Using Amazon Elastic Inference with MXNet on Amazon SageMaker
This notebook demonstrates how to enable and use Amazon Elastic Inference with the prebuilt Amazon SageMaker MXNet images.
Amazon Elastic Inference (EI) is a resource you can attach to your Amazon EC2 instances to accelerate your deep learning (DL) inference workloads. EI allows you to add inference acceleration to an Amazon SageMaker hosted endpoint or Jupyter notebook for a fraction of the cost of using a full GPU instance. For more information please visit: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
This notebook is an adaption of the SageMaker MXNet MNIST notebook, with modifications showing the changes needed to enable and use EI with MXNet on SageMaker.
Using Amazon Elastic Inference with MXNet on Amazon SageMaker
Creating an inference endpoint and attaching an EI accelerator
If you are familiar with SageMaker and already have a trained model, skip ahead to the “Creating an inference endpoint” section
For this example, we use the SageMaker Python SDK, which makes it easy to train and deploy MXNet models. For our MXNet model, we train a simple neural network using the Apache MXNet Module API and the MNIST dataset.
MNIST dataset
The MNIST dataset is widely used for handwritten digit classification, and consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits). The task at hand is to train a model using the 60,000 training images and then test its classification accuracy on the 10,000 test images.
Setup
Let’s start by creating a SageMaker session and specifying the IAM role arn used to give training and hosting access to your data. See the documentation for how to create these. Note, if more than one role is required for notebook instances, training, and/or hosting, please replace the sagemaker.get_execution_role()
with a the appropriate full IAM role arn string(s).
[ ]:
import sagemaker
role = sagemaker.get_execution_role()
The training script
The mnist_ei.py
script provides all the code we need to train and host a SageMaker model. The script also checkpoints the model at the end of every epoch and saves the model graph, params and optimizer state in the folder /opt/ml/checkpoints
. If the folder path does not exist then it will skip checkpointing. The script we will use is adaptated from Apache MXNet MNIST tutorial.
[ ]:
!pygmentize mnist_ei.py
SageMaker’s MXNet estimator class
The SageMaker MXNet
estimator allows us to run single-machine or distributed training in SageMaker, using CPU or GPU-based instances.
When we create the estimator, we pass in the filename of our training script, the name of our IAM execution role, and the S3 locations we defined in the setup section. We also provide a few other parameters. instance_count
and instance_type
determine the number and type of SageMaker instances that are used for the training job. The hyperparameters
parameter is a dict
of values that are passed to your training script. You can see how to access these values in the mnist_ei.py
script above.
For this example, we use one ml.m4.xlarge
instance for our training job.
[ ]:
from sagemaker.mxnet import MXNet
mnist_estimator = MXNet(
entry_point="mnist_ei.py",
role=role,
instance_count=1,
instance_type="ml.m4.xlarge",
framework_version="1.7.0",
py_version="py3",
hyperparameters={"learning-rate": 0.1},
)
Running the training Job
After we’ve constructed our MXNet
object, we can fit it using data stored in S3. In the next cell we run SageMaker training on two input channels: train and test.
During training, SageMaker makes this data stored in S3 available in the local filesystem where the MNIST script is running. The mnist_ei.py
script simply loads the train and test data from disk.
[ ]:
%%time
import boto3
region = boto3.Session().region_name
train_data_location = "s3://sagemaker-sample-data-{}/mxnet/mnist/train".format(region)
test_data_location = "s3://sagemaker-sample-data-{}/mxnet/mnist/test".format(region)
mnist_estimator.fit({"train": train_data_location, "test": test_data_location})
Creating an inference endpoint and attaching an EI accelerator
After training, we call the deploy
method of the MXNet
estimator object to build and deploy an MXNetPredictor
. This creates a Sagemaker endpoint, which is a hosted prediction service that we can use to perform inference.
We pass the following arguments to the deploy
method:
instance_count
- how many instances to back the endpoint.instance_type
- which EC2 instance type to use for the endpoint. For information on supported instance, please check the AWS documentation.accelerator_type
- which EI accelerator type to attach to each of our instances. The supported types of accelerators can be found in the AWS documentation.
How our models are loaded
You should provide your custom model_fn
to use EI accelerator attached to your endpoint. An example of model_fn
implementation is as follows:
def model_fn(model_dir):
ctx = mx.cpu()
sym, args, aux = mx.model.load_checkpoint(os.path.join(model_dir, 'model'), 0)
sym = sym.optimize_for('EIA')
mod = mx.mod.Module(symbol=sym, context=ctx, data_names=data_names, label_names=None)
mod.bind(for_training=False, data_shapes=data_shapes)
mod.set_params(args, aux, allow_missing=True)
return mod
Check mnist_ei.py
above for the specific implementation of model_fn()
in this notebook example.
In EI MXNet 1.5.1 and earlier, the predefined SageMaker MXNet containers have a default model_fn
, which determines how your model is loaded. The default model_fn
loads an MXNet Module object with a context based on the instance type of the endpoint.
If an EI accelerator is attached to your endpoint and a custom model_fn
isn’t provided, then the default model_fn
will load the MXNet Module object. This default model_fn
works with the default save
function. If a custom save
function was defined, then you may need to write a custom model_fn
function. For more information on model_fn
, see this documentation for using MXNet with SageMaker.
For examples on how to load and serve a MXNet Module object explicitly, please see our predefined default ``model_fn` for MXNet <https://github.com/aws/sagemaker-mxnet-serving-container/blob/master/src/sagemaker_mxnet_serving_container/default_inference_handler.py#L36>`__.
[ ]:
%%time
predictor = mnist_estimator.deploy(
initial_instance_count=1, instance_type="ml.m4.xlarge", accelerator_type="ml.eia1.medium"
)
The request handling behavior of the endpoint is determined by the mnist_ei.py
script. In this case, the script doesn’t include any request handling functions, so the endpoint uses the default handlers provided by SageMaker. These default handlers allow us to perform inference on input data encoded as a multi-dimensional JSON array.
Making an inference request
Now that our endpoint is deployed and we have a predictor
object, we can use it to classify handwritten digits.
[ ]:
data = [
[
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
]
]
Now we can use the predictor
object to classify the handwritten digit:
[ ]:
%%time
response = predictor.predict(data)
print("Raw prediction result:")
print(response)
labeled_predictions = list(zip(range(10), response[0]))
print("Labeled predictions: ")
print(labeled_predictions)
labeled_predictions.sort(key=lambda label_and_prob: 1.0 - label_and_prob[1])
print("Most likely answer: {}".format(labeled_predictions[0]))
Delete the endpoint
After you have finished with this example, remember to delete the prediction endpoint.
[ ]:
print("Endpoint name: " + predictor.endpoint_name)
[ ]:
import sagemaker
predictor.delete_endpoint()