Deploy a Trained PyTorch Model


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


In this notebook, we walk through the process of deploying a trained model to a SageMaker endpoint. If you recently ran the notebook for training with %store% magic, the model_data can be restored. Otherwise, we retrieve the model artifact from a public S3 bucket.

[ ]:
# setups

import os
import json

import boto3
import sagemaker
from sagemaker.pytorch import PyTorchModel
from sagemaker import get_execution_role, Session

sess = Session()

role = get_execution_role()

%store -r pt_mnist_model_data

try:
    pt_mnist_model_data
except NameError:
    import json

    # copy a pretrained model from a public public to your default bucket
    s3 = boto3.client("s3")
    bucket = f"sagemaker-example-files-prod-{sess.boto_region_name}"
    key = "datasets/image/MNIST/model/pytorch-training-2020-11-21-22-02-56-203/model.tar.gz"
    s3.download_file(bucket, key, "model.tar.gz")

    # upload to default bucket
    pt_mnist_model_data = sess.upload_data(
        path="model.tar.gz", bucket=sess.default_bucket(), key_prefix="model/pytorch"
    )
[ ]:
print(pt_mnist_model_data)

PyTorch Model Object

The PyTorchModel class allows you to define an environment for making inference using your model artifact. Like the PyTorch class discussed in this notebook for training an PyTorch model, it is a high level API used to set up a docker image for your model hosting service.

Once it is properly configured, it can be used to create a SageMaker endpoint on an EC2 instance. The SageMaker endpoint is a containerized environment that uses your trained model to make inference on incoming data via RESTful API calls.

Some common parameters used to initiate the PyTorchModel class are: - entry_point: A user defined python file to be used by the inference image as handlers of incoming requests - source_dir: The directory of the entry_point - role: An IAM role to make AWS service requests - model_data: the S3 location of the compressed model artifact. It can be a path to a local file if the endpoint is to be deployed on the SageMaker instance you are using to run this notebook (local mode) - framework_version: version of the PyTorch package to be used - py_version: python version to be used

We elaborate on the entry_point below.

[ ]:
model = PyTorchModel(
    entry_point="inference.py",
    source_dir="code",
    role=role,
    model_data=pt_mnist_model_data,
    framework_version="1.5.0",
    py_version="py3",
)

Entry Point for the Inference Image

Your model artifacts pointed by model_data is pulled by the PyTorchModel and it is decompressed and saved in in the docker image it defines. They become regular model checkpoint files that you would produce outside SageMaker. This means in order to use your trained model for serving, you need to tell PyTorchModel class how to a recover a PyTorch model from the static checkpoint.

Also, the deployed endpoint interacts with RESTful API calls, you need to tell it how to parse an incoming request to your model.

These two instructions needs to be defined as two functions in the python file pointed by entry_point.

By convention, we name this entry point file inference.py and we put it in the code directory.

To tell the inference image how to load the model checkpoint, you need to implement a function called model_fn. This function takes one positional argument

  • model_dir: the directory of the static model checkpoints in the inference image.

The return of model_fn is a PyTorch model. In this example, the model_fn looks like:

def model_fn(model_dir):
    model = Net()
    with open(os.path.join(model_dir, "model.pth"), "rb") as f:
        model.load_state_dict(torch.load(f))
    model.to(device).eval()
    return model

Next, you need to tell the hosting service how to handle the incoming data. This includes:

  • How to parse the incoming request

  • How to use the trained model to make inference

  • How to return the prediction to the caller of the service

You do it by implementing 3 functions:

input_fn function

The SageMaker PyTorch model server will invoke the input_fn function in your inference entry point. This function handles data decoding. The input_fn have the following signature:

def input_fn(request_body, request_content_type)

The two positional arguments are: - request_body: the payload of the incoming request - request_content_type: the content type of the incoming request

The return of input_fn is an object that can be passed to predict_fn

In this example, the input_fn looks like:

def input_fn(request_body, request_content_type):
    assert request_content_type=='application/json'
    data = json.loads(request_body)['inputs']
    data = torch.tensor(data, dtype=torch.float32, device=device)
    return data

It requires the request payload is encoded as a json string and it assumes the decoded payload contains a key inputs that maps to the input data to be consumed by the model.

predict_fn

After the inference request has been deserialized by input_fn, the SageMaker PyTorch model server invokes predict_fn on the return value of input_fn.

The predict_fn function has the following signature:

def predict_fn(input_object, model)

The two positional arguments are: - input_object: the return value from input_fn - model: the return value from model_fn

The return of predict_fn is the first argument to be passed to output_fn

In this example, the predict_fn function looks like

def predict_fn(input_object, model):
    with torch.no_grad():
        prediction = model(input_object)
    return prediction

Note that we directly feed the return of input_fn to predict_fn. This means you should invoke the SageMaker PyTorch model server with data that can be readily consumed by the model, i.e. normalized and has batch and channel dimension.

output_fn

After invoking predict_fn, the model server invokes output_fn for data post-process. The output_fn has the following signature:

def output_fn(prediction, content_type)

The two positional arguments are: - prediction: the return value from predict_fn - content_type: the content type of the response

The return of output_fn should be a byte array of data serialized to content_type.

In this example, the output_fn function looks like

def output_fn(predictions, content_type):
    assert content_type == 'application/json'
    res = predictions.cpu().numpy().tolist()
    return json.dumps(res)

After the inference, the function uses content_type to encode the prediction into the content type of the response. In this example, the function requires the caller of the service to accept json string.

For more info on handler functions, check the SageMaker Python SDK document

Execute the inference container

Once the PyTorchModel class is initiated, we can call its deploy method to run the container for the hosting service. Some common parameters needed to call deploy methods are:

  • initial_instance_count: the number of SageMaker instances to be used to run the hosting service.

  • instance_type: the type of SageMaker instance to run the hosting service. Set it to local if you want to run the hosting service on the local SageMaker instance. Local mode is typically used for debugging.

  • serializer: A python callable used to serialize (encode) the request data.

  • deserializer: A python callable used to deserialize (decode) the response data.

Commonly used serializers and deserializers are implemented in sagemaker.serializers and sagemaker.deserializers submodules of the SageMaker Python SDK.

Since in the transform_fn we declared that the incoming requests are json-encoded, we need to use a json serializer, to encode the incoming data into a json string. Also, we declared the return content type to be json string, we need to use a json deserializer to parse the response into an integer, in this case, representing the predicted hand-written digit.

Note: local mode is not supported in SageMaker Studio

[ ]:
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

# set local_mode to False if you want to deploy on a remote
# SageMaker instance

local_mode = False

if local_mode:
    instance_type = "local"
else:
    instance_type = "ml.c4.xlarge"

predictor = model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
)

The predictor we get above can be used to make prediction requests against a SageMaker endpoint. For more information, check the API reference for SageMaker Predictor

Now, let’s test the endpoint with some dummy data.

[ ]:
import random
import numpy as np

dummy_data = {"inputs": np.random.rand(16, 1, 28, 28).tolist()}

In transform_fn, we declared that the parsed data is a python dictionary with a key inputs and its value should be a 1D array of length 784. Hence, the definition of dummy_data.

[ ]:
res = predictor.predict(dummy_data)
[ ]:
print("Predictions:", res)

If the input data does not look exactly like dummy_data, the endpoint will raise an exception. This is because of the stringent way we defined the transform_fn. Let’s test the following example.

[ ]:
dummy_data = [random.random() for _ in range(784)]

When the dummy_data is parsed in transform_fn, it does not have the inputs field, so transform_fn will crash.

[ ]:
# uncomment the following line to make inference on incorrectly formated input data
# res = predictor.predict(dummy_data)

Now, let’s use real MNIST test to test the endpoint. We use helper functions defined in code.utils to download MNIST data set and normalize the input data.

[ ]:
from utils.mnist import mnist_to_numpy, normalize
import random
import matplotlib.pyplot as plt

%matplotlib inline

data_dir = "/tmp/data"
X, Y = mnist_to_numpy(data_dir, train=False)

# randomly sample 16 images to inspect
mask = random.sample(range(X.shape[0]), 16)
samples = X[mask]
labels = Y[mask]
# plot the images
fig, axs = plt.subplots(nrows=1, ncols=16, figsize=(16, 1))

for i, splt in enumerate(axs):
    splt.imshow(samples[i])
[ ]:
print(samples.shape, samples.dtype)

Before we invoke the SageMaker PyTorch model server with samples, we need to do some pre-processing - convert its data type to 32 bit floating point - normalize each channel (only one channel for MNIST) - add a channel dimension

[ ]:
samples = normalize(samples.astype(np.float32), axis=(1, 2))

res = predictor.predict({"inputs": np.expand_dims(samples, axis=1).tolist()})

The response is a list with probability vectors for each sample.

[ ]:
predictions = np.argmax(np.array(res, dtype=np.float32), axis=1).tolist()
print("Predicted digits: ", predictions)

Test and debug the entry point before deployment

When deploying a model to a SageMaker endpoint, it is a good practice to test the entry point. The following snippet shows you how you can test and debug the model_fn and transform_fn you implemented in the entry point for the inference image.

[ ]:
!pygmentize code/test_inference.py

The test function simulates how the inference container works. It pulls the model artifact and loads the model into memory by calling model_fn and parse model_dir to it. When it receives a request, it calls input_fn, predict_fn and output_fn consecutively.

Implementing such a test function helps you to debug the entry point before putting it into the production. If test runs correctly, then you can be certain that if the incoming data and its content type are what they are supposed to be, then the endpoint is going to work as expected.

(Optional) Clean up

If you do not plan to use the endpoint, you should delete it to free up some computation resource. If you use local, you will need to manually delete the docker container bounded at port 8080 (the port that listens to the incoming request).

[ ]:
import os

if not local_mode:
    predictor.delete_endpoint()
else:
    os.system("docker container ls | grep 8080 | awk '{print $1}' | xargs docker container rm -f")

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable