Building your own algorithm container for Causal Inference


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


With Amazon SageMaker, you can package your own algorithms that can than be trained and deployed in the SageMaker environment. This notebook will guide you through an example that shows you how to build a Docker container for SageMaker that hosts a Causal model, and how can you use it for training, inference and interventions of the model.

This example shows how to build a container for using the Causal Inference library using as a base the following tutorial building your own algorithm container tutorial. We are going to use Conda Python 3 kernel in this notebook.

Permissions

Running this notebook requires permissions in addition to the normal SageMakerFullAccess permissions. This is because we will be creating new repositories on Amazon ECR. The easiest way to add these permissions is simply to add the managed policy AmazonEC2ContainerRegistryFullAccess to the role that you used to start your notebook instance. There’s no need to restart your notebook instance when you do this, the new permissions will be available immediately.

Installing prerequisites

[ ]:
!pip install causalnex

The parts of the sample container

In the container directory are all the components you need to package the sample algorithm for Amazon SageMaker:

.
|-- Dockerfile
|-- build_and_push.sh
`-- causal_nex
    |-- nginx.conf
    |-- predictor.py
    |-- serve
    |-- train
    `-- wsgi.py

The Dockerfile

[ ]:
!cat container/Dockerfile

Building and registering the container

The following shell code shows how to build the container image using docker build and push the container image to ECR using docker push. This code is also available as the shell script container/build-and-push.sh, which you can run as build-and-push.sh causal-nex-container to build the image causal-nex-container.

[ ]:
%%sh

# The name of our algorithm
algorithm_name=causal-nex-container

cd container

chmod +x causal_nex/train
chmod +x causal_nex/serve

account=$(aws sts get-caller-identity --query Account --output text)

# Get the region defined in the current configuration (default to us-west-1 if none defined)
region=$(aws configure get region)
# region=${region:-eu-west-1}

fullname="${account}.dkr.ecr.${region}.amazonaws.com/${algorithm_name}:latest"

# If the repository doesn't exist in ECR, create it.
aws ecr describe-repositories --repository-names "${algorithm_name}" > /dev/null 2>&1

if [ $? -ne 0 ]
then
    aws ecr create-repository --repository-name "${algorithm_name}" > /dev/null
fi

# Get the login command from ECR and execute it directly
$(aws ecr get-login --region $region --registry-ids $account --no-include-email)

# Build the docker image locally with the image name and then push it to ECR
# with the full name.

docker build  -t ${algorithm_name} .
docker tag ${algorithm_name} ${fullname}

docker push ${fullname}

Using CausalNex in Amazon SageMaker

Once you have your container packaged, you can use it to train models and use the model for hosting or batch transforms. Let’s do that with the algorithm we made above. However, we have an additional bit, we can do model interventions as well, a common feature of Causal models.

Set up the environment and create the session

Here we specify a bucket to use and the role that will be used for working with SageMaker. The session remembers our connection parameters to SageMaker.

[ ]:
import boto3
import re

import os
import numpy as np
import pandas as pd
from sagemaker import get_execution_role

import sagemaker as sage
from time import gmtime, strftime

role = get_execution_role()
sess = sage.Session()
region = boto3.Session().region_name
s3_client = boto3.client("s3")

Upload the data for training

When training large models with huge amounts of data, you’ll typically use big data tools, like Amazon Athena, AWS Glue, or Amazon EMR, to create your data in S3. For the purposes of this example, we are using a heart failure dataset of 299 patients.

Davide Chicco, Giuseppe Jurman: “Machine learning can predict survival of patients with heart failure from serum creatinine and ejection fraction alone”. BMC Medical Informatics and Decision Making 20, 16 (2020). Web Link

Let’s download it from the public bucket and then upload it to our default sagemaker bucket:

[ ]:
! mkdir data

# S3 bucket where the training data is located.
data_bucket = f"sagemaker-sample-files"
data_prefix = "datasets/tabular/uci_heart_failure/"
data_bucket_path = f"s3://{data_bucket}"

# S3 prefix for saving code and model artifacts.
prefix = "DEMO-causal-nex"
WORK_DIRECTORY = "data/"

s3_client.download_file(
    data_bucket,
    data_prefix + "heart_failure_clinical_records_dataset.csv",
    WORK_DIRECTORY + "heart_failure_clinical_records_dataset.csv",
)
data_location = sess.upload_data(WORK_DIRECTORY, key_prefix=prefix)

Intro to causal modeling

Causal models are mathematical models representing causal relationships. They facilitate inferences about causal relationships from statistical data. They can teach us a good deal about the causation, and about the relationship between causation and probability. We will walk through how to modify the training script which is located in container/causal_nex/train. Let’s take a look in details:

[ ]:
!pygmentize -g container/causal_nex/train

For demo purposes, we use a dataset of 299 patients with heart failure collected in 2015 that contains thirteen clinical features:

  • age: age of the patient (years)

  • anemia: decrease of red blood cells or hemoglobin (boolean)

  • high blood pressure: if the patient has hypertension (boolean)

  • creatinine phosphokinase (CPK): level of the CPK enzyme in the blood (mcg/L)

  • diabetes: if the patient has diabetes (boolean)

  • ejection fraction: percentage of blood leaving the heart at each contraction (percentage)

  • platelets: platelets in the blood (kilo platelets/mL)

  • sex: woman or man (binary)

  • serum creatinine: level of serum creatinine in the blood (mg/dL)

  • serum sodium: level of serum sodium in the blood (mEq/L)

  • smoking: if the patient smokes or not (boolean)

  • time: follow-up period (days)

  • [target] death event: if the patient deceased during the follow-up period (boolean)

In this paper, the authors define two most important features - serum creatinine and ejection fraction.

Bayesian Networks in CausalNex support only discrete distributions. Any continuous features, or features with a large number of categories, should be discretized prior to fitting the Bayesian Network. Models containing variables with many possible values will typically be badly fit, and exhibit poor performance.

As a first step, let’s do the data discretization. CausalNex provides a helper class causalnex.discretiser.Discretiser, which supports several discretization methods. For our data the fixed method will be applied, providing static values that define the bucket boundaries. For example, the splitting can be done by using statistical quartiles (a type of quantile which divides the number of data points into four parts, or quarters, of more-or-less equal size), or by using statistical quantitative description of the numeric features. Therefore, ejection_fraction will be discretized into the buckets < 30, from 31 till 38, from 39 till 42, and >=42. Each bucket will be labelled as an integer from zero.

[ ]:
from causalnex.discretiser import Discretiser
import pandas as pd

initial_df = pd.read_csv(WORK_DIRECTORY + "/heart_failure_clinical_records_dataset.csv")

initial_df["age"] = Discretiser(method="fixed", numeric_split_points=[60]).transform(
    initial_df["age"].values
)
initial_df["serum_sodium"] = Discretiser(method="fixed", numeric_split_points=[136]).transform(
    initial_df["serum_sodium"].values
)
initial_df["serum_creatinine"] = Discretiser(
    method="fixed", numeric_split_points=[1.1, 1.4]
).transform(initial_df["serum_sodium"].values)

initial_df["ejection_fraction"] = Discretiser(
    method="fixed", numeric_split_points=[30, 38, 42]
).transform(initial_df["ejection_fraction"].values)

initial_df["creatinine_phosphokinase"] = Discretiser(
    method="fixed", numeric_split_points=[120, 540, 670]
).transform(initial_df["creatinine_phosphokinase"].values)

initial_df["platelets"] = Discretiser(method="fixed", numeric_split_points=[263358]).transform(
    initial_df["platelets"].values
)
initial_df.head()

We can manually define a structure model(SM) (or it can be created by a domain expert) by specifying the relationships between different features. Defining appropriate SM is a key to a causal analysis. For example, in our case a relationship that is defined (“ejection_fraction”, “DEATH_EVENT”) can be translated as “ejection_fraction” node causes “DEATH_EVENT”.

[ ]:
import networkx as nx

causal_graph = nx.DiGraph(
    [
        ("ejection_fraction", "DEATH_EVENT"),
        ("creatinine_phosphokinase", "DEATH_EVENT"),
        ("age", "DEATH_EVENT"),
        ("serum_sodium", "DEATH_EVENT"),
        ("high_blood_pressure", "DEATH_EVENT"),
        ("anaemia", "DEATH_EVENT"),
        ("creatinine_phosphokinase", "DEATH_EVENT"),
        ("smoking", "DEATH_EVENT"),
    ]
)

We can visualize the statistical dependencies between these variables using a graph:

[ ]:
import matplotlib.pyplot as plt

nx.draw_networkx(causal_graph, arrows=True)
plt.show()

Create an estimator and fit the model

In order to use SageMaker to fit our algorithm, we’ll create an Estimator that defines how to use the container to train. This includes the configuration we need to invoke SageMaker training:

  • The container name. This is constructed as in the shell commands above.

  • The role. As defined above.

  • The instance count which is the number of machines to use for training.

  • The instance type which is the type of machine to use for training.

  • The output path determines where the model artifact will be written.

  • The session is the SageMaker session object that we defined above.

Then we use fit() on the estimator to train against the data that we uploaded above.

[ ]:
account = sess.boto_session.client("sts").get_caller_identity()["Account"]
region = sess.boto_session.region_name
image = "{}.dkr.ecr.{}.amazonaws.com/causal-nex-container:latest".format(account, region)

bn = sage.estimator.Estimator(
    image_uri=image,
    role=role,
    instance_count=1,
    instance_type="ml.c4.2xlarge",
    output_path="s3://{}/output".format(sess.default_bucket()),
    sagemaker_session=sess,
)

bn.fit(data_location)

Deploy the model

Deploying the model to SageMaker hosting just requires a deploy call on the fitted model. This call takes an instance count, instance type, and optionally serializer and deserializer functions. These are used when the resulting predictor is created on the endpoint.

[ ]:
from sagemaker.predictor import csv_serializer
from datetime import datetime

# to create unique endpoint
now = datetime.now()
dt_string = now.strftime("-%d-%m-%Y-%H-%M-%S")

endpoint_name = "causal-endpoint" + dt_string
predictor = bn.deploy(1, "ml.m5d.xlarge", endpoint_name=endpoint_name)

Likelihood Estimation and Predictions

When the graph has been determined and parameter estimation for Conditional probability distribution (using Maximum likelihood or Bayesian parameter, in our case its done by calling .fit()) was performed, they can be used to predict the node states and corresponding probabilities. Conditional probabilities calculate the chance that a specific value for a random variable will occur given that another random variable has already taken a value. For more details on how to use CausalNex you can refer to this article.

Choose some data and use it for a prediction

In order to do some predictions, we’ll create a dictionary with 2 test cases - feel free to add more! The only variable that we change is age to see how that impacts the survival outcomes (remember that we discretized variables before).

[ ]:
import boto3
import json

client = boto3.client("sagemaker-runtime")

test_cases = [
    {
        "age": 1,
        "anaemia": 0,
        "creatinine_phosphokinase": 2,
        "diabetes": 0,
        "ejection_fraction": 0,
        "high_blood_pressure": 1,
        "platelets": 1,
        "serum_creatinine": 0,
        "serum_sodium": 0,
        "sex": 1,
        "smoking": 0,
        "time": 4,
    },
    {
        "age": 0,
        "anaemia": 0,
        "creatinine_phosphokinase": 2,
        "diabetes": 0,
        "ejection_fraction": 0,
        "high_blood_pressure": 1,
        "platelets": 1,
        "serum_creatinine": 0,
        "serum_sodium": 0,
        "sex": 1,
        "smoking": 1,
        "time": 4,
    },
]

target_node = "DEATH_EVENT"
payload = json.dumps({"data": test_cases, "pred_type": "prediction", "target_node": target_node})

# invoke endpoint
response = client.invoke_endpoint(
    EndpointName=endpoint_name, ContentType="application/json", Body=payload
)

# decode the endpoint response
response_body = response["Body"]
response_str = response_body.read().decode("utf-8")

# print the prediction state
print("DEATH_EVENT Predictions:")
print(response_str)

Making interventions

To explore the effect of actions on the target variable, and examine the effect of such intervention, Do-Calculus on the Bayesian Network can be used. One of the goals of causal analysis is not only to understand exactly what causes a specific effect, but rather to be able to intervene in the process and control what the outcome is and to answer questions of the form: - Does treatment X help to cure the disease? - What happens if we change the type Y?

Actually doing the intervention might be unfeasible or unethical — side-stepping actual interventions and still getting at causal effects is the whole point of this approach to causal inference. To read more about interventions go here or here. For the deeper understanding we recommend reading “Book of Why” by Judea Pearl.

Some example questions that can be answered with Causal Analysis are: - Does the treatment X helps to cure the disease? - What happens if we change type of detail Y in the production line? - What is an effect of new route on the revenue of item Z?

Let’s examine the effect of intervention on the ejection_fraction node by changing its states.

[ ]:
import boto3

client = boto3.client("sagemaker-runtime")
node = "ejection_fraction"
target_node = "DEATH_EVENT"
# pass the json file for intervention with node on which the intervention will be performed, states and target node
payload = json.dumps(
    {
        "data": {
            "node": node,
            "states": [
                {0: 1.0, 1: 0.0, 2: 0.0, 3: 0.0},
                {0: 0.0, 1: 1.0, 2: 0.0, 3: 0.0},
                {0: 0.0, 1: 0.0, 2: 1.0, 3: 0.0},
                {0: 0.0, 1: 0.0, 2: 0.0, 3: 1.0},
            ],
            "target_node": target_node,
        },
        "pred_type": "intervention",
    }
)


response = client.invoke_endpoint(
    EndpointName=endpoint_name, ContentType="application/json", Body=payload
)
# decode output
response_body = response["Body"]
response_str = response_body.read().decode("utf-8")


# show output
print(target_node, "prediction with intervention/s on", node, "node:\n", response_str)

We used discretizer to have ejection_fraction values binned into 4 buckets < 30, from 31 till 38, from 39 till 42, and >=42. Each bucket is labelled as an integer from zero. Therefore, “states”: [{0: 1.0, 1: 0.0, 2: 0.0, 3: 0.0}] means that we want to examine if the target_node “DEATH_EVENT” will be changed if we set(intervene) node ejection_fraction to be <30.

Optional cleanup

When you’re done with the endpoint, you’ll want to clean it up.

[ ]:
predictor.delete_model()
predictor.delete_endpoint()

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable