Deploying pre-trained PyTorch vision models with Amazon SageMaker Neo


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


Amazon SageMaker Neo is an API to compile machine learning models to optimize them for our choice of hardware targets. Currently, Neo supports pre-trained PyTorch models from TorchVision. General support for other PyTorch models is forthcoming.

Runtime

This notebook takes approximately 8 minutes to run.

Contents

  1. Import ResNet18 from TorchVision

  2. Invoke Neo Compilation API

  3. Deploy the model

  4. Send requests

  5. Delete the Endpoint

Import ResNet18 from TorchVision

We import the ResNet18 model from TorchVision and create a model artifact model.tar.gz.

[ ]:
import sys

!{sys.executable} -m pip install torch==1.13.0 torchvision==0.14.0
!{sys.executable} -m pip install s3transfer==0.5.0
!{sys.executable} -m pip install --upgrade sagemaker

Specify the input data shape. For more information, see Prepare Model for Compilation.

[ ]:
import sagemaker
import torch
import torchvision.models as models
import tarfile

resnet18 = models.resnet18(pretrained=True)
input_shape = [1, 3, 224, 224]
trace = torch.jit.trace(resnet18.float().eval(), torch.zeros(input_shape).float())
trace.save("model.pth")

with tarfile.open("model.tar.gz", "w:gz") as f:
    f.add("model.pth")

Upload the model archive to S3

Specify parameters for the compilation job and upload the model.tar.gz archive file.

[ ]:
import boto3
import sagemaker
import time
from sagemaker.utils import name_from_base

role = sagemaker.get_execution_role()
sess = sagemaker.Session()
region = sess.boto_region_name
bucket = sess.default_bucket()

compilation_job_name = name_from_base("TorchVision-ResNet18-Neo")
prefix = compilation_job_name + "/model"

model_path = sess.upload_data(path="model.tar.gz", key_prefix=prefix)

data_shape = '{"input0":[1,3,224,224]}'
target_device = "ml_c5"
framework = "PYTORCH"
framework_version = "1.13"
compiled_model_path = "s3://{}/{}/output".format(bucket, compilation_job_name)

Invoke Neo Compilation API

Create a PyTorch SageMaker model

Use the PyTorchModel and define parameters including the path to the model, the entry_point script that is used to perform inference, and other version and environment variables.

[ ]:
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.predictor import Predictor

sagemaker_model = PyTorchModel(
    model_data=model_path,
    predictor_cls=Predictor,
    framework_version=framework_version,
    role=role,
    sagemaker_session=sess,
    entry_point="resnet18.py",
    source_dir="code",
    py_version="py3",
    env={"MMS_DEFAULT_RESPONSE_TIMEOUT": "500"},
)

Use Neo compiler to compile the model

Run the compilation job, which is saved in S3 at the specified compiled_model_path location.

[ ]:
compiled_model = sagemaker_model.compile(
    target_instance_family=target_device,
    input_shape=data_shape,
    job_name=compilation_job_name,
    role=role,
    framework=framework.lower(),
    framework_version=framework_version,
    output_path=compiled_model_path,
)

Deploy the model

Deploy the compiled model to an endpoint so it can be used for inference.

[ ]:
predictor = compiled_model.deploy(initial_instance_count=1, instance_type="ml.c5.9xlarge")

Send requests

Let’s send a picture to the endpoint to predict the image subject.

title

Open the image and pass the payload as a bytearray to the predictor, receiving a response.

[ ]:
import numpy as np
import json

with open("cat.jpg", "rb") as f:
    payload = f.read()
    payload = bytearray(payload)

response = predictor.predict(payload)
result = json.loads(response.decode())
print("Most likely class: {}".format(np.argmax(result)))

Use the ImageNet class ID response to look up which subject the image contains, and with what probability.

[ ]:
# Load names for ImageNet classes
object_categories = {}
with open("imagenet1000_clsidx_to_labels.txt", "r") as f:
    for line in f:
        key, val = line.strip().split(":")
        object_categories[key] = val.strip(" ").strip(",")
print(
    "The label is",
    object_categories[str(np.argmax(result))],
    "with probability",
    str(np.amax(result))[:5],
)

Delete the Endpoint

Delete the endpoint to avoid incurring costs now that it is no longer needed.

[ ]:
predictor.delete_model()
sess.delete_endpoint(predictor.endpoint_name)

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable