Deploy a Trained TensorFlow V2 Model


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


In this notebook, we walk through the process of deploying a trained model to a SageMaker endpoint. If you recently ran the notebook for training with %store% magic, the model_data can be restored. Otherwise, we retrieve the model artifact from a public S3 bucket.

[ ]:
# setups

import os
import json

import sagemaker
from sagemaker.tensorflow import TensorFlowModel
from sagemaker import get_execution_role, Session
import boto3


sess = Session()
role = get_execution_role()

%store -r tf_mnist_model_data


try:
    tf_mnist_model_data
except NameError:
    import json

    # copy a pretrained model from a public bucket to your default bucket
    s3 = boto3.client("s3")
    bucket = f"sagemaker-example-files-prod-{sess.boto_region_name}"
    key = "datasets/image/MNIST/model/tensorflow-training-2020-11-20-23-57-13-077/model.tar.gz"
    s3.download_file(bucket, key, "model.tar.gz")
    tf_mnist_model_data = sess.upload_data(
        path="model.tar.gz", bucket=sess.default_bucket(), key_prefix="model/tensorflow"
    )
    os.remove("model.tar.gz")
[ ]:
print(tf_mnist_model_data)

TensorFlow Model Object

The TensorFlowModel class allows you to define an environment for making inference using your model artifact. Like TensorFlow estimator class we discussed in this notebook for training an Tensorflow model, it is high level API used to set up a docker image for your model hosting service.

Once it is properly configured, it can be used to create a SageMaker endpoint on an EC2 instance. The SageMaker endpoint is a containerized environment that uses your trained model to make inference on incoming data via RESTful API calls.

Some common parameters used to initiate the TensorFlowModel class are: - role: An IAM role to make AWS service requests - model_data: the S3 bucket URI of the compressed model artifact. It can be a path to a local file if the endpoint is to be deployed on the SageMaker instance you are using to run this notebook (local mode) - framework_version: version of the MXNet package to be used - py_version: python version to be used

[ ]:
model = TensorFlowModel(
    role=role,
    model_data=tf_mnist_model_data,
    framework_version="2.3.1",
)

Execute the Inference Container

Once the TensorFlowModel class is initiated, we can call its deploy method to run the container for the hosting service. Some common parameters needed to call deploy methods are:

  • initial_instance_count: the number of SageMaker instances to be used to run the hosting service.

  • instance_type: the type of SageMaker instance to run the hosting service. Set it to local if you want run the hosting service on the local SageMaker instance. Local mode are typically used for debugging.

Note: local mode is not supported in SageMaker Studio

[ ]:
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

# set local_mode to False if you want to deploy on a remote
# SageMaker instance

local_mode = False

if local_mode:
    instance_type = "local"
else:
    instance_type = "ml.c4.xlarge"

predictor = model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
)

Making Predictions Against a SageMaker endpoint

Once you have the Predictor instance returned by model.deploy(...), you can send prediction requests to your endpoints. In this case, the model accepts normalized batch images in depth-minor convention.

[ ]:
# use some dummy inputs
import numpy as np

dummy_inputs = {"instances": np.random.rand(4, 28, 28, 1)}

res = predictor.predict(dummy_inputs)
print(res)

The formats of the input and output data correspond directly to the request and response format of the Predict method in TensorFlow Serving REST API, for example, the key of the array to be parsed to the model in the dummy_inputs needs to be called instances. Moreover, the input data needs to have a batch dimension.

[ ]:
# Uncomment the following lines to see an example that cannot be processed by the endpoint

# dummy_data = {
#    'instances': np.random.rand(28, 28, 1).tolist()
# }
# print(predictor.predict(inputs))

Now, let’s use real MNIST test to test the endpoint. We use helper functions defined in code.utils to download MNIST data set and normalize the input data.

[ ]:
from utils.mnist import mnist_to_numpy, normalize
import random
import matplotlib.pyplot as plt

%matplotlib inline

data_dir = "/tmp/data"
X, _ = mnist_to_numpy(data_dir, train=False)

# randomly sample 16 images to inspect
mask = random.sample(range(X.shape[0]), 16)
samples = X[mask]

# plot the images
fig, axs = plt.subplots(nrows=1, ncols=16, figsize=(16, 1))

for i, splt in enumerate(axs):
    splt.imshow(samples[i])

Since the model accepts normalized input, you will need to normalize the samples before sending it to the endpoint.

[ ]:
samples = normalize(samples, axis=(1, 2))
predictions = predictor.predict(np.expand_dims(samples, 3))["predictions"]  # add channel dim

# softmax to logit
predictions = np.array(predictions, dtype=np.float32)
predictions = np.argmax(predictions, axis=1)
[ ]:
print("Predictions: ", predictions.tolist())

(Optional) Clean up

If you do not plan to use the endpoint, you should delete it to free up some computation resource. If you use local, you will need to manually delete the docker container bounded at port 8080 (the port that listens to the incoming request).

[ ]:
import os

if not local_mode:
    predictor.delete_endpoint()
else:
    os.system("docker container ls | grep 8080 | awk '{print $1}' | xargs docker container rm -f")

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable