Amazon SageMaker Multi-Model Endpoints using PyTorch


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


This notebook works well with SageMaker Studio kernel ``Python 3 (Data Science)``, or SageMaker Notebook Instance kernel ``conda_python3``

With Amazon SageMaker multi-model endpoints, customers can create an endpoint that seamlessly hosts up to thousands of models. These endpoints are well suited to use cases where any one of many models, which can be served from a common inference container, needs to be callable on-demand and where it is acceptable for infrequently invoked models to incur some additional latency. For applications which require consistently low inference latency, a traditional endpoint is still the best choice.

In some cases where the variable latency is tolerable, and cost optimization is more important, customers may also decide to use MMEs for A/B/n testing, in place of the more typical production variant based strategy discussed here.

To demonstrate how multi-model endpoints can be created and used, this notebook provides an example using models trained with the SageMaker PyTorch framework container. We’ll take an A/B scenario for simplicity, training and deploying just two models to our endpoint.

For other MME use cases, you can also refer to:

Contents

  1. The example use case: MNIST

  2. Train multiple models

  3. Check single-model deployment

  4. Create the Multi-Model Endpoint with the SageMaker SDK

  5. Deploy the Multi-Model Endpoint

  6. Dynamically deploying models to the endpoint

  7. Get predictions from the endpoint

  8. Updating a model

  9. Clean up

Before these sections though, we’ll load the libraries needed for this notebook and define some configurations you can edit - for where the data will be saved in Amazon S3:

[ ]:
# Python Built-Ins:
from datetime import datetime
import os
import json
import logging
from tempfile import TemporaryFile
import time

# External Dependencies:
import boto3
from botocore.exceptions import ClientError
import numpy as np
import sagemaker
from sagemaker.multidatamodel import MultiDataModel
from sagemaker.pytorch import PyTorch as PyTorchEstimator, PyTorchModel

smsess = sagemaker.Session()
region = smsess.boto_region_name
role = sagemaker.get_execution_role()

# Configuration:
bucket_name = smsess.default_bucket()
prefix = "mnist/"
output_path = f"s3://{bucket_name}/{prefix[:-1]}"

The example use case: MNIST

MNIST is a widely used dataset for handwritten digit classification. It consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images.

In this example, we download the MNIST data from a public S3 bucket and upload it to your default SageMaker bucket as selected above.

[ ]:
def fetch_sample_data(
    to_bucket: str,
    to_prefix: str,
    from_bucket: str = f"sagemaker-example-files-prod-{region}",
    from_prefix: str = "datasets/image/MNIST",
    dataset: str = "mnist-train",
):
    DATASETS = {
        "mnist-train": ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz"],
        "mnist-test": ["t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"],
    }

    if dataset not in DATASETS:
        raise ValueError(f"dataset '{dataset}' not in known set: {set(DATASETS.keys())}")

    if len(from_prefix) and not from_prefix.endswith("/"):
        from_prefix += "/"
    if len(to_prefix) and not to_prefix.endswith("/"):
        to_prefix += "/"

    s3client = boto3.client("s3")
    for key in DATASETS[dataset]:
        # If you're in the same region as the source bucket, could consider copy_object() instead:
        with TemporaryFile() as ftmp:
            s3client.download_fileobj(from_bucket, f"{from_prefix}{key}", ftmp)
            ftmp.seek(0)
            s3client.upload_fileobj(ftmp, to_bucket, f"{to_prefix}{key}")


train_prefix = f"{prefix}data/train"
fetch_sample_data(to_bucket=bucket_name, to_prefix=train_prefix, dataset="mnist-train")
train_s3uri = f"s3://{bucket_name}/{train_prefix}"
print(f"Uploaded training data to {train_s3uri}")

test_prefix = f"{prefix}data/test"
fetch_sample_data(to_bucket=bucket_name, to_prefix=test_prefix, dataset="mnist-test")
test_s3uri = f"s3://{bucket_name}/{test_prefix}"
print(f"Uploaded training data to {test_s3uri}")
[ ]:
print("Training data:")
!aws s3 ls --recursive $train_s3uri
print("Test data:")
!aws s3 ls --recursive $test_s3uri

Train multiple models

In this following section, we’ll train multiple models on the same dataset, using the SageMaker PyTorch Framework Container.

For a simple example, we’ll just create two models A and B, using the same code but some slightly different hyperparameters between each.

[ ]:
def get_estimator(base_job_name, hyperparam_overrides={}):
    hyperparameters = {
        "batch-size": 128,
        "epochs": 20,
        "learning-rate": 1e-3,
        "log-interval": 100,
    }
    for k, v in hyperparam_overrides.items():
        hyperparameters[k] = v

    return PyTorchEstimator(
        base_job_name=base_job_name,
        entry_point="train.py",
        source_dir="code",  # directory of your training script
        role=role,
        # At the time of writing, this example gives a deployment error in container v1.8.1 with
        # upgraded TorchServe: so specifically setting "1.8.0". But "1.7" and "1.6" should be fine.
        framework_version="1.8.0",
        py_version="py3",
        instance_type="ml.c4.xlarge",
        instance_count=1,
        output_path=output_path,
        hyperparameters=hyperparameters,
    )


estimatorA = get_estimator(base_job_name="mnist-a", hyperparam_overrides={"weight-decay": 1e-4})
estimatorB = get_estimator(base_job_name="mnist-b", hyperparam_overrides={"weight-decay": 1e-2})

By default, calling the SageMaker Python SDK’s Estimator.fit() method waits for the training job to complete, streaming progress information and logs to the notebook.

This is not the only supported configuration though: For example we can also start jobs asynchronously by setting wait=False, or retrospectively wait() on previously started jobs (optionally pulling through the logs).

The below section will kick off both training jobs in parallel, stream the logs from B as it runs, and then wait for A to complete if it hasn’t already.

[ ]:
estimatorA.fit({"training": train_s3uri, "testing": test_s3uri}, wait=False)
print("Started estimator A training in background (logs will not show)")

print("Training estimator B with logs:")
estimatorB.fit({"training": train_s3uri, "testing": test_s3uri})

print("\nWaiting for estimator A to complete:")
estimatorA.latest_training_job.wait(logs=False)

Check single-model deployment

Before trying to set up a multi-model deployment, it may be helpful to quickly check a single model can be deployed and invoked as expected:

[ ]:
modelA = estimatorA.create_model(role=role, source_dir="code", entry_point="inference.py")
[ ]:
predictorA = modelA.deploy(
    initial_instance_count=1,
    instance_type="ml.c5.xlarge",
)
predictorA.serializer = sagemaker.serializers.JSONSerializer()
predictorA.deserializer = sagemaker.deserializers.JSONDeserializer()
[ ]:
def get_dummy_request():
    """Create a dummy predictor.predict example data (16 images of random pixels)"""
    return {"inputs": np.random.rand(16, 1, 28, 28).tolist()}


dummy_data = get_dummy_request()

start_time = time.time()
predicted_value = predictorA.predict(dummy_data)
duration = time.time() - start_time

print(f"Model took {int(duration * 1000):,d} ms")
np.array(predicted_value)[0]

Assuming the test worked, this endpoint is no longer needed so can be disposed:

[ ]:
predictorA.delete_endpoint(delete_endpoint_config=True)

Create the Multi-Model Endpoint with the SageMaker SDK

Create a SageMaker Model from one of the Estimators

Multi-Model Endpoints load models on demand in a shared container, so we’ll first create a Model from any of our estimators to define this runtime:

[ ]:
model = estimatorA.create_model(role=role, source_dir="code", entry_point="inference.py")

Create the Amazon SageMaker MultiDataModel entity

We create the multi-model endpoint using the `MultiDataModel <https://sagemaker.readthedocs.io/en/stable/api/inference/multi_data_model.html>`__ class.

You can create a MultiDataModel by directly passing in a sagemaker.model.Model object - in which case, the Endpoint will inherit information about the image to use, as well as any environmental variables, network isolation, etc., once the MultiDataModel is deployed.

In addition, a MultiDataModel can also be created without explicitly passing a sagemaker.model.Model object. Please refer to the documentation for additional details.

[ ]:
# This is where our MME will read models from on S3.
multi_model_prefix = f"{prefix}multi-model/"
multi_model_s3uri = f"s3://{bucket_name}/{multi_model_prefix}"
print(multi_model_s3uri)
[ ]:
mme = MultiDataModel(
    name="mnist-multi-" + datetime.now().strftime("%Y-%m-%d-%H-%M-%S"),
    model_data_prefix=multi_model_s3uri,
    model=model,  # passing our model
    sagemaker_session=smsess,
)

Deploy the Multi-Model Endpoint

You need to consider the appropriate instance type and number of instances for the projected prediction workload across all the models you plan to host behind your multi-model endpoint. The number and size of the individual models will also drive memory requirements.

[ ]:
try:
    predictor.delete_endpoint(delete_endpoint_config=True)
    print("Deleting previous endpoint...")
    time.sleep(10)
except (NameError, ClientError):
    pass

predictor = mme.deploy(
    initial_instance_count=1,
    instance_type="ml.c5.xlarge",
)
predictor.serializer = sagemaker.serializers.JSONSerializer()
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()

Our endpoint has launched! Let’s look at what models are available to the endpoint!

By ‘available’, what we mean is, what model artifacts are currently stored under the S3 prefix we defined when setting up the MultiDataModel above i.e. model_data_prefix.

Currently, since we have no artifacts (i.e. tar.gz files) stored under our defined S3 prefix, our endpoint, will have no models ‘available’ to serve inference requests.

We will demonstrate how to make models ‘available’ to our endpoint below.

[ ]:
# No models visible!
list(mme.list_models())

Dynamically deploying models to the endpoint

The .add_model() method of the MultiDataModel will copy over our model artifacts from where they were initially stored, by training, to where our endpoint will source model artifacts for inference requests.

Note that we can continue using this method, as shown below, to dynamically deploy more models to our live endpoint as required!

model_data_source refers to the location of our model artifact (i.e. where it was deposited on S3 after training completed)

model_data_path is the relative path to the S3 prefix we specified above (i.e. model_data_prefix) where our endpoint will source models for inference requests. Since this is a relative path, we can simply pass the name of what we wish to call the model artifact at inference time.

Note: To directly use training job model.tar.gz outputs as we do here, you’ll need to make sure your training job produces results that:

  • Already include any required inference code in a code/ subfolder, and

  • (If you’re using SageMaker PyTorch containers v1.6+) have been packaged to be compatible with TorchServe.

See the enable_sm_oneclick_deploy() and enable_torchserve_multi_model() functions in src/train.py for notes on this. Alternatively, you can perform the same steps after the fact - to produce a new, serving-ready model.tar.gz from your raw training job result.

[ ]:
for name, est in {"ModelA": estimatorA, "ModelB": estimatorB}.items():
    artifact_path = est.latest_training_job.describe()["ModelArtifacts"]["S3ModelArtifacts"]
    # This is copying over the model artifact to the S3 location for the MME.
    mme.add_model(model_data_source=artifact_path, model_data_path=name)

Our models are ready to invoke!

We can see that the S3 prefix we specified when setting up MultiDataModel now has model artifacts listed. As such, the endpoint can now serve up inference requests for these models.

[ ]:
list(mme.list_models())

Get predictions from the endpoint

Recall that mme.deploy() returns a RealTimePredictor that we saved in a variable called predictor.

That predictor can now be used as usual to request inference - but specifying which model to call:

[ ]:
dummy_data = get_dummy_request()

start_time = time.time()
predicted_value = predictor.predict(dummy_data, target_model="ModelA")
duration = time.time() - start_time

print(f"Model took {int(duration * 1000):,d} ms")
np.array(predicted_value)[0]

Updating a model

To update a model, you would follow the same approach as above and add it as a new model. For example, ModelA-2.

You should avoid overwriting model artifacts in Amazon S3, because the old version of the model might still be loaded in the endpoint’s running container(s) or on the storage volume of instances on the endpoint: This would lead invocations to still use the old version of the model.

Alternatively, you could stop the endpoint and re-deploy a fresh set of models.

Clean up

Endpoints should be deleted when no longer in use, since (per the SageMaker pricing page) they’re billed by time deployed. Here we’ll also delete the endpoint configuration - to keep things tidy.

[ ]:
predictor.delete_endpoint(delete_endpoint_config=True)

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable