Training and Deploying ML Models using JAX on SageMaker

Amazon SageMaker provides you the flexibility to train models using any framework that can work in a Docker container. In this example we’ll show how to utilize the Bring-Your-Own-Container (BYOC) paradigm to train machine learning models using the increasingly popular JAX library from Google. We’ll train a fashion mnist classification model using vanilla JAX, another using jax.experimental.stax, and a final model using the higher level Trax library from Google.

For both of these demos, we’ll show how both JAX and Trax can serialize models using the TensorFlow standard SavedModel format. This enables us to train these models in a custom container, but then deploy them using the managed and optimized SageMaker TensorFlow inference containers.

[ ]:
import os
import json

import boto3
import sagemaker
from sagemaker import get_execution_role

from sagemaker_jax import JaxEstimator

client = boto3.client("sts")
account = client.get_caller_identity()["Account"]
role = get_execution_role()
my_session = boto3.session.Session()
region = my_session.region_name

container_name = "sagemaker-jax"
ecr_image = "{}.dkr.ecr.{}.amazonaws.com/{}".format(account, region, container_name)

Custom Framework Estimator

Since we’ll be saving our JAX and Trax models as SavedModel format, we can create a subclass of the base SageMaker Framework estimator. This will enable us to specify a custom create_model method which leverages the existing TensorFlowModel class to launch inference containers

[ ]:
!pygmentize sagemaker_jax.py

Training Docker Container

Our custom training container is straight forward, though there are a few things worth mentioning that can be seen in the comments

[ ]:
!cat docker/Dockerfile

Building and Publishing the Image

The below shell script must be run if the docker image has not already been pushed to the Elastic Container Registry.

NOTE: Since SageMaker studio is already running inside a Docker container, this script cannot be run inside SageMaker Studio. Please push your container using awscli or use this toolkit: https://github.com/aws-samples/sagemaker-studio-image-build-cli

[ ]:
# %%sh

# container_name=sagemaker-jax
# account=$(aws sts get-caller-identity --query Account --output text)

# # Get the region defined in the current configuration (default to us-west-2 if none defined)
# region=$(aws configure get region)
# region=${region:-us-west-2}

# fullname="${account}.dkr.ecr.${region}.amazonaws.com/${container_name}"

# # If the repository doesn't exist in ECR, create it.
# aws ecr describe-repositories --repository-names "${container_name}" > /dev/null 2>&1
# if [ $? -ne 0 ]
# then
#     aws ecr create-repository --repository-name "${container_name}" > /dev/null
# fi

# # Get the login command from ECR and execute it directly
# $(aws ecr get-login --region ${region} --no-include-email)

# # Build the docker image locally with the image name and then push it to ECR
# # with the full name.
# docker build  -t ${container_name} docker/
# docker tag ${container_name} ${fullname}

# docker push ${fullname}

Serializing models as SavedModel format

In the upcoming training jobs we’ll be training a vanilla JAX model, a Stax model, and a Trax model on the fashion mnist dataset. The full details of the model can be seen in the training_scripts/ directory, but it is worth calling out the methods for serialization.

The JAX model utilizes the new experimental jax2tf converter: https://github.com/google/jax/tree/master/jax/experimental/jax2tf

The Trax model utilizes the new trax2keras functionality: https://github.com/google/trax/blob/master/trax/trax2keras.py

Train using Vanilla JAX

[ ]:
vanilla_jax_estimator = JaxEstimator(
    image_uri=ecr_image,
    role=role,
    instance_count=1,
    base_job_name=container_name + "-jax",
    source_dir="training_scripts",
    entry_point="train_jax.py",
    instance_type="ml.p2.xlarge",
    hyperparameters={"num_epochs": 3},
)
vanilla_jax_estimator.fit(logs=False)

Train Using JAX Medium-level API Stax

[ ]:
stax_estimator = JaxEstimator(
    image_uri=ecr_image,
    role=role,
    instance_count=1,
    base_job_name=container_name + "-jax",
    source_dir="training_scripts",
    entry_point="train_stax.py",
    instance_type="ml.p2.xlarge",
    hyperparameters={"num_epochs": 3},
)

stax_estimator.fit(logs=False)

Train Using JAX High-level API Trax

[ ]:
trax_estimator = JaxEstimator(
    image_uri=ecr_image,
    role=role,
    instance_count=1,
    base_job_name=container_name + "-trax",
    source_dir="training_scripts",
    entry_point="train_trax.py",
    instance_type="ml.p2.xlarge",
    hyperparameters={"train_steps": 1000},
)

trax_estimator.fit(logs=False)

Deploy Both Models to prebuilt TF Containers

Since we’ve our customer Framework Estimator knows the models are to be served using TensorFlowModel, deploying these endpoints is just a trivial call to the estimator.deploy() method

[ ]:
vanilla_jax_predictor = vanilla_jax_estimator.deploy(
    initial_instance_count=1, instance_type="ml.m4.xlarge"
)
[ ]:
trax_predictor = trax_estimator.deploy(initial_instance_count=1, instance_type="ml.m4.xlarge")
[ ]:
stax_predictor = stax_estimator.deploy(initial_instance_count=1, instance_type="ml.m4.xlarge")

Test Inference Endpoints

This requires TF to be installed on your notebook’s kernel as it is used to load testing data

[ ]:
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
[ ]:
def test_image(predictor, test_images, test_labels, image_number):
    np_img = np.expand_dims(np.expand_dims(test_images[image_number], axis=-1), axis=0)

    result = predictor.predict(np_img)
    pred_y = np.argmax(result["predictions"])

    print("True Label:", test_labels[image_number])
    print("Predicted Label:", pred_y)
    plt.imshow(test_images[image_number])
[ ]:
test_image(vanilla_jax_predictor, x_test, y_test, 0)
[ ]:
test_image(stax_predictor, x_test, y_test, 0)
[ ]:
test_image(trax_predictor, x_test, y_test, 0)

Optional: Delete the running endpoints

[ ]:
# Clean-Up
vanilla_jax_predictor.delete_endpoint()
stax_predictor.delete_endpoint()
trax_predictor.delete_endpoint()