Use SageMaker Batch Transform for PyTorch Batch Inference


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


In this notebook, we examine how to do a Batch Transform task with PyTorch in Amazon SageMaker.

First, an image classification model is built on the MNIST dataset. Then, we demonstrate batch transform by using the SageMaker Python SDK PyTorch framework with different configurations: - data_type=S3Prefix: uses all objects that match the specified S3 prefix for batch inference. - data_type=ManifestFile: a manifest file contains a list of object keys to use in batch inference. - instance_count>1: distributes the batch inference dataset to multiple inference instances.

For batch transform in TensorFlow in Amazon SageMaker, you can follow other Jupyter notebooks in the sagemaker_batch_transform directory.

Runtime

This notebook takes approximately 15 minutes to run.

Contents

  1. Setup

  2. Model training

  3. Prepare batch inference data

  4. Create model transformer

  5. Batch inference

  6. Look at all transform jobs

  7. Conclusion

Setup

We’ll begin with some necessary installs and imports, and get an Amazon SageMaker session to help perform certain tasks, as well as an IAM role with the necessary permissions.

[ ]:
!pip install nvidia-ml-py3
!yes | pip uninstall torchvision
!pip install torchvision
[ ]:
%matplotlib inline
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import os
from os import listdir
from os.path import isfile, join
from shutil import copyfile
import sagemaker
from sagemaker.pytorch import PyTorchModel
from sagemaker import get_execution_role

sagemaker_session = sagemaker.Session()
region = sagemaker_session.boto_region_name
role = get_execution_role()

bucket = sagemaker_session.default_bucket()
prefix = "sagemaker/DEMO-pytorch-batch-inference-script"
print("Bucket: {}".format(bucket))

Model training

Since the main purpose of this notebook is to demonstrate SageMaker PyTorch batch transform, we reuse a SageMaker Python SDK PyTorch MNIST example to train a PyTorch model. It takes around 7 minutes to finish the training.

[ ]:
from torchvision.datasets import MNIST
from torchvision import transforms

local_dir = "data"
MNIST.mirrors = [
    f"https://sagemaker-example-files-prod-{region}.s3.amazonaws.com/datasets/image/MNIST/"
]
MNIST(
    local_dir,
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    ),
)


inputs = sagemaker_session.upload_data(path=local_dir, bucket=bucket, key_prefix=prefix)
print("input spec (in this case, just an S3 path): {}".format(inputs))

from sagemaker.pytorch import PyTorch

estimator = PyTorch(
    entry_point="model-script/mnist.py",
    role=role,
    framework_version="1.8.0",
    py_version="py3",
    instance_count=3,
    instance_type="ml.c5.2xlarge",
    hyperparameters={
        "epochs": 1,
        "backend": "gloo",
    },  # set epochs to a more realistic number for real training
)

estimator.fit({"training": inputs})

Prepare batch inference data

Convert the test data into PNG image format.

[ ]:
!ls data/MNIST/raw
[ ]:
# untar gz => png

import gzip
import numpy as np
import os

with gzip.open(os.path.join(local_dir, "MNIST/raw", "t10k-images-idx3-ubyte.gz"), "rb") as f:
    images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)
[ ]:
print(len(images), "test images")

Randomly sample 100 test images and upload them to S3.

[ ]:
import random
from PIL import Image as im

ids = random.sample(range(len(images)), 100)
ids = np.array(ids, dtype=np.int)
selected_images = images[ids]

image_dir = "data/images"

if not os.path.exists(image_dir):
    os.makedirs(image_dir)

for i, img in enumerate(selected_images):
    pngimg = im.fromarray(img)
    pngimg.save(os.path.join(image_dir, f"{i}.png"))
[ ]:
inference_prefix = "batch_transform"
inference_inputs = sagemaker_session.upload_data(
    path=image_dir, bucket=bucket, key_prefix=inference_prefix
)
print("Input S3 path: {}".format(inference_inputs))

Create model transformer

Now, we create a transformer object for creating and interacting with Amazon SageMaker transform jobs. We can create the transformer in two ways: 1. Use a fitted estimator directly. 1. First create a PyTorchModel from a saved model artifact, and then create a transformer from the PyTorchModel object.

Here, we implement the model_fn, input_fn, predict_fn and output_fn function to override the default PyTorch inference handler.

In the input_fn() function, the inferenced images are encoded as a Python ByteArray. That’s why we use the load_from_bytearray() function to load images from io.BytesIO and then use PIL.image to read the images.

def model_fn(model_dir):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = torch.nn.DataParallel(Net())
    with open(os.path.join(model_dir, "model.pth"), "rb") as f:
        model.load_state_dict(torch.load(f))
    return model.to(device)


def load_from_bytearray(request_body):
    image_as_bytes = io.BytesIO(request_body)
    image = Image.open(image_as_bytes)
    image_tensor = ToTensor()(image).unsqueeze(0)
    return image_tensor


def input_fn(request_body, request_content_type):
    # if set content_type as "image/jpg" or "application/x-npy",
    # the input is also a python bytearray
    if request_content_type == "application/x-image":
        image_tensor = load_from_bytearray(request_body)
    else:
        print("not support this type yet")
        raise ValueError("not support this type yet")
    return image_tensor


# Perform prediction on the deserialized object, with the loaded model
def predict_fn(input_object, model):
    output = model.forward(input_object)
    pred = output.max(1, keepdim=True)[1]

    return {"predictions": pred.item()}


# Serialize the prediction result into the desired response content type
def output_fn(predictions, response_content_type):
    return json.dumps(predictions)
[ ]:
# Use fitted estimator directly
transformer = estimator.transformer(instance_count=1, instance_type="ml.c5.xlarge")
[ ]:
# You can also create a Transformer object from saved model artifact

# Get model artifact location by estimator.model_data, or give an S3 key directly
model_artifact_s3_location = estimator.model_data  # "s3://<BUCKET>/<PREFIX>/model.tar.gz"

# Create PyTorchModel from saved model artifact
pytorch_model = PyTorchModel(
    model_data=model_artifact_s3_location,
    role=role,
    framework_version="1.8.0",
    py_version="py3",
    source_dir="model-script/",
    entry_point="mnist.py",
)

# Create transformer from PyTorchModel object
transformer = pytorch_model.transformer(instance_count=1, instance_type="ml.c5.xlarge")

Batch inference

Next, we perform inference on the sampled 100 MNIST images in a batch manner.

Input images directly from S3 location

We set S3DataType=S3Prefix to use all objects that match the specified S3 prefix for batch inference.

[ ]:
transformer.transform(
    data=inference_inputs,
    data_type="S3Prefix",
    content_type="application/x-image",
    wait=True,
)

Input images by manifest file

First, we generate a manifest file. Then we use the manifest file containing a list of object keys as inputs to batch inference. Some key points: - content_type = "application/x-image" (here the content_type is for the actual object for inference, not for the manifest file) - data_type = "ManifestFile" - Manifest file format must follow the format as S3DataSource points out. We create the manifest file by using the jsonlines package.

[
    {"prefix": "s3://customer_bucket/some/prefix/"},
    "relative/path/to/custdata-1",
    "relative/path/custdata-2",
    ...
    "relative/path/custdata-N"
]
[ ]:
!pip install -q jsonlines
[ ]:
import jsonlines

# Build image list
manifest_prefix = f"s3://{bucket}/{prefix}/images/"

path = image_dir
img_files = [f for f in listdir(path) if isfile(join(path, f))]

print("img_files\n", img_files)

manifest_content = [{"prefix": manifest_prefix}]
manifest_content.extend(img_files)

print("manifest_content\n", manifest_content)

# Write jsonl file
manifest_file = "manifest.json"
with jsonlines.open(manifest_file, mode="w") as writer:
    writer.write(manifest_content)

# Upload to S3
manifest_obj = sagemaker_session.upload_data(path=manifest_file, key_prefix=prefix)

print("manifest_obj\n", manifest_obj)
[ ]:
# Batch transform with manifest file
transform_job = transformer.transform(
    data=manifest_obj,
    data_type="ManifestFile",
    content_type="application/x-image",
    wait=False,
)
[ ]:
print("Latest transform job:", transformer.latest_transform_job.name)
[ ]:
# look at the status of the transform job
import pprint as pp

sm_cli = sagemaker_session.sagemaker_client

job_info = sm_cli.describe_transform_job(TransformJobName=transformer.latest_transform_job.name)

pp.pprint(job_info)

Multiple instance

We use instance_count > 1 to create multiple inference instances. When a batch transform job starts, Amazon SageMaker initializes compute instances and distributes the inference or preprocessing workload between them. Batch Transform partitions the Amazon S3 objects in the input by key and maps Amazon S3 objects to instances. Given multiple files, one instance might process input1.csv, and another instance might process input2.csv. Read more at Use Batch Transform.

[ ]:
dist_transformer = estimator.transformer(instance_count=2, instance_type="ml.c4.xlarge")

dist_transformer.transform(
    data=inference_inputs,
    data_type="S3Prefix",
    content_type="application/x-image",
    wait=True,
)

Look at all transform jobs

We list and describe the transform jobs to retrieve information about them.

[ ]:
transform_jobs = sm_cli.list_transform_jobs()["TransformJobSummaries"]
for job in transform_jobs:
    pp.pprint(job)
[ ]:
job_info = sm_cli.describe_transform_job(
    TransformJobName=dist_transformer.latest_transform_job.name
)

pp.pprint(job_info)
[ ]:
import re


def get_bucket_and_prefix(s3_output_path):
    trim = re.sub("s3://", "", s3_output_path)
    bucket, prefix = trim.split("/")
    return bucket, prefix


local_path = "output"  # Where to save the output locally

bucket, output_prefix = get_bucket_and_prefix(job_info["TransformOutput"]["S3OutputPath"])
print(bucket, output_prefix)

sagemaker_session.download_data(path=local_path, bucket=bucket, key_prefix=output_prefix)
[ ]:
!ls {local_path}
[ ]:
# Inspect the output

import json

for f in os.listdir(local_path):
    path = os.path.join(local_path, f)
    with open(path, "r") as f:
        pred = json.load(f)
        print(pred)

Conclusion

In this notebook, we trained a PyTorch model, created a transformer from it, and then performed batch inference using S3 inputs, manifest files, and on multiple instances. This shows a variety of options that are available when running SageMaker Batch Transform jobs for batch inference.

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable