Explaining Object Detection model with Amazon SageMaker Clarify

In this notebook, we deploy a pre-trained image detection model to showcase how you can use Amazon SagemaMaker Clarify explainability features for Computer Vision, specifically for object detection models including your own ones.

1. We first import a model from the Gluon model zoo locally on the notebook, that we then compress and send to S3

2. We then use the SageMaker MXNet Serving feature to deploy the model to a managed SageMaker endpoint. It uses the model artifact that we previously loaded to S3.

3. We query the endpoint and visualize detection results

4. We explain the predictions of the model using Amazon SageMaker Clarify.

This notebook can be run with the conda_python3 Kernel.

More on Amazon SageMaker Clarify:

Amazon SageMaker Clarify helps improve your machine learning models by detecting potential bias and helping explain how these models make predictions. The fairness and explainability functionality provided by SageMaker Clarify takes a step towards enabling AWS customers to build trustworthy and understandable machine learning models. The product comes with the tools to help you with the following tasks.

Measure biases that can occur during each stage of the ML lifecycle (data collection, model training and tuning, and monitoring of ML models deployed for inference). Generate model governance reports targeting risk and compliance teams and external regulators. Provide explanations of the data, models, and monitoring used to assess predictions for input containing data of various modalities like numerical data, categorical data, text, and images. Learn more about SageMaker Clarify here: https://aws.amazon.com/sagemaker/clarify/.

More on Gluon and Gluon CV:

• Gluon is the imperative python front-end of the Apache MXNet deep learning framework. Gluon notably features specialized toolkits helping reproducing state-of-the-art architectures: Gluon-CV, Gluon-NLP, Gluon-TS. Gluon also features a number of excellent end-to-end tutorials mixing science with code such as D2L.ai and The Straight Dope

• Gluon-CV is an efficient computer vision toolkit written on top of Gluon and MXNet aiming to make state-of-the-art vision research reproducible.

This sample is provided for demonstration purposes, make sure to conduct appropriate testing if derivating this code for your own use-cases!

Index:

1. Test a pre-trained detection model, locally

2. Instantiate model

3. Create endpoint and get predictions (optional)

4. Run Clarify and interpret predictions

[ ]:

! pip install -r requirements.txt


Let’s start by installing the latest version of the SageMaker Python SDK, boto, and AWS CLI.

[ ]:

! pip install sagemaker botocore boto3 awscli --upgrade

[ ]:

import datetime
import json
import math
import os
import shutil
from subprocess import check_call
import tarfile

from PIL import Image
import numpy as np
from matplotlib import pyplot as plt

import boto3
import botocore

import sagemaker
from sagemaker import get_execution_role
from sagemaker.mxnet.model import MXNetModel

import gluoncv
from gluoncv import model_zoo, data, utils
import mxnet as mx
from mxnet import gluon, image, nd

[ ]:

sm_sess = sagemaker.Session()
sm_client = boto3.client("sagemaker")

s3_bucket = (
sm_sess.default_bucket()
)  # We use this bucket to store model weights - don't hesitate to change.
print(f"using bucket {s3_bucket}")

# For a sagemaker notebook
sm_role = sagemaker.get_execution_role()
# Override the role if you are executing locally:
# sm_role = "arn:aws:iam::<account>:role/service-role/AmazonSageMaker-ExecutionRole"


Constants

[ ]:

TEST_IMAGE_DIR = "caltech"  # directory with test images
MODEL_NAME = "yolo3_darknet53_coco"
S3_KEY_PREFIX = "clarify_object_detection"  # S3 Key to store model artifacts
ENDPOINT_INSTANCE_TYPE = "ml.g4dn.xlarge"
ANALYZER_INSTANCE_TYPE = "ml.c5.xlarge"
ANALYZER_INSTANCE_COUNT = 1

[ ]:

def gen_unique_name(model_name: str):
# Generate a unique name for this user / host combination
import hashlib
import socket
import getpass

user = getpass.getuser()
host = socket.gethostname()
h = hashlib.sha256()
h.update(user.encode())
h.update(host.encode())
res = model_name + "-" + h.hexdigest()[:8]
res = res.replace("_", "-").replace(".", "")
return res


Test a pre-trained detection model, locally

Gluon model zoo contains a variety of models. In this demo we use a YoloV3 detection model (Redmon et Farhadi). More about YoloV3: * Paper https://pjreddie.com/media/files/papers/YOLOv3.pdf * Website https://pjreddie.com/darknet/yolo/

Gluon CV model zoo contains a number of architectures with different tradeoffs in terms of speed and accuracy. If you are looking for speed or accuracy, don’t hesitate to change the model

[ ]:

net = model_zoo.get_model(MODEL_NAME, pretrained=True)


The model we downloaded above is trained on the COCO dataset and can detect 80 classes. In this demo, we restrict the model to detect only specific classes of interest. This idea is derived from the official Gluon CV tutorial: https://gluon-cv.mxnet.io/build/examples_detection/skip_fintune.html

COCO contains the following classes:

[ ]:

print("coco classes: ", sorted(net.classes))

[ ]:

# in this demo we reset the detector to the following classes
classes = ["dog", "elephant", "zebra", "bear"]
net.reset_class(classes=classes, reuse_weights=classes)
print("new classes: ", net.classes)
net.hybridize()  # hybridize to optimize computation


Get RGB images from the Caltech 256 dataset [Griffin, G. Holub, AD. Perona, P. The Caltech 256. Caltech Technical Report.]

[ ]:

import urllib.request
import os

list_of_images = [
"009.bear/009_0001.jpg",
"009.bear/009_0002.jpg",
"056.dog/056_0023.jpg",
"056.dog/056_0001.jpg",
"064.elephant-101/064_0003.jpg",
"064.elephant-101/064_0004.jpg",
"064.elephant-101/064_0006.jpg",
"250.zebra/250_0001.jpg",
"250.zebra/250_0002.jpg",
]

source_url = "https://sagemaker-sample-files.s3.amazonaws.com/datasets/image/caltech-256/256_ObjectCategories/"

if not os.path.exists(TEST_IMAGE_DIR):
os.makedirs(TEST_IMAGE_DIR)

for image_name in list_of_images:
url = source_url + image_name
file_name = TEST_IMAGE_DIR + "/" + image_name.replace("/", "_")
urllib.request.urlretrieve(url, file_name)


Test locally

[ ]:

import glob

test_images = glob.glob(f"{TEST_IMAGE_DIR}/*.jpg")
test_images


gluoncv comes with built-in pre-processing logic for popular detectors, including YoloV3:

https://gluon-cv.mxnet.io/_modules/gluoncv/data/transforms/presets/yolo.html

https://gluon-cv.mxnet.io/build/examples_detection/demo_yolo.html

Let’s see how the network computes detections in a single image, we have to first resize and reshape, since the original image is loaded with channels in the last dimension and MXNet will expect a shape of (num_batches, channels, width, height)

[ ]:

transformed_image, _ = data.transforms.presets.yolo.transform_test(image.imread(test_images[-1]))


The network returns 3 tensors: class_ids, scores and bounding boxes. The default is up to 100 detections, so we get tensor with shape (num batches, detections, …) where the last dimension is 4 for the bounding boxes as we have upper right corner, and lower right corner coordinates.

[ ]:

(cids, scores, bboxs) = net(transformed_image)

[ ]:

cids.shape

[ ]:

scores.shape

[ ]:

bboxs.shape

[ ]:

bboxs[:, 0, :]

[ ]:

n_pics = len(test_images)
n_cols = 3
n_rows = max(math.ceil(n_pics / n_cols), 2)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 15))
[ax.axis("off") for ax_dim in axes for ax in ax_dim]
for i, pic in enumerate(test_images):
curr_col = i % n_cols
curr_row = i // n_cols
print(pic)
x, orig_img = data.transforms.presets.yolo.transform_test(im_array)
# forward pass and display
box_ids, scores, bboxes = net(x)
ax = utils.viz.plot_bbox(
orig_img,
bboxes[0],
scores[0],
box_ids[0],
class_names=classes,
thresh=0.9,
ax=axes[curr_row, curr_col],
)
ax.axis("off")
fig.tight_layout()
fig.show();


Deploy the detection server

1. We first need to send the model to S3, as we will provide the S3 model path to Amazon SageMaker endpoint creation API

2. We create a serving script containing model deserialization code and inference logic. This logic is in the repo folder.

3. We deploy the endpoint with a SageMaker SDK call

Save local model, compress and send to S3

Clarify needs a model since it will spin up its own inference endpoint to get explanations. We will now export the local model, archieve it and then create a SageMaker model from this archieve which allows to create other resources that depend on this model.

[ ]:

# save the full local model (both weights and graph)
net.export(MODEL_NAME, epoch=0)

[ ]:

# compress into a tar file
model_file = "model.tar.gz"
tar = tarfile.open(model_file, "w:gz")
tar.close()

[ ]:

# upload to s3
model_data_s3_uri


Instantiate model

We use batching of images on the predictor entry_point in order to achieve higher performance as utilization of resources is better than one image at a time.

[ ]:

model = MXNetModel(
model_data=model_data_s3_uri,
role=sm_role,
py_version="py37",
entry_point="detection_server_batch.py",
source_dir="repo",
framework_version="1.8.0",
sagemaker_session=sm_sess,
)

container_def = model.prepare_container_def(instance_type=ENDPOINT_INSTANCE_TYPE)
model_name = gen_unique_name(MODEL_NAME)
sm_sess.create_model(role=sm_role, name=model_name, container_defs=[container_def])


(Optional) Create endpoint and get predictions, model IO in depth

In this optional section we deploy an endpoint to get predictions and dive deep into details that can be helpful to troubleshot issues related to expected model IO format of predictions, serialization and tensor shapes.

Common pitfalls are usually solved by making sure we are using the right serializer and deserializer and that the model output conforms to the expectations of Clarify in terms of shapes and semantics of the output tensors.

In general, Clarify expectes that our model receieves a batch of images and outputs a batch of image detections with a tensor having the following elements: class id, prediction score and normalized bounding box of the detection.

[ ]:

endpoint_name = gen_unique_name(MODEL_NAME)
endpoint_name


Delete any previous enpoint

[ ]:

try:
sm_sess.delete_endpoint(endpoint_name)
except:
pass


Delete any stale endpoint config

[ ]:

try:
sm_sess.delete_endpoint_config(endpoint_name)
except botocore.exceptions.ClientError as e:
print(e)
pass


Deploy the model in a SageMaker endpoint

[ ]:

import sagemaker.serializers
import sagemaker.deserializers

print(model.name)
predictor = model.deploy(
initial_instance_count=1,
instance_type=ENDPOINT_INSTANCE_TYPE,
endpoint_name=endpoint_name,
serializer=sagemaker.serializers.NumpySerializer(),
deserializer=sagemaker.deserializers.JSONDeserializer(),
)

[ ]:

predictor.deserializer

[ ]:

predictor.serializer

[ ]:

predictor.accept


Let’s go in detail on how the detection server works, let’s take the following test image as an example:

[ ]:

im = Image.open(test_images[0])
im


Since we overrode the transform_fn making it support batches and normalizing the detection boxes, we feed a tensor with a single batch, H, W and the 3 color channels as input

[ ]:

im_np = np.array([np.asarray(im)])

[ ]:

im_np.shape

[ ]:

(H, W) = im_np.shape[1:3]

[ ]:

(H, W)


Send the image to the predictor and get detections

[ ]:

tensor = np.array(predictor.predict(im_np))

[ ]:

tensor

[ ]:

tensor.shape


Our prediction has one batch, 3 detections and 6 elements containing class_id, score and normalized box with upper left corner, and lower left corner.

[ ]:

box_scale = np.array([W, H, W, H])


To display the detections we undo the normalization and split the detection format that clarify uses so we use the gluon plot_bbox function with the non-normalized boxes and separate scores and class ids from detections

[ ]:

box_scale

[ ]:

numdet = tensor.shape[1]
cids = np.zeros(numdet)
scores = np.zeros(numdet)
bboxes = np.zeros((numdet, 4))
for i, det in enumerate(tensor[0]):
cids[i] = det[0]
scores[i] = det[1]
bboxes[i] = det[2:]
bboxes[i] *= box_scale
bboxes[i] = bboxes[i].astype("int")

[ ]:

bboxes

[ ]:

scores

[ ]:

utils.viz.plot_bbox(np.asarray(im), bboxes, scores, cids, class_names=classes, thresh=0.8)


We can group the logic above in a function to make it more convenient to use

[ ]:

def detect(pic, predictor):
"""elementary function to send a picture to a predictor"""
im = Image.open(pic)
im = im.convert("RGB")
im_np = np.array([np.asarray(im)])
(h, w) = im_np.shape[1:3]
prediction = np.array(predictor.predict(im_np))
box_scale = np.array([w, h, w, h])
numdet = prediction.shape[1]
cids = np.zeros(numdet)
scores = np.zeros(numdet)
bboxes = np.zeros((numdet, 4))
for i, det in enumerate(prediction[0]):
cids[i] = det[0]
scores[i] = det[1]
bboxes[i] = det[2:]
bboxes[i] *= box_scale
bboxes[i] = bboxes[i].astype("int")
return (cids, scores, bboxes)

[ ]:

%%time
pic = test_images[0]
cids, scores, bboxes = detect(pic, predictor)

[ ]:

cids

[ ]:

bboxes

[ ]:

# for local viz we need to resize local pic to the server-side resize
utils.viz.plot_bbox(orig_img, bboxes, scores, cids, class_names=classes, thresh=0.9)

[ ]:

cids


There’s a single detection of a dog which is class index 0 as in the beginning of the notebook where we called reset_class

Amazon Sagemaker Clarify

We will now showcase how to use SageMaker Clarify to explain detections by the model, for that we have already done some work in detection_server_batch.py to filter out missing detections with index -1 and we have normalized the boxes to the image dimensions. We only need to upload the data to s3, provide the configuration for Clarify in the analysis_config.json describing the explainability job parameters and execute the processing job with the data and configuration as inputs. As a result, we will get in S3 the explanation for the detections of the model.

Clarify expects detections to be in the format explored in the cells above. Detections should come in a tensor of shape (num_images, batch, detections, 6). The first number of each detection is the predicted class label. The second number is the associated confidence score for the detection. The last four numbers represent the bounding box coordinates [xmin / w, ymin / h, xmax / w, ymax / h]. These output bounding box corner indices are normalized by the overall image size dimensions, where w is the width of the image, and h is the height.

Upload some test images to get explanations

[ ]:

s3_test_images = f"{S3_KEY_PREFIX}/test_images"

[ ]:

!mkdir -p test_images
!cp {TEST_IMAGE_DIR}/009.bear_009_0002.jpg test_images
!cp {TEST_IMAGE_DIR}/064.elephant-101_064_0003.jpg test_images

[ ]:

dataset_uri = sm_sess.upload_data("test_images", key_prefix=s3_test_images)
dataset_uri


We use this noise image as a baseline to mask different segments of the image during the explainability process

[ ]:

baseline_uri = sm_sess.upload_data("noise_rgb.png", key_prefix=S3_KEY_PREFIX)


It’s very important that predictor.content_type and predictor.accept_type in the json fields below match the sagemaker python sdk predictor.serializer and predictor.deserializer class instances above such as sagemaker.serializers.NumpySerializer so Clarify job can use the right (de)serializer.

Clarify job configuration for object detection type of models

We will configure important parameters of the Clarify job for object detection under image_config:

• num_samples: This number determines the size of the generated synthetic dataset to compute the SHAP values. More samples will produce more accurate explanations but will consume more computational resources

• baseline: image that will be used to mask segments during Kernel SHAP

• num_segments: number of segments to partition the detection image into

• max_objects: maximum number of objects starting from the first that will be considered sorted by predicted score

• iou_threshold: minimum IOU for considering predictions against the original detections, as detection boxes will shift during masking

• context: whether to mask the image background when running SHAP, takes values 0 or 1

Below we use the Sagemaker Python SDK which helps create an Analysis configuration but using higher level Python classes.

[ ]:

from sagemaker.clarify import (
SageMakerClarifyProcessor,
ModelConfig,
DataConfig,
SHAPConfig,
ImageConfig,
ModelPredictedLabelConfig,
)
from sagemaker.utils import unique_name_from_base


Configure parameters of the Clarify Processing job. The job has one input, the config file and one output, the resulting analysis of the model.

[ ]:

analyzer_instance_count = 1
analyzer_instance_type = "ml.c5.xlarge"
output_bucket = sm_sess.default_bucket()
# Here we specify where to store the results.
analysis_result_path = "s3://{}/{}/{}".format(output_bucket, S3_KEY_PREFIX, "cv_analysis_result")

clarify_processor: SageMakerClarifyProcessor = SageMakerClarifyProcessor(
role=sm_role,
instance_count=analyzer_instance_count,
instance_type=analyzer_instance_type,
max_runtime_in_seconds=3600,
sagemaker_session=sm_sess,
)

model_config: ModelConfig = ModelConfig(
model_name=model_name,
instance_count=1,
instance_type=ENDPOINT_INSTANCE_TYPE,
content_type="application/x-npy",
)

data_config: DataConfig = DataConfig(
s3_data_input_path=dataset_uri,
s3_output_path=analysis_result_path,
dataset_type="application/x-image",
)

image_config: ImageConfig = ImageConfig(
model_type="OBJECT_DETECTION",
feature_extraction_method="segmentation",
num_segments=20,
segment_compactness=5,
max_objects=5,
iou_threshold=0.5,
context=1.0,
)

shap_config: SHAPConfig = SHAPConfig(
baseline=baseline_uri,
num_samples=500,
image_config=image_config,
)



Now run the processing job, it will take approximately 6 minutes.

[ ]:

clarify_processor.run_explainability(
data_config=data_config,
model_config=model_config,
model_scores=predictions_config,
explainability_config=shap_config,
job_name=unique_name_from_base("clarify-cv-object-detection"),
wait=True,
)


[ ]:

!aws s3 cp --recursive {analysis_result_path} cv_analysis_result

[ ]:

im = Image.open("cv_analysis_result/shap_064.elephant-101_064_0003_box1.jpeg")
im

[ ]:

im = Image.open("cv_analysis_result/shap_064.elephant-101_064_0003_box2.jpeg")
im

[ ]:

im = Image.open("cv_analysis_result/064.elephant-101_064_0003_objects.jpeg")
im


Cleanup of resources

We delete the previous endpoint

[ ]:

sm_sess.delete_endpoint(endpoint_name)