TensorFlow BYOM: Train with Custom Training Script, Compile with Neo, and Deploy on SageMaker


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable


In this notebook you will compile a trained model using Amazon SageMaker Neo. This notebook is similar to the TensorFlow MNIST training and serving notebook in terms of its functionality. You will complete the same classification task, however this time you will compile the trained model using the SageMaker Neo API on the backend. SageMaker Neo will optimize your model to run on your choice of hardware. At the end of this notebook you will setup a real-time hosting endpoint in SageMaker for your SageMaker Neo compiled model using the TensorFlow Model Server. Note: This notebooks requires Sagemaker Python SDK v2.x.x or above.

Set up the environment

[ ]:
import os
import sagemaker
from sagemaker import get_execution_role

sagemaker_session = sagemaker.Session()

role = get_execution_role()

Download the MNIST dataset

[ ]:
import utils
from tensorflow.contrib.learn.python.learn.datasets import mnist
import tensorflow as tf

data_sets = mnist.read_data_sets("data", dtype=tf.uint8, reshape=False, validation_size=5000)

utils.convert_to(data_sets.train, "train", "data")
utils.convert_to(data_sets.validation, "validation", "data")
utils.convert_to(data_sets.test, "test", "data")

Upload the data

We use the sagemaker.Session.upload_data function to upload our datasets to an S3 location. The return value inputs identifies the location – we will use this later when we start the training job.

[ ]:
inputs = sagemaker_session.upload_data(path="data", key_prefix="data/DEMO-mnist")

Construct a script for distributed training

Here is the full code for the network model:

[ ]:
!cat 'mnist.py'

The script here is and adaptation of the TensorFlow MNIST example. It provides a model_fn(features, labels, mode), which is used for training, evaluation and inference. See TensorFlow MNIST training and serving notebook for more details about the training script.

Create a training job using the sagemaker.TensorFlow estimator

[ ]:
from sagemaker.tensorflow import TensorFlow

mnist_estimator = TensorFlow(
    entry_point="mnist.py",
    role=role,
    framework_version="1.15.3",
    py_version="py3",
    training_steps=1000,
    evaluation_steps=100,
    instance_count=2,
    instance_type="ml.c4.xlarge",
)

mnist_estimator.fit(inputs)

The ``fit`` method will create a training job in two ml.c4.xlarge instances. The logs above will show the instances doing training, evaluation, and incrementing the number of training steps.

In the end of the training, the training job will generate a saved model for TF serving.

Deploy the trained model to prepare for predictions (the old way)

The deploy() method creates an endpoint which serves prediction requests in real-time.

[ ]:
mnist_predictor = mnist_estimator.deploy(initial_instance_count=1, instance_type="ml.m4.xlarge")

Invoking the endpoint

[ ]:
import numpy as np
import json
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

for i in range(10):
    data = mnist.test.images[i].tolist()
    # Follow https://www.tensorflow.org/tfx/serving/api_rest guide to format input to the model server
    predict_response = mnist_predictor.predict({"instances": np.asarray(data).tolist()})

    print("========================================")
    label = np.argmax(mnist.test.labels[i])
    print("label is {}".format(label))
    prediction = np.argmax(predict_response["predictions"])
    print("prediction is {}".format(prediction))

Deleting the endpoint

[ ]:
sagemaker.Session().delete_endpoint(mnist_predictor.endpoint)

Deploy the trained model using Neo

Now the model is ready to be compiled by Neo to be optimized for our hardware of choice. We are using the TensorFlowEstimator.compile_model method to do this. For this example, our target hardware is 'ml_c5'. You can changed these to other supported target hardware if you prefer.

Compiling the model

The input_shape is the definition for the model’s input tensor and output_path is where the compiled model will be stored in S3. Important. If the following command result in a permission error, scroll up and locate the value of execution role returned by ``get_execution_role()``. The role must have access to the S3 bucket specified in ``output_path``.

[ ]:
output_path = "/".join(mnist_estimator.output_path.split("/")[:-1])
optimized_estimator = mnist_estimator.compile_model(
    target_instance_family="ml_c5",
    input_shape={"data": [1, 784]},  # Batch size 1, 1 channel, 28*28 image size.
    output_path=output_path,
    framework="tensorflow",
    framework_version="1.15.3",
)

Set image uri (Temporarily required)

Image URI: aws_account_id.dkr.ecr.aws_region.amazonaws.com/sagemaker-inference-tensorflow:1.15.3-instance_type-py3

Refer to the table on the bottom here to get aws account id and region mapping

[ ]:
optimized_estimator.image_uri = (
    "301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-inference-tensorflow:1.15.3-cpu-py3"
)

Deploying the compiled model

[ ]:
optimized_predictor = optimized_estimator.deploy(
    initial_instance_count=1, instance_type="ml.c5.xlarge"
)

Invoking the endpoint

[ ]:
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

for i in range(10):
    data = mnist.test.images[i].tolist()
    # Follow https://www.tensorflow.org/tfx/serving/api_rest guide to format input to the model server
    predict_response = optimized_predictor.predict({"instances": np.asarray(data).tolist()})

    print("========================================")
    label = np.argmax(mnist.test.labels[i])
    print("label is {}".format(label))
    prediction = np.argmax(predict_response["predictions"])
    print("prediction is {}".format(prediction))

Deleting endpoint

[ ]:
sagemaker.Session().delete_endpoint(optimized_predictor.endpoint)

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable