Explainability with SageMaker Clarify - Partial Dependence Plots (PDP)
This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.
Runtime
This notebook takes approximately 30 minutes to run.
Contents
Overview
Amazon SageMaker Clarify provides you the ability to gain insight into your model with Partial Dependency Plots (PDP). PDP shows the marginal effect features have on the predicted outcome of a machine learning model. Intuitively, you can interpret the partial dependence as the target response which is expected as a function of each input feature of interest.
This sample notebook walks you through:
Key terms and concepts needed to understand SageMaker Clarify
Generating PDPs
Accessing the explainability report and viewing PDPs
In doing so, the notebook first trains a SageMaker XGBoost model using a training dataset, then uses SageMaker Clarify to explain the corresponding testing dataset in CSV format with PDP. In addition to CSV, SageMaker Clarify also supports analyzing datasets in SageMaker JSON Lines dense format.
This notebook focuses on model explainability with PDP. If you would like to learn more about model explainability with Kernel SHAP, please visit this notebook. You can find the detailed documentation of SageMaker Clarify at What Is Fairness and Model Explainability for Machine Learning Predictions and more demo notebooks at aws-sagemaker-examples GitHub repository.
Prerequisites and Data
Initialize SageMaker
[ ]:
from sagemaker import Session
session = Session()
bucket = session.default_bucket()
prefix = "sagemaker/DEMO-sagemaker-clarify-pdp"
region = session.boto_region_name
# Define IAM role
from sagemaker import get_execution_role
import pandas as pd
import os
import boto3
from datetime import datetime
import matplotlib.pyplot as plt
role = get_execution_role()
s3_client = boto3.client("s3")
Download data
Data Source: https://archive.ics.uci.edu/ml/machine-learning-databases/adult/
Let’s download the data and save it in the local folder with the name adult.data
and adult.test
from UCI repository\(^{[1]}\).
\(^{[1]}\)Dua Dheeru, and Efi Karra Taniskidou. “UCI Machine Learning Repository”. Irvine, CA: University of California, School of Information and Computer Science (2017).
[ ]:
adult_columns = [
"Age",
"Workclass",
"fnlwgt",
"Education",
"Education-Num",
"Marital Status",
"Occupation",
"Relationship",
"Ethnic group",
"Sex",
"Capital Gain",
"Capital Loss",
"Hours per week",
"Country",
"Target",
]
if not os.path.isfile("adult.data"):
s3_client.download_file(
f"sagemaker-example-files-prod-{region}",
"datasets/tabular/uci_adult/adult.data",
"adult.data",
)
print("adult.data saved!")
else:
print("adult.data already on disk.")
if not os.path.isfile("adult.test"):
s3_client.download_file(
f"sagemaker-example-files-prod-{region}",
"datasets/tabular/uci_adult/adult.test",
"adult.test",
)
print("adult.test saved!")
else:
print("adult.test already on disk.")
Loading the data: Adult Dataset
From the UCI repository of machine learning datasets, this database contains 14 features concerning demographic characteristics of 45,222 rows (32,561 for training and 12,661 for testing). The task is to predict whether a person has a yearly income that is more or less than $50,000.
Here are the features and their possible values:
Age: continuous.
Workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked.
Fnlwgt: continuous (the number of people the census takers believe that observation represents).
Education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.
Education-num: continuous.
Marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.
Occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.
Relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.
Ethnic group: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.
Sex: Female, Male.
Note: this data is extracted from the 1994 Census and enforces a binary option on Sex
Capital-gain: continuous.
Capital-loss: continuous.
Hours-per-week: continuous.
Native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.
[ ]:
training_data = pd.read_csv(
"adult.data", names=adult_columns, sep=r"\s*,\s*", engine="python", na_values="?"
).dropna()
testing_data = pd.read_csv(
"adult.test", names=adult_columns, sep=r"\s*,\s*", engine="python", na_values="?", skiprows=1
).dropna()
training_data.head()
Data inspection
Plotting histograms for the distribution of the different features is a good way to visualize the data. Let’s plot a few of the features that can be considered sensitive. Let’s take a look specifically at the Sex feature of a census respondent. In the first plot we see that there are fewer Female respondents as a whole but especially in the positive outcomes, where they form ~\(\frac{1}{7}\)th of respondents.
[ ]:
%matplotlib inline
training_data["Sex"].value_counts().sort_values().plot(kind="bar", title="Counts of Sex", rot=0)
plt.show()
[ ]:
%matplotlib inline
training_data["Sex"].where(training_data["Target"] == ">50K").value_counts().sort_values().plot(
kind="bar", title="Counts of Sex earning >$50K", rot=0
)
plt.show()
Encode and Upload the Dataset
Here we encode the training and test data. Encoding input data is not necessary for SageMaker Clarify, but is necessary for the model.
[ ]:
from sklearn import preprocessing
def number_encode_features(df):
result = df.copy()
encoders = {}
for column in result.columns:
if result.dtypes[column] == object:
encoders[column] = preprocessing.LabelEncoder()
result[column] = encoders[column].fit_transform(result[column].fillna("None"))
else:
result[column] = result[column].astype("float")
return result, encoders
training_data = pd.concat([training_data["Target"], training_data.drop(["Target"], axis=1)], axis=1)
training_data, _ = number_encode_features(training_data)
training_data.to_csv("train_data.csv", index=False, header=False)
testing_data, _ = number_encode_features(testing_data)
test_features = testing_data.drop(["Target"], axis=1)
testing_data.to_csv("test_data.csv", index=False, header=False)
A quick note about our encoding: the “Female” Sex value has been encoded as 0 and “Male” as 1.
[ ]:
training_data.head()
Lastly, let’s upload the data to S3.
[ ]:
from sagemaker.s3 import S3Uploader
from sagemaker.inputs import TrainingInput
train_uri = S3Uploader.upload("train_data.csv", "s3://{}/{}".format(bucket, prefix))
train_input = TrainingInput(train_uri, content_type="csv")
test_uri = S3Uploader.upload("test_data.csv", "s3://{}/{}".format(bucket, prefix))
Train XGBoost Model
Train Model
Since our focus is on understanding how to use SageMaker Clarify, we keep it simple by using a standard XGBoost model.
[ ]:
from sagemaker.image_uris import retrieve
from sagemaker.estimator import Estimator
container = retrieve("xgboost", region, version="1.5-1")
xgb = Estimator(
container,
role,
instance_count=1,
instance_type="ml.m5.xlarge",
disable_profiler=True,
sagemaker_session=session,
)
xgb.set_hyperparameters(
max_depth=5,
eta=0.2,
gamma=4,
min_child_weight=6,
subsample=0.8,
objective="binary:logistic",
num_round=800,
)
xgb.fit({"train": train_input}, logs=False)
Deploy Model
Here we create the SageMaker model.
[ ]:
model_name = "DEMO-clarify-xgb-model-{}".format(datetime.now().strftime("%d-%m-%Y-%H-%M-%S"))
model = xgb.create_model(name=model_name)
container_def = model.prepare_container_def()
session.create_model(model_name, role, container_def)
Amazon SageMaker Clarify
Now that you have your model set up, let’s say hello to SageMaker Clarify!
[ ]:
from sagemaker import clarify
clarify_processor = clarify.SageMakerClarifyProcessor(
role=role, instance_count=1, instance_type="ml.m5.xlarge", sagemaker_session=session
)
Explaining Predictions with PDP
PDP shows the marginal effect features have on the predicted outcome of a model. SageMaker Clarify can generate PDPs for a user-specified list of features, or for top k features with largest SHAP values.
Writing DataConfig and ModelConfig
A DataConfig
object communicates some basic information about data I/O to SageMaker Clarify. We specify where to find the input dataset, where to store the output, the target column (label
), the header names, and the dataset type.
[ ]:
pdp_explainability_output_path = "s3://{}/{}/clarify-explainability-pdp".format(bucket, prefix)
explainability_data_config = clarify.DataConfig(
s3_data_input_path=test_uri,
s3_output_path=pdp_explainability_output_path,
label="Target",
headers=testing_data.columns.to_list(),
dataset_type="text/csv",
)
A ModelConfig
object communicates information about your trained model. To avoid additional traffic to your production models, SageMaker Clarify sets up and tears down a dedicated endpoint when processing. * instance_type
and instance_count
specify your preferred instance type and instance count used to run your model on during SageMaker Clarify’s processing. The testing dataset is small, so a single standard instance is good enough to run this example. If you have a large complex
dataset, you may want to use a better instance type to speed up, or add more instances to enable Spark parallelization. * accept_type
denotes the endpoint response payload format, and content_type
denotes the payload format of request to the endpoint.
[ ]:
model_config = clarify.ModelConfig(
model_name=model_name,
instance_type="ml.m5.xlarge",
instance_count=1,
accept_type="text/csv",
content_type="text/csv",
)
A ModelPredictedLabelConfig
provides information on the format of your predictions. XGBoost model outputs probabilities of samples, so SageMaker Clarify invokes the endpoint then uses probability_threshold
to convert the probability to binary labels for bias analysis. Prediction above the threshold is interpreted as label value 1
and below or equal as label value 0
.
[ ]:
predictions_config = clarify.ModelPredictedLabelConfig(probability_threshold=0.8)
Writing PDPConfig
A PDPConfig
object specifies how the PDP analysis is done. There are 2 ways in which we can plot PDP with SageMaker Clarify based on specified parameters:
If a list of
features
is specified, PDPs are plotted for each feature in the list.If
top_k_features
is specified, feature importance is ranked based on SHAP values and PDPs are plotted for the top k features. If not specified, the default number of features is10
.
Let’s look at the stand alone PDP analysis first. To generate PDP without Kernel SHAP, we need to specify the features
we want PDP analysis done on as a list of feature headers. We can also specify an optional grid_resolution
parameter, which indicates the number of evenly spaced buckets a numerical feature is grouped into across its range of values. The feature values used to plot PDP are the midpoints of each bucket. If not specified, the default grid_resolution
is 20
. For
categorical features, each unique feature value is used in the PDP plot.
[ ]:
pdp_config = clarify.PDPConfig(
features=["Sex", "Age", "Education-Num"],
grid_resolution=15,
)
Now we can run the analysis with the above inputs. The command in the cell below creates a SageMaker Processing job and may take around 10 minutes to complete.
[ ]:
clarify_processor.run_explainability(
data_config=explainability_data_config,
model_config=model_config,
explainability_config=pdp_config,
model_scores=predictions_config,
logs=False,
)
View the Explainabiltiy Report
You can access the explainability report in pdf, html and ipynb formats in the following S3 bucket:
[ ]:
pdp_explainability_output_path
For example, you can also download a copy of the html report and view it in-place here.
[ ]:
s3_client.download_file(
bucket, prefix + "/clarify-explainability-pdp/report.html", "explainability-report.html"
)
[ ]:
import IPython
IPython.display.HTML(filename="explainability-report.html")
PDP Analysis with SHAP
Instead of specifying the features we want to run PDP analysis on, we can use SHAP values to rank feature importance and select the top_k_features to plot PDPs. SHAP value of an input feature indicates the amount of contribution that this feature makes to the model prediction.
[ ]:
pdp_config = clarify.PDPConfig(
top_k_features=5,
grid_resolution=25,
)
The SHAPConfig
object communicates information about required by the Kernel SHAP algorithm to compute SHAP values.
A
Baseline
dataset is required by the Kernel SHAP algorithm. If abaseline
is not provided, it will be automatically computed as the cluster centers of the dataset by SageMaker Clarify using K-means or K-prototypes. Baseline dataset type shall be the same asdataset_type
ofDataConfig
, and baseline samples shall only include features. By definition,baseline
should either be a S3 URI to the baseline dataset file, or an in-place list of samples. In this case we chose the latter, and put the first sample of the test dataset to the list. For more information on the baseline dataset, see SHAP Baselines for Explainability.num_samples
determines the size of the generated synthetic dataset to compute the SHAP values. If not provided then Clarify job will choose a proper value according to the count of features.agg_method
specifies how the per-instance feature importance should be aggregated over the dataset to compute the overall (global) feature importance. For more details see Amazon AI Fairness and Explainability Whitepaper.
[ ]:
shap_config = clarify.SHAPConfig(
baseline=[test_features.iloc[0].values.tolist()],
num_samples=15,
agg_method="mean_abs",
)
[ ]:
pdp_with_shap_explainability_output_path = "s3://{}/{}/clarify-explainability-pdp-with-shap".format(
bucket, prefix
)
explainability_data_config = clarify.DataConfig(
s3_data_input_path=test_uri,
s3_output_path=pdp_with_shap_explainability_output_path,
label="Target",
headers=testing_data.columns.to_list(),
dataset_type="text/csv",
)
Now we can run the analysis with the above inputs. The command in the cell below creates a SageMaker Processing job and may take around 15 minutes to complete.
[ ]:
clarify_processor.run_explainability(
data_config=explainability_data_config,
model_config=model_config,
explainability_config=[pdp_config, shap_config],
logs=False,
)
Similar to the previous explainability report, you can access the report in pdf, html and ipynb formats in the following S3 bucket:
[ ]:
pdp_with_shap_explainability_output_path
Clean Up
Finally, don’t forget to clean up the resources we set up and used for this demo!
[ ]:
session.delete_model(model_name)
Notebook CI Test Results
This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.