Training and Deploying ML Models using JAX on SageMaker


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


Amazon SageMaker provides you the flexibility to train models using our pre-built machine learning containers or your own bespoke container. We’ll refer to these strategies as Bring-Your-Own-Script (BYOS) and Bring-Your-Own-Container (BYOC) in this tutorial.

Bring Your Own JAX Script

In this notebook, we’ll show how to extend our optimized TensorFlow containers to train machine learning models using the increasingly popular JAX library. We’ll train a fashion MNIST classification model using vanilla JAX, another using jax.experimental.stax, and a final model using the higher level Trax library.

For all three patterns, we’ll show how the JAX models can be serialized as standard TensorFlow SavedModel format. This enables us to seamlessly deploy the models using the managed and optimized SageMaker TensorFlow inference containers.

Bring Your Own JAX Container

We’ve included a dockerfile in this repo directory to show how you can build your own bespoke JAX container with support for GPUs on SageMaker. Unfortunately, the NVIDIA/CUDA Dockerhub containers have a deletion policy, so we’re unable to assert that the container can be built through time. Nonetheless, you can trivially adapt a newer version of the container if your workload requires a custom container. For more information on running BYOC on SageMaker see the documentation.

[ ]:
%pip install --upgrade sagemaker
[ ]:
import sagemaker
from sagemaker import get_execution_role
from sagemaker.tensorflow import TensorFlow

role = get_execution_role()

Installing JAX in SageMaker TensorFlow Containers

When using BYOS with managed SageMaker containers, you can trivially install extra dependencies by providing a requirements.txt within the source_dir that contains your training scripts. At runtime these dependencies will be installed prior to executing the training script, so we can utilize our optimized TensorFlow GPU container to utilize JAX with CUDA support.

To be specific, any container that has the sagemaker-training-toolkit supports installing additional dependencies from requirements.txt

Serializing models as SavedModel format

In the upcoming training jobs we’ll be training a vanilla JAX model, a Stax model, and a Trax model on the fashion MNIST dataset. The full details of the model can be seen in the training_scripts/ directory, but it is worth calling out the methods for serialization.

The JAX/Stax models utilize the new jax2tf converter: https://github.com/google/jax/tree/master/jax/experimental/jax2tf

def save_model_tf(prediction_function, params_to_save):
    tf_fun = jax2tf.convert(prediction_function, enable_xla=False)
    param_vars = tf.nest.map_structure(lambda param: tf.Variable(param), params_to_save)

    tf_graph = tf.function(
        lambda inputs: tf_fun(param_vars, inputs),
        autograph=False,
        jit_compile=False,
    )

The Trax model utilizes the new trax2keras functionality: https://github.com/google/trax/blob/master/trax/trax2keras.py

def save_model_tf(model_to_save):
    """
    Serialize a TensorFlow graph from trained Trax Model
    :param model_to_save: Trax Model
    """
    keras_layer = trax.AsKeras(model_to_save, batch_size=1)
    inputs = tf.keras.Input(shape=(28, 28, 1))
    hidden = keras_layer(inputs)

    keras_model = tf.keras.Model(inputs=inputs, outputs=hidden)
    keras_model.save("/opt/ml/model/1", save_format="tf")

Train using Vanilla JAX

Note: Our source_dir directory contains a requirements.txt that will install JAX with CUDA support

[ ]:
vanilla_jax_estimator = TensorFlow(
    role=role,
    instance_count=1,
    base_job_name="jax",
    framework_version="2.10",
    py_version="py39",
    source_dir="training_scripts",
    entry_point="train_jax.py",
    instance_type="ml.p3.2xlarge",
    hyperparameters={"num_epochs": 3},
)
vanilla_jax_estimator.fit(logs=False)

Train Using JAX Medium-level API Stax

[ ]:
stax_estimator = TensorFlow(
    role=role,
    instance_count=1,
    base_job_name="stax",
    framework_version="2.10",
    py_version="py39",
    source_dir="training_scripts",
    entry_point="train_stax.py",
    instance_type="ml.p3.2xlarge",
    hyperparameters={"num_epochs": 3},
)

stax_estimator.fit(logs=False)

Train Using JAX High-level API Trax

[ ]:
trax_estimator = TensorFlow(
    role=role,
    instance_count=1,
    base_job_name="trax",
    framework_version="2.10",
    py_version="py39",
    source_dir="training_scripts",
    entry_point="train_trax.py",
    instance_type="ml.p3.2xlarge",
    hyperparameters={"train_steps": 1000},
)


trax_estimator.fit(logs=False)

Deploy Models to managed TF Containers

Since we’ve serialized the models as TensorFlow SavedModel format, deploying these models as endpoints is just a trivial call to the estimator.deploy() method

[ ]:
vanilla_jax_predictor = vanilla_jax_estimator.deploy(
    initial_instance_count=1, instance_type="ml.m4.xlarge"
)
[ ]:
trax_predictor = trax_estimator.deploy(initial_instance_count=1, instance_type="ml.m4.xlarge")
[ ]:
stax_predictor = stax_estimator.deploy(initial_instance_count=1, instance_type="ml.m4.xlarge")

Test Inference Endpoints

This requires TF to be installed on your notebook’s kernel as it is used to load testing data

[ ]:
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
[ ]:
def test_image(predictor, test_images, test_labels, image_number):
    np_img = np.expand_dims(np.expand_dims(test_images[image_number], axis=-1), axis=0)

    result = predictor.predict(np_img)
    pred_y = np.argmax(result["predictions"])

    print("True Label:", test_labels[image_number])
    print("Predicted Label:", pred_y)
    plt.imshow(test_images[image_number])
[ ]:
test_image(vanilla_jax_predictor, x_test, y_test, 0)
[ ]:
test_image(stax_predictor, x_test, y_test, 0)
[ ]:
test_image(trax_predictor, x_test, y_test, 0)

Optional: Delete the running endpoints

[ ]:
# Clean-Up
vanilla_jax_predictor.delete_endpoint()
stax_predictor.delete_endpoint()
trax_predictor.delete_endpoint()

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable