Deploying pre-trained PyTorch VGG19 model with Amazon SageMaker Neo


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable


Amazon SageMaker Neo is API to compile machine learning models to optimize them for our choice of hardward targets. Currently, Neo supports pre-trained PyTorch models from TorchVision. General support for other PyTorch models is forthcoming.

In this example notebook, we will compare the performace of PyTorch pretrained Vgg19_bn model before versus after compilation using Neo.

Pytorch Vgg19_bn model is one of the models that benefits a lot from compilation with Neo. Here we will verify that in end to end compilation and inference on sagemaker endpoints, Neo compiled model can get seven times speedup with no loss in accuracy.

Make sure you selected Python 3 (Data Science) kernel.

[ ]:
%cd /root/amazon-sagemaker-examples/aws_sagemaker_studio/sagemaker_neo_compilation_jobs/pytorch_vgg19_bn

SageMaker SDK >= 2.0 is required for this notebook

[ ]:
import sys

!{sys.executable} -m pip install torch==1.6.0 torchvision==0.7.0
!{sys.executable} -m pip install --upgrade sagemaker

Import VGG19 from TorchVision

We’ll import VGG19_bn model from TorchVision and create a model artifact model.tar.gz:

[ ]:
import torch
import torchvision.models as models
import tarfile
import sagemaker

sagemaker.__version__
[ ]:
vgg19_bn = models.vgg19_bn(pretrained=True)
input_shape = [1, 3, 224, 224]
trace = torch.jit.trace(vgg19_bn.float().eval(), torch.zeros(input_shape).float())
trace.save("model.pth")

with tarfile.open("model.tar.gz", "w:gz") as f:
    f.add("model.pth")

Set up the environment

[ ]:
import boto3
import sagemaker
import time
from sagemaker.utils import name_from_base
from sagemaker import image_uris

role = sagemaker.get_execution_role()
sess = sagemaker.Session()
region = sess.boto_region_name
bucket = sess.default_bucket()

compilation_job_name = name_from_base("TorchVision-vgg19-Neo")
prefix = compilation_job_name + "/model"

model_path = sess.upload_data(path="model.tar.gz", key_prefix=prefix)

data_shape = '{"input0":[1,3,224,224]}'
target_device = "ml_c5"
framework = "pytorch"
framework_version = "1.6"
compiled_model_path = "s3://{}/{}/output".format(bucket, compilation_job_name)

inference_image_uri = image_uris.retrieve(
    f"neo-{framework}", region, framework_version, instance_type=target_device
)

Use sagemaker PyTorchModel to load pretained PyTorch model

[ ]:
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.predictor import Predictor

pt_vgg = PyTorchModel(
    model_data=model_path,
    framework_version=framework_version,
    predictor_cls=Predictor,
    role=role,
    sagemaker_session=sess,
    entry_point="vgg19_bn_uncompiled.py",
    source_dir="code",
    py_version="py3",
    image_uri=inference_image_uri,
)

Deploy the pretrained model to prepare for predictions(the old way)

[ ]:
vgg_predictor = pt_vgg.deploy(initial_instance_count=1, instance_type="ml.c5.9xlarge")

Invoke the endpoint

Let’s test with a cat image.

[ ]:
from IPython.display import Image

Image("cat.jpg")
[ ]:
import json

with open("cat.jpg", "rb") as f:
    payload = f.read()
    payload = bytearray(payload)
[ ]:
import time

start = time.time()
for _ in range(1000):
    output = vgg_predictor.predict(payload)
inference_time = time.time() - start
print("Inference time is " + str(inference_time) + "millisecond")
[ ]:
import numpy as np

result = json.loads(output.decode())
predicted = np.argmax(result)
[ ]:
# Load names for ImageNet classes
object_categories = {}
with open("imagenet1000_clsidx_to_labels.txt", "r") as f:
    for line in f:
        key, val = line.strip().split(":")
        object_categories[key] = val
[ ]:
print("Result: label - " + object_categories[str(predicted)])

Clean-up

Deleting the local endpoint when you’re finished is important since you can only run one local endpoint at a time.

[ ]:
sess.delete_endpoint(vgg_predictor.endpoint_name)

Neo optimization

Create a PyTorch SageMaker model

[ ]:
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.predictor import Predictor

sagemaker_model = PyTorchModel(
    model_data=model_path,
    predictor_cls=Predictor,
    framework_version=framework_version,
    role=role,
    sagemaker_session=sess,
    entry_point="vgg19_bn_compiled.py",
    source_dir="code",
    py_version="py3",
    env={"MMS_DEFAULT_RESPONSE_TIMEOUT": "500"},
)

Use Neo compiler to compile the model

[ ]:
compiled_model = sagemaker_model.compile(
    target_instance_family=target_device,
    input_shape=data_shape,
    job_name=compilation_job_name,
    role=role,
    framework=framework.lower(),
    framework_version=framework_version,
    output_path=compiled_model_path,
)
[ ]:
predictor = compiled_model.deploy(initial_instance_count=1, instance_type="ml.c5.9xlarge")
[ ]:
import time

start = time.time()
for _ in range(1000):
    response = predictor.predict(payload)
neo_inference_time = time.time() - start
print("Neo optimized inference time is " + str(neo_inference_time) + "millisecond")
[ ]:
result = json.loads(response.decode())
print("Most likely class: {}".format(np.argmax(result)))
print(
    "Result: label - "
    + object_categories[str(np.argmax(result))]
    + " probability - "
    + str(np.amax(result))
)
[ ]:
sess.delete_endpoint(predictor.endpoint_name)

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable