Deploying pre-trained PyTorch vision models with Amazon SageMaker Neo


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


Amazon SageMaker Neo is an API to compile machine learning models to optimize them for our choice of hardware targets. Currently, Neo supports pre-trained PyTorch models from TorchVision. General support for other PyTorch models is forthcoming.

Runtime

This notebook takes approximately 8 minutes to run.

Contents

  1. Import ResNet18 from TorchVision

  2. Invoke Neo Compilation API

  3. Deploy the model

  4. Send requests

  5. Delete the Endpoint

Import ResNet18 from TorchVision

We import the ResNet18 model from TorchVision and create a model artifact model.tar.gz.

[2]:
import sys

!{sys.executable} -m pip install torch==1.6.0 torchvision==0.7.0
!{sys.executable} -m pip install s3transfer==0.5.0
!{sys.executable} -m pip install --upgrade sagemaker
Collecting torch==1.6.0
  Downloading torch-1.6.0-cp36-cp36m-manylinux1_x86_64.whl (748.8 MB)
     |████████████████████████████████| 748.8 MB 3.2 kB/s
Collecting torchvision==0.7.0
  Downloading torchvision-0.7.0-cp36-cp36m-manylinux1_x86_64.whl (5.9 MB)
     |████████████████████████████████| 5.9 MB 26.7 MB/s
Collecting future
  Downloading future-0.18.2.tar.gz (829 kB)
     |████████████████████████████████| 829 kB 76.1 MB/s
Requirement already satisfied: numpy in /opt/conda/lib/python3.6/site-packages (from torch==1.6.0) (1.19.5)
Requirement already satisfied: pillow>=4.1.1 in /opt/conda/lib/python3.6/site-packages (from torchvision==0.7.0) (8.1.2)
Building wheels for collected packages: future
  Building wheel for future (setup.py) ... - \ | done
  Created wheel for future: filename=future-0.18.2-py3-none-any.whl size=491059 sha256=7aedfbad26ef83eb3dd8359611d9d98dc34253863139314d13c03d4f4fad4842
  Stored in directory: /root/.cache/pip/wheels/6e/9c/ed/4499c9865ac1002697793e0ae05ba6be33553d098f3347fb94
Successfully built future
Installing collected packages: future, torch, torchvision
  Attempting uninstall: torch
    Found existing installation: torch 1.4.0
    Uninstalling torch-1.4.0:
      Successfully uninstalled torch-1.4.0
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.5.0+cpu
    Uninstalling torchvision-0.5.0+cpu:
      Successfully uninstalled torchvision-0.5.0+cpu
Successfully installed future-0.18.2 torch-1.6.0 torchvision-0.7.0
Collecting s3transfer==0.5.0
  Downloading s3transfer-0.5.0-py3-none-any.whl (79 kB)
     |████████████████████████████████| 79 kB 4.2 MB/s
Requirement already satisfied: botocore<2.0a.0,>=1.12.36 in /opt/conda/lib/python3.6/site-packages (from s3transfer==0.5.0) (1.23.7)
Requirement already satisfied: urllib3<1.27,>=1.25.4 in /opt/conda/lib/python3.6/site-packages (from botocore<2.0a.0,>=1.12.36->s3transfer==0.5.0) (1.25.11)
Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /opt/conda/lib/python3.6/site-packages (from botocore<2.0a.0,>=1.12.36->s3transfer==0.5.0) (2.8.1)
Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /opt/conda/lib/python3.6/site-packages (from botocore<2.0a.0,>=1.12.36->s3transfer==0.5.0) (0.10.0)
Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.6/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<2.0a.0,>=1.12.36->s3transfer==0.5.0) (1.15.0)
Installing collected packages: s3transfer
  Attempting uninstall: s3transfer
    Found existing installation: s3transfer 0.5.2
    Uninstalling s3transfer-0.5.2:
      Successfully uninstalled s3transfer-0.5.2
Successfully installed s3transfer-0.5.0
Requirement already satisfied: sagemaker in /opt/conda/lib/python3.6/site-packages (2.69.1.dev0)
Collecting sagemaker
  Downloading sagemaker-2.86.2.tar.gz (521 kB)
     |████████████████████████████████| 521 kB 8.6 MB/s
Requirement already satisfied: attrs==20.3.0 in /opt/conda/lib/python3.6/site-packages (from sagemaker) (20.3.0)
Collecting boto3>=1.20.21
  Downloading boto3-1.21.42-py3-none-any.whl (132 kB)
     |████████████████████████████████| 132 kB 92.3 MB/s
Requirement already satisfied: google-pasta in /opt/conda/lib/python3.6/site-packages (from sagemaker) (0.2.0)
Requirement already satisfied: numpy>=1.9.0 in /opt/conda/lib/python3.6/site-packages (from sagemaker) (1.19.5)
Requirement already satisfied: protobuf>=3.1 in /opt/conda/lib/python3.6/site-packages (from sagemaker) (3.15.5)
Requirement already satisfied: protobuf3-to-dict>=0.1.5 in /opt/conda/lib/python3.6/site-packages (from sagemaker) (0.1.5)
Requirement already satisfied: smdebug_rulesconfig==1.0.1 in /opt/conda/lib/python3.6/site-packages (from sagemaker) (1.0.1)
Requirement already satisfied: importlib-metadata>=1.4.0 in /opt/conda/lib/python3.6/site-packages (from sagemaker) (3.7.2)
Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.6/site-packages (from sagemaker) (20.9)
Requirement already satisfied: pandas in /opt/conda/lib/python3.6/site-packages (from sagemaker) (0.25.0)
Requirement already satisfied: pathos in /opt/conda/lib/python3.6/site-packages (from sagemaker) (0.2.8)
Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in /opt/conda/lib/python3.6/site-packages (from boto3>=1.20.21->sagemaker) (0.10.0)
Requirement already satisfied: s3transfer<0.6.0,>=0.5.0 in /opt/conda/lib/python3.6/site-packages (from boto3>=1.20.21->sagemaker) (0.5.0)
Collecting botocore<1.25.0,>=1.24.42
  Downloading botocore-1.24.42-py3-none-any.whl (8.7 MB)
     |████████████████████████████████| 8.7 MB 105.2 MB/s
Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /opt/conda/lib/python3.6/site-packages (from botocore<1.25.0,>=1.24.42->boto3>=1.20.21->sagemaker) (2.8.1)
Requirement already satisfied: urllib3<1.27,>=1.25.4 in /opt/conda/lib/python3.6/site-packages (from botocore<1.25.0,>=1.24.42->boto3>=1.20.21->sagemaker) (1.25.11)
Requirement already satisfied: typing-extensions>=3.6.4 in /opt/conda/lib/python3.6/site-packages (from importlib-metadata>=1.4.0->sagemaker) (3.7.4.3)
Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.6/site-packages (from importlib-metadata>=1.4.0->sagemaker) (3.4.1)
Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.6/site-packages (from packaging>=20.0->sagemaker) (2.4.7)
Requirement already satisfied: six>=1.9 in /opt/conda/lib/python3.6/site-packages (from protobuf>=3.1->sagemaker) (1.15.0)
Requirement already satisfied: pytz>=2017.2 in /opt/conda/lib/python3.6/site-packages (from pandas->sagemaker) (2021.1)
Requirement already satisfied: ppft>=1.6.6.4 in /opt/conda/lib/python3.6/site-packages (from pathos->sagemaker) (1.6.6.4)
Requirement already satisfied: dill>=0.3.4 in /opt/conda/lib/python3.6/site-packages (from pathos->sagemaker) (0.3.4)
Requirement already satisfied: pox>=0.3.0 in /opt/conda/lib/python3.6/site-packages (from pathos->sagemaker) (0.3.0)
Requirement already satisfied: multiprocess>=0.70.12 in /opt/conda/lib/python3.6/site-packages (from pathos->sagemaker) (0.70.12.2)
Building wheels for collected packages: sagemaker
  Building wheel for sagemaker (setup.py) ... - \ | done
  Created wheel for sagemaker: filename=sagemaker-2.86.2-py2.py3-none-any.whl size=720848 sha256=6bdcdf3a37fc58a4ab6e046af656df54496209cd995078a1290f00b7c6564e93
  Stored in directory: /root/.cache/pip/wheels/59/43/38/ebab0cc66165586b93249bb62b88af317edd25ecd7885b496b
Successfully built sagemaker
Installing collected packages: botocore, boto3, sagemaker
  Attempting uninstall: botocore
    Found existing installation: botocore 1.23.7
    Uninstalling botocore-1.23.7:
      Successfully uninstalled botocore-1.23.7
  Attempting uninstall: boto3
    Found existing installation: boto3 1.20.7
    Uninstalling boto3-1.20.7:
      Successfully uninstalled boto3-1.20.7
  Attempting uninstall: sagemaker
    Found existing installation: sagemaker 2.69.1.dev0
    Uninstalling sagemaker-2.69.1.dev0:
      Successfully uninstalled sagemaker-2.69.1.dev0
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
awscli 1.22.7 requires botocore==1.23.7, but you have botocore 1.24.42 which is incompatible.
Successfully installed boto3-1.21.42 botocore-1.24.42 sagemaker-2.86.2

Specify the input data shape. For more information, see Prepare Model for Compilation.

[3]:
import sagemaker
import torch
import torchvision.models as models
import tarfile

resnet18 = models.resnet18(pretrained=True)
input_shape = [1, 3, 224, 224]
trace = torch.jit.trace(resnet18.float().eval(), torch.zeros(input_shape).float())
trace.save("model.pth")

with tarfile.open("model.tar.gz", "w:gz") as f:
    f.add("model.pth")
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth

Upload the model archive to S3

Specify parameters for the compilation job and upload the model.tar.gz archive file.

[4]:
import boto3
import sagemaker
import time
from sagemaker.utils import name_from_base

role = sagemaker.get_execution_role()
sess = sagemaker.Session()
region = sess.boto_region_name
bucket = sess.default_bucket()

compilation_job_name = name_from_base("TorchVision-ResNet18-Neo")
prefix = compilation_job_name + "/model"

model_path = sess.upload_data(path="model.tar.gz", key_prefix=prefix)

data_shape = '{"input0":[1,3,224,224]}'
target_device = "ml_c5"
framework = "PYTORCH"
framework_version = "1.6"
compiled_model_path = "s3://{}/{}/output".format(bucket, compilation_job_name)

Invoke Neo Compilation API

Create a PyTorch SageMaker model

Use the PyTorchModel and define parameters including the path to the model, the entry_point script that is used to perform inference, and other version and environment variables.

[5]:
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.predictor import Predictor

sagemaker_model = PyTorchModel(
    model_data=model_path,
    predictor_cls=Predictor,
    framework_version=framework_version,
    role=role,
    sagemaker_session=sess,
    entry_point="resnet18.py",
    source_dir="code",
    py_version="py3",
    env={"MMS_DEFAULT_RESPONSE_TIMEOUT": "500"},
)

Use Neo compiler to compile the model

Run the compilation job, which is saved in S3 at the specified compiled_model_path location.

[6]:
compiled_model = sagemaker_model.compile(
    target_instance_family=target_device,
    input_shape=data_shape,
    job_name=compilation_job_name,
    role=role,
    framework=framework.lower(),
    framework_version=framework_version,
    output_path=compiled_model_path,
)
?????????????????????????????............................................!

Deploy the model

Deploy the compiled model to an endpoint so it can be used for inference.

[7]:
predictor = compiled_model.deploy(initial_instance_count=1, instance_type="ml.c5.9xlarge")
----!

Send requests

Let’s send a picture to the endpoint to predict the image subject.

title

Open the image and pass the payload as a bytearray to the predictor, receiving a response.

[8]:
import numpy as np
import json

with open("https://raw.githubusercontent.com/aws/amazon-sagemaker-examples/main/sagemaker_neo_compilation_jobs/pytorch_torchvision/cat.jpg", "rb") as f:
    payload = f.read()
    payload = bytearray(payload)

response = predictor.predict(payload)
result = json.loads(response.decode())
print("Most likely class: {}".format(np.argmax(result)))
Most likely class: 282

Use the ImageNet class ID response to look up which subject the image contains, and with what probability.

[9]:
# Load names for ImageNet classes
object_categories = {}
with open("imagenet1000_clsidx_to_labels.txt", "r") as f:
    for line in f:
        key, val = line.strip().split(":")
        object_categories[key] = val.strip(" ").strip(",")
print(
    "The label is",
    object_categories[str(np.argmax(result))],
    "with probability",
    str(np.amax(result))[:5],
)
The label is 'tiger cat' with probability 0.645

Delete the Endpoint

Delete the endpoint to avoid incurring costs now that it is no longer needed.

[10]:
predictor.delete_model()
sess.delete_endpoint(predictor.endpoint_name)

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable