Explaining Image Classification with SageMaker Clarify


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


  1. Overview

  2. Train and Deploy Image Classifier

    1. Permissions and environment variables

    2. Fine-tuning the Image classification model

    3. Training

    4. Input data specification

    5. Start the training

    6. Deploy SageMaker model

    7. List of object categories

  3. Amazon SageMaker Clarify

    1. Test Images

    2. Set up config objects

    3. SageMaker Clarify Processor

    4. Reading Results

  4. Clean Up

Overview

Amazon SageMaker Clarify provides you the ability to gain an insight into your Computer Vision models. Clarify generates heat map, which highlights feature importance, for each input image and helps understand the model behavior. For Computer Vision, Clarify supports both Image Classification and Object Detection use cases. This notebook can be run inside the SageMaker Studio with conda_pytorch_latest_py36 kernel and inside SageMaker Notebook-Instance with Python 3 (PyTorch 1.8 Python 3.6 GPU Optimized) kernel. This sample notebook walks you through: 1. Key terms and concepts needed to understand SageMaker Clarify. 1. Explaining the importance of the image features (super pixels) for Image Classification model. 1. Accessing the reports and output images.

In doing so, the notebook will first train and deploy an Image Classification model with Sagemaker Estimator using Caltech-256 dataset [1], then use SageMaker Clarify to run explainability on a subset of test images. >[1] Griffin, G. Holub, AD. Perona, P. The Caltech 256. Caltech Technical Report.

Let’s start by installing the latest version of the SageMaker Python SDK, boto, and AWS CLI.

[ ]:
! pip install sagemaker botocore boto3 awscli --upgrade

Train and Deploy Image Classifier

Let’s first train and deploy an Image Classification model to SageMaker.

Permissions and environment variables

Here we set up the linkage and authentication to AWS services. There are three parts to this:

  • The roles used to give learning and hosting access to your data. This will automatically be obtained from the role used to start the notebook

  • The S3 bucket that you want to use for training and model data

  • The Amazon sagemaker image classification docker image which need not be changed

[ ]:
import boto3
import sagemaker
from sagemaker import get_execution_role

role = get_execution_role()
print(role)

region = boto3.Session().region_name

s3_client = boto3.client("s3")

sess = sagemaker.Session()

output_bucket = sess.default_bucket()
output_prefix = "ic-transfer-learning"

# download the files
s3_client.download_file(
    f"sagemaker-example-files-prod-{region}",
    "datasets/image/caltech-256/caltech-256-60-train.rec",
    "caltech-256-60-train.rec",
)
s3_client.download_file(
    f"sagemaker-example-files-prod-{region}",
    "datasets/image/caltech-256/caltech-256-60-val.rec",
    "caltech-256-60-val.rec",
)

s3_client.upload_file(
    "caltech-256-60-train.rec", output_bucket, output_prefix + "/train_rec/caltech-256-60-train.rec"
)

s3_client.upload_file(
    "caltech-256-60-train.rec",
    output_bucket,
    output_prefix + "/validation_rec/caltech-256-60-train.rec",
)
[ ]:
from sagemaker import image_uris

training_image = image_uris.retrieve(
    "image-classification", sess.boto_region_name, version="latest"
)

print(training_image)

Fine-tuning the Image classification model

The Caltech-256 dataset consist of images from 257 categories (the last one being a clutter category) and has 30k images with a minimum of 80 images and a maximum of about 800 images per category.

The image classification algorithm can take two types of input formats. The first is a recordio format and the other is a lst format. Files for both these formats are available at http://data.dmlc.ml/mxnet/data/caltech-256/. In this example, we will use the recordio format for training and use the training/validation split specified here.

[ ]:
# Four channels: train, validation, train_lst, and validation_lst
s3train = f"s3://{output_bucket}/{output_prefix}/train_rec/"
s3validation = f"s3://{output_bucket}/{output_prefix}/validation_rec/"

Training

Now that we are done with all the setup that is needed, we are ready to train our object detector. To begin, let us create a sageMaker.estimator.Estimator object. This estimator will launch the training job. There are two kinds of parameters that need to be set for training. Following are the parameters for the training job: * instance_count: This is the number of instances on which to run the training. When the number of instances is greater than one, then the image classification algorithm will run in distributed settings. * instance_typee: This indicates the type of machine on which to run the training. Typically, we use GPU instances for such training jobs. * output_path: This the s3 folder in which the training output is stored.

[ ]:
s3_output_location = f"s3://{output_bucket}/{output_prefix}/output"
ic_estimator = sagemaker.estimator.Estimator(
    training_image,
    role,
    instance_count=1,
    instance_type="ml.p2.xlarge",
    volume_size=50,
    max_run=360000,
    input_mode="File",
    output_path=s3_output_location,
    sagemaker_session=sess,
)

Apart from the above set of parameters, there are hyperparameters that are specific to the algorithm. These are:

  • num_layers: The number of layers (depth) for the network. We use 18 for this training but other values such as 50, 152 can also be used.

  • use_pretrained_model: Set to 1 to use pretrained model for transfer learning.

  • image_shape: The input image dimensions,‘num_channels, height, width’, for the network. It should be no larger than the actual image size. The number of channels should be same as the actual image.

  • num_classes: This is the number of output classes for the new dataset. ImageNet was trained with 1000 output classes but the number of output classes can be changed for fine-tuning. For caltech, we use 257 because it has 256 object categories + 1 clutter class.

  • num_training_samples: This is the total number of training samples. It is set to 15240 for caltech dataset with the current split.

  • mini_batch_size: The number of training samples used for each mini batch. In distributed training, the number of training samples used per batch will be N * mini_batch_size where N is the number of hosts on which training is run.

  • epochs: Number of training epochs.

  • learning_rate: Learning rate for training.

  • precision_dtype: Training datatype precision (default: float32). If set to ‘float16’, the training will be done in mixed_precision mode and will be faster than float32 mode.

[ ]:
ic_estimator.set_hyperparameters(
    num_layers=18,
    use_pretrained_model=1,
    image_shape="3,224,224",
    num_classes=257,
    num_training_samples=15420,
    mini_batch_size=128,
    epochs=2,
    learning_rate=0.01,
    precision_dtype="float32",
)

Input data specification

Set the data type and channels used for training.

[ ]:
train_data = sagemaker.inputs.TrainingInput(
    s3train,
    distribution="FullyReplicated",
    content_type="application/x-recordio",
    s3_data_type="S3Prefix",
)
validation_data = sagemaker.inputs.TrainingInput(
    s3validation,
    distribution="FullyReplicated",
    content_type="application/x-recordio",
    s3_data_type="S3Prefix",
)

data_channels = {"train": train_data, "validation": validation_data}

Start the training

Start training by calling the fit method in the estimator.

[ ]:
ic_estimator.fit(inputs=data_channels, logs=True)

Deploy SageMaker model

Once trained, we use the estimator to deploy a model to SageMaker. This model will be used by Clarify to deploy endpoints and run inference on images.

[ ]:
from time import gmtime, strftime

timestamp_suffix = strftime("%Y-%m-%d-%H-%M-%S", gmtime())

model_name = "DEMO-clarify-image-classification-model-{}".format(timestamp_suffix)
model = ic_estimator.create_model(name=model_name)
container_def = model.prepare_container_def()
sess.create_model(model_name, role, container_def)

List of object categories

[ ]:
with open("caltech_256_object_categories.txt", "r+") as object_categories_file:
    object_categories = [category.rstrip("\n") for category in object_categories_file.readlines()]

# Let's list top 10 entries from the object_categories list
object_categories[:10]

Amazon SageMaker Clarify

Now that we have your image classification endpoint all set up, let’s get started with SageMaker Clarify!

Test Images

We need some test images to explain predictions made by the Image Classification model using Clarify. Let’s grab some test images from the Caltech-256 dataset and upload them to some S3 bucket.

[ ]:
prefix = "sagemaker/DEMO-sagemaker-clarify-cv"
file_name_map = {
    "167.pyramid/167_0002.jpg": "pyramid.jpg",
    "038.chimp/038_0013.jpg": "chimp.jpg",
    "124.killer-whale/124_0013.jpg": "killer-whale.jpg",
    "170.rainbow/170_0001.jpg": "rainbow.jpg",
}


s3_client.download_file(
    f"sagemaker-example-files-prod-{region}",
    "datasets/image/caltech-256/256_ObjectCategories/167.pyramid/167_0002.jpg",
    "pyramid.jpg",
)
s3_client.download_file(
    f"sagemaker-example-files-prod-{region}",
    "datasets/image/caltech-256/256_ObjectCategories/038.chimp/038_0013.jpg",
    "chimp.jpg",
)
s3_client.download_file(
    f"sagemaker-example-files-prod-{region}",
    "datasets/image/caltech-256/256_ObjectCategories/124.killer-whale/124_0013.jpg",
    "killer-whale.jpg",
)
s3_client.download_file(
    f"sagemaker-example-files-prod-{region}",
    "datasets/image/caltech-256/256_ObjectCategories/038.chimp/038_0013.jpg",
    "chimp.jpg",
)
s3_client.download_file(
    f"sagemaker-example-files-prod-{region}",
    "datasets/image/caltech-256/256_ObjectCategories/170.rainbow/170_0001.jpg",
    "rainbow.jpg",
)

for file_name in file_name_map:
    s3_client.upload_file(
        file_name_map[file_name], output_bucket, f"{prefix}/{file_name_map[file_name]}"
    )

Set up config objects

Now we setup some config objects required for running the Clarify job: * explainability_data_config: Config object related to configurations of the input and output dataset. * model_config: Config object related to a model and its endpoint to be created. * content_type: Specifies the type of input expected by the model. * predictions_config: Config object to extract a predicted label from the model output. * label_headers: This is the list of all the classes on which the model was trained. * image_config: Config object for image data type * model_type: Specifies the type of CV model (IMAGE_CLASSIFICATION | OBJECT_DETECTION) * num_segments: Clarify uses SKLearn’s SLIC method for image segmentation to generate features/superpixels. num_segments specifies approximate number of segments to be generated. * segment_compactness: Balances color proximity and space proximity. Higher values give more weight to space proximity, making superpixel shapes more square/cubic. We recommend exploring possible values on a log scale, e.g., 0.01, 0.1, 1, 10, 100, before refining around a chosen value. * shap_config: Config object for kernel SHAP parameters * num_samples: total number of feature coalitions to be tested by Kernel SHAP.

[ ]:
from sagemaker import clarify

s3_data_input_path = "s3://{}/{}/".format(output_bucket, prefix)
clarify_output_prefix = f"{prefix}/cv_analysis_result"
analysis_result_path = "s3://{}/{}".format(output_bucket, clarify_output_prefix)
explainability_data_config = clarify.DataConfig(
    s3_data_input_path=s3_data_input_path,
    s3_output_path=analysis_result_path,
    dataset_type="application/x-image",
)

model_config = clarify.ModelConfig(
    model_name=model_name, instance_type="ml.m5.xlarge", instance_count=1, content_type="image/jpeg"
)

predictions_config = clarify.ModelPredictedLabelConfig(label_headers=object_categories)

image_config = clarify.ImageConfig(
    model_type="IMAGE_CLASSIFICATION", num_segments=20, segment_compactness=5
)

shap_config = clarify.SHAPConfig(num_samples=500, image_config=image_config)

SageMaker Clarify Processor

Let’s get the execution role for running SageMakerClarifyProcessor.

[ ]:
import os

account_id = os.getenv("AWS_ACCOUNT_ID", "<your-account-id>")
sagemaker_iam_role = "<AmazonSageMaker-ExecutionRole>"

# Fetch the IAM role to initialize the sagemaker processing job
try:
    role = sagemaker.get_execution_role()
except ValueError as e:
    print(e)
    role = f"arn:aws:iam::{account_id}:role/{sagemaker_iam_role}"

clarify_processor = clarify.SageMakerClarifyProcessor(
    role=role, instance_count=1, instance_type="ml.m5.xlarge", sagemaker_session=sess
)

Finally, we run explainability on the clarify processor.

[ ]:
clarify_processor.run_explainability(
    data_config=explainability_data_config,
    model_config=model_config,
    explainability_config=shap_config,
    model_scores=predictions_config,
)

Reading Results

Let’s download all the result images along with the PDF report.

[ ]:
%%time
output_objects = s3_client.list_objects(Bucket=output_bucket, Prefix=clarify_output_prefix)
result_images = []

for file_obj in output_objects["Contents"]:
    file_name = os.path.basename(file_obj["Key"])
    if os.path.splitext(file_name)[1] == ".jpeg":
        result_images.append(file_name)

    print(f"Downloading s3://{output_bucket}/{file_obj['Key']} ...")
    s3_client.download_file(output_bucket, file_obj["Key"], file_name)

Let’s visualize and understand the results. The result images shows the segmented image and the heatmap.

  • Segments: Highlights the image segments.

  • Shades of Blue: Represents positive Shapley values indicating that the corresponding feature increases the overall confidence score.

  • Shades of Red: Represents negative Shapley values indicating that the corresponding feature decreases the overall confidence score.

[ ]:
from IPython.display import Image

for img in result_images:
    display(Image(img))

Clean Up

Finally, don’t forget to clean up the resources we set up and used for this demo!

[ ]:
%%time

# Delete the SageMaker model
model.delete_model()

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable