Using SageMaker debugger to visualize class activation maps in CNNs
This notebook will demonstrate how to use SageMaker debugger to plot class activations maps for image classification models. A class activation map (saliency map) is a heatmap that highlights the regions in the image that lead the model to make a certain prediction. This is especially useful:
if the model makes a misclassification and it is not clear why;
or to determine if the model takes all important features of an object into account
In this notebook we will train a ResNet model on the German Traffic Sign Dataset and we will use SageMaker debugger to plot class activation maps in real-time.
The following animation shows the saliency map for a particular traffic sign as training progresses. Red highlights the regions with high activation leading to the prediction, blue indicates low activation that are less relevant for the prediction.
In the beginning the model will do a lot of mis-classifications as it focuses on the wrong image regions e.g. the obstacle in the lower left corner. As training progresses the focus shifts to the center of the image, and the model becomes more and more confident in predicting the class 3 (which is the correct class).
There exist several methods to generate saliency maps e.g. CAM, GradCAM. The paper Full-Gradient Representation for Neural Network Visualization [1] proposes a new method which produces state of the art results. It requires intermediate features and their biases. With SageMaker debugger we can easily retrieve these tensors.
[1] Full-Gradient Representation for Neural Network Visualization: Suraj Srinivas and Francois Fleuret, 2019, 1905.00780, arXiv
Customize the smdebug hook
To create saliency maps, the gradients of the prediction with respect to the intermediate features need to be computed. To obtain this information, we have to customize the smdebug hook. The custom hook is defined in entry_point/custom_hook.py During the forward pass, we register a backward hook on the outputs. We also need to get gradients of the input image, so we provide an additional function that registers a backward hook on the input tensor.
The paper Full-Gradient Representation for Neural Network Visualization [1] distinguishes between implicit and explicit biases. Implicit biases include running mean and variance from BatchNorm layers. With SageMaker debugger we only get the explicit biases which equals the beta paramater in the case of BatchNorm layers. We extend the hook to also record running averages and variances for BatchNorm layers.
import smdebug.pytorch as smd
class CustomHook(smd.Hook):
#register input image for backward pass, to get image gradients
def image_gradients(self, image):
image.register_hook(self.backward_hook("image"))
def forward_hook(self, module, inputs, outputs):
module_name = module._module_name
self._write_inputs(module_name, inputs)
#register outputs for backward pass. this is expensive, so we will only do it during EVAL mode
if self.mode == ModeKeys.EVAL:
outputs.register_hook(self.backward_hook(module_name + "_output"))
#record running mean and var of BatchNorm layers
if isinstance(module, torch.nn.BatchNorm2d):
self._write_outputs(module_name + ".running_mean", module.running_mean)
self._write_outputs(module_name + ".running_var", module.running_var)
self._write_outputs(module_name, outputs)
self.last_saved_step = self.step
Replace in-place operations
Additionally we need to convert inplace operations, as they can potentially overwrite values that are required to compute gradients. In the case of PyTorch pre-trained ResNet model, ReLU activatons are per default executed inplace. The following code sets inplace=False
[ ]:
def relu_inplace(model):
for child_name, child in model.named_children():
if isinstance(child, torch.nn.ReLU):
setattr(model, child_name, torch.nn.ReLU(inplace=False))
else:
relu_inplace(child)
Download the dataset and upload it to Amazon S3
Now we download the German Traffic Sign Dataset and upload it to Amazon S3. The training dataset consists of 43 image classes.
[ ]:
import zipfile
! wget https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB-Training_fixed.zip
with zipfile.ZipFile("GTSRB-Training_fixed.zip", "r") as zip_ref:
zip_ref.extractall("./")
The test dataset:
[ ]:
import zipfile
! wget https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_Images.zip
with zipfile.ZipFile("GTSRB_Final_Test_Images.zip", "r") as zip_ref:
zip_ref.extractall("./")
Now we upload the datasets to the SageMaker default bucket in Amazon S3.
[ ]:
import boto3
import sagemaker
import os
def upload_to_s3(path, directory_name, bucket, counter=-1):
print("Upload files from" + path + " to " + bucket)
client = boto3.client("s3")
for path, subdirs, files in os.walk(path):
path = path.replace("\\", "/")
print(path)
for file in files[0:counter]:
client.upload_file(
os.path.join(path, file),
bucket,
directory_name + "/" + path.split("/")[-1] + "/" + file,
)
boto_session = boto3.Session()
sagemaker_session = sagemaker.Session(boto_session=boto_session)
bucket = sagemaker_session.default_bucket()
upload_to_s3("GTSRB/Training", directory_name="train", bucket=bucket)
# we will compute saliency maps for all images in the test dataset, so we will only upload 4 images
upload_to_s3("GTSRB/Final_Test", directory_name="test", bucket=bucket, counter=4)
Before starting the SageMaker training job, we need to install some libraries. We will use smdebug
library to read, filter and analyze raw tensors that are stored in Amazon S3. We will use opencv-python
library to plot saliency maps as heatmap.
[ ]:
!apt-get update && apt-get install -y python3-opencv
[ ]:
import pip
def import_or_install(package):
try:
__import__(package)
except ImportError:
pip.main(["install", package])
[ ]:
import_or_install("smdebug")
[ ]:
import_or_install("opencv-python")
SageMaker training
Following code defines the SageMaker estimator. The entry point script train.py defines the model training. It downloads a pre-trained ResNet model and performs transfer learning on the German traffic sign dataset.
Debugger hook configuration
Next we define a custom collection where we indicate regular expression of tensor names to be included. Tensors from training phase are saved every 100 steps, while tensors from validation phase are saved every step. A step presents one forward and backward pass.
[ ]:
from sagemaker.debugger import DebuggerHookConfig, CollectionConfig
debugger_hook_config = DebuggerHookConfig(
collection_configs=[
CollectionConfig(
name="custom_collection",
parameters={
"include_regex": ".*bn|.*bias|.*downsample|.*ResNet_input|.*image|.*fc_output|.*CrossEntropyLoss",
"train.save_interval": "100",
"eval.save_interval": "1",
},
)
]
)
Builtin rule
In addition we run the training job with a builtin rule. We select here the class imbalance rule that measures whether our training set is imbalanced and/or whether the model has lower accurcay for certain classes in the training dataset. The tensors that are passed into the loss function CrossEntropyLoss
are the labels and predictions. In our example those tensors have the name CrossEntropyLoss_input_1
and CrossEntropyLoss_input_0
. The rule uses those tensors to compute class
imbalance.
[ ]:
from sagemaker.debugger import Rule, CollectionConfig, rule_configs
class_imbalance_rule = Rule.sagemaker(
base_config=rule_configs.class_imbalance(),
rule_parameters={
"labels_regex": "CrossEntropyLoss_input_1",
"predictions_regex": "CrossEntropyLoss_input_0",
"argmax": "True",
},
)
SageMaker training
Following code defines the SageMaker estimator. The entry point script train.py defines the model training. It downloads a pre-trained ResNet model and performs transfer learning on the German traffic sign dataset.
[ ]:
from sagemaker.pytorch import PyTorch
role = sagemaker.get_execution_role()
pytorch_estimator = PyTorch(
entry_point="train.py",
source_dir="entry_point",
role=role,
train_instance_type="ml.p3.2xlarge",
train_instance_count=1,
framework_version="1.12.0",
py_version="py38",
hyperparameters={
"epochs": 5,
"batch_size_train": 64,
"batch_size_val": 4,
"learning_rate": 0.001,
},
volume_size=100,
debugger_hook_config=debugger_hook_config,
rules=[class_imbalance_rule],
)
Now that we have defined the estimator we can call fit
, which will start the training job on a ml.p3.2xlarge
instance:
[ ]:
pytorch_estimator.fit(
inputs={"train": "s3://{}/train".format(bucket), "test": "s3://{}/test".format(bucket)},
wait=False,
)
Check rule status
[ ]:
pytorch_estimator.latest_training_job.rule_job_summary()
Visualize saliency maps in real-time
Once the training job has started, SageMaker debugger will upload the tensors of our model into S3. We can check the location in S3:
[ ]:
path = pytorch_estimator.latest_job_debugger_artifacts_path()
print("Tensors are stored in: {}".format(path))
We can check the status of our training job, by executing describe_training_job
:
[ ]:
job_name = pytorch_estimator.latest_training_job.name
print("Training job name: {}".format(job_name))
client = pytorch_estimator.sagemaker_session.sagemaker_client
description = client.describe_training_job(TrainingJobName=job_name)
We can access the tensors from S3 once the training job is in status Training
or Completed
. In the following code cell we check the job status:
[ ]:
import time
if description["TrainingJobStatus"] != "Completed":
while description["SecondaryStatus"] not in {"Training", "Completed"}:
description = client.describe_training_job(TrainingJobName=job_name)
primary_status = description["TrainingJobStatus"]
secondary_status = description["SecondaryStatus"]
print(
"Current job status: [PrimaryStatus: {}, SecondaryStatus: {}]".format(
primary_status, secondary_status
)
)
time.sleep(30)
Once the job is in status Training
or Completed
, we can create the trial:
[ ]:
from smdebug.trials import create_trial
trial = create_trial(path)
Now we can compute the saliency maps. The method described in Full-Gradient Representation for Neural Network Visualization [1] requires all intermediate features and their biases. The following cell retrieves the gradients for the outputs of batchnorm and downsampling layers and the corresponding biases. If you use a model other than ResNet you may need to adjust the regular expressions in the following cell:
[ ]:
biases, gradients = [], []
for tname in trial.tensor_names(regex=".*gradient.*bn.*output|.*gradient.*downsample.1.*output"):
gradients.append(tname)
for tname in trial.tensor_names(regex="^(?=.*bias)(?:(?!fc).)*$"):
biases.append(tname)
As mentioned in the beginning of the notebook, in the case of BatchNorm layers, we need to compute the implicit biases. In the following code cell we retrieve the necessary tensors:
[ ]:
bn_weights, running_vars, running_means = [], [], []
for tname in trial.tensor_names(regex=".*running_mean"):
running_means.append(tname)
for tname in trial.tensor_names(regex=".*running_var"):
running_vars.append(tname)
for tname in trial.tensor_names(regex=".*bn.*weight|.*downsample.1.*weight"):
bn_weights.append(tname)
We need to ensure that the tensors in the list are in order, e.g. bias vector and gradients need to be for the same layer. Let’s have a look on the tensors:
[ ]:
for bias, gradient, weight, running_var, running_mean in zip(
biases, gradients, bn_weights, running_vars, running_means
):
print(bias, gradient, weight, running_var, running_mean)
Here we define a helper function that is used later on to normalize tensors:
[ ]:
def normalize(tensor):
tensor = tensor - np.min(tensor)
tensor = tensor / np.max(tensor)
return tensor
A helper function to plot saliency maps:
[ ]:
import matplotlib.pyplot as plt
def plot(saliency_map, image, predicted_class, propability):
# clear matplotlib figure
plt.clf()
# revert normalization
mean = [[[0.485]], [[0.456]], [[0.406]]]
std = [[[0.229]], [[0.224]], [[0.225]]]
image = image * std + mean
# transpose image: color channel in last dimension
image = image.transpose(1, 2, 0)
image = (image * 255).astype(np.uint8)
# create heatmap: we multiply it with -1 because we use
# matplotlib to plot output results which inverts the colormap
saliency_map = -saliency_map * 255
saliency_map = saliency_map.astype(np.uint8)
heatmap = cv2.applyColorMap(saliency_map, cv2.COLORMAP_JET)
# overlay original image with heatmap
output_image = heatmap.astype(np.float32) + image.astype(np.float32)
# normalize
output_image = output_image / np.max(output_image)
# plot
fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(10, 5))
ax0.imshow(image)
ax1.imshow(output_image)
ax0.set_axis_off()
ax1.set_axis_off()
ax0.set_title("Input image")
ax1.set_title("Predicted class " + predicted_class + " with propability " + propability + "%")
plt.show()
A helper function to compute implicit biases:
[ ]:
def compute_implicit_biases(bn_weights, running_vars, running_means, step):
implicit_biases = []
for weight_name, running_var_name, running_mean_name in zip(
bn_weights, running_vars, running_means
):
weight = trial.tensor(weight_name).value(step_num=step, mode=modes.EVAL)
running_var = trial.tensor(running_var_name).value(step_num=step, mode=modes.EVAL)
running_mean = trial.tensor(running_mean_name).value(step_num=step, mode=modes.EVAL)
implicit_biases.append(-running_mean / np.sqrt(running_var) * weight)
return implicit_biases
Get available steps:
[ ]:
import time
steps = 0
while steps == 0:
steps = trial.steps()
print("Waiting for tensors to become available...")
time.sleep(3)
print("\nDone")
print("Getting tensors...")
rendered_steps = []
We iterate over the tensors from the validation steps and compute the saliency map for each item in the batch. To compute the saliency map, we perform the following steps:
compute the implicit bias
multiply gradients and bias (sum of explicit and implicit bias)
normalize result
interpolate tensor to the input size of the original input image
create heatmap and overlay it with the original input image
[ ]:
import numpy as np
import cv2
import scipy.ndimage
import scipy.special
from smdebug import modes
from smdebug.core.modes import ModeKeys
from smdebug.exceptions import TensorUnavailableForStep
import os
image_size = 224
loaded_all_steps = False
while not loaded_all_steps and description["SecondaryStatus"] != "Completed":
# get available steps
loaded_all_steps = trial.loaded_all_steps
steps = trial.steps(mode=modes.EVAL)
# quick way to get diff between two lists
steps_to_render = list(set(steps).symmetric_difference(set(rendered_steps)))
# iterate over available steps
for step in sorted(steps_to_render):
try:
# get original input image
image_batch = trial.tensor("ResNet_input_0").value(step_num=step, mode=modes.EVAL)
# compute implicit biases from batchnorm layers
implicit_biases = compute_implicit_biases(bn_weights, running_vars, running_means, step)
for item in range(image_batch.shape[0]):
# input image
image = image_batch[item, :, :, :]
# get gradients of input image
image_gradient = trial.tensor("gradient/image").value(
step_num=step, mode=modes.EVAL
)[item, :]
image_gradient = np.sum(normalize(np.abs(image_gradient * image)), axis=0)
saliency_map = image_gradient
for gradient_name, bias_name, implicit_bias in zip(
gradients, biases, implicit_biases
):
# get gradients and bias vectors for corresponding step
gradient = trial.tensor(gradient_name).value(step_num=step, mode=modes.EVAL)[
item : item + 1, :, :, :
]
bias = trial.tensor(bias_name).value(step_num=step, mode=modes.EVAL)
bias = bias + implicit_bias
# compute full gradient
bias = bias.reshape((1, bias.shape[0], 1, 1))
bias = np.broadcast_to(bias, gradient.shape)
bias_gradient = normalize(np.abs(bias * gradient))
# interpolate to original image size
for channel in range(bias_gradient.shape[1]):
interpolated = scipy.ndimage.zoom(
bias_gradient[0, channel, :, :],
image_size / bias_gradient.shape[2],
order=1,
)
saliency_map += interpolated
# normalize
saliency_map = normalize(saliency_map)
# predicted class and propability
predicted_class = trial.tensor("fc_output_0").value(step_num=step, mode=modes.EVAL)[
item, :
]
print("Predicted class:", np.argmax(predicted_class))
scores = np.exp(np.asarray(predicted_class))
scores = scores / scores.sum(0)
# plot image and heatmap
plot(
saliency_map,
image,
str(np.argmax(predicted_class)),
str(int(np.max(scores) * 100)),
)
except TensorUnavailableForStep:
print("Tensor unavailable for step {}".format(step))
rendered_steps.extend(steps_to_render)
time.sleep(5)
description = client.describe_training_job(TrainingJobName=job_name)
print("\nDone")