Training and Deploying ML Models using JAX on SageMaker
Amazon SageMaker provides you the flexibility to train models using our pre-built machine learning containers or your own bespoke container. We’ll refer to these strategies as Bring-Your-Own-Script (BYOS) and Bring-Your-Own-Container (BYOC) in this tutorial.
Bring Your Own JAX Script
In this notebook, we’ll show how to extend our optimized TensorFlow containers to train machine learning models using the increasingly popular JAX library. We’ll train a fashion MNIST classification model using vanilla JAX, another using jax.experimental.stax
, and a final model using the higher level Trax library.
For all three patterns, we’ll show how the JAX models can be serialized as standard TensorFlow SavedModel format. This enables us to seamlessly deploy the models using the managed and optimized SageMaker TensorFlow inference containers.
Bring Your Own JAX Container
We’ve included a dockerfile in this repo directory to show how you can build your own bespoke JAX container with support for GPUs on SageMaker. Unfortunately, the NVIDIA/CUDA Dockerhub containers have a deletion policy, so we’re unable to assert that the container can be built through time. Nonetheless, you can trivially adapt a newer version of the container if your workload requires a custom container. For more information on running BYOC on SageMaker see the documentation.
[ ]:
%pip install --upgrade sagemaker
[ ]:
import sagemaker
from sagemaker import get_execution_role
from sagemaker.tensorflow import TensorFlow
role = get_execution_role()
Installing JAX in SageMaker TensorFlow Containers
When using BYOS with managed SageMaker containers, you can trivially install extra dependencies by providing a requirements.txt
within the source_dir
that contains your training scripts. At runtime these dependencies will be installed prior to executing the training script, so we can utilize our optimized TensorFlow GPU container to utilize JAX with CUDA support.
To be specific, any container that has the sagemaker-training-toolkit supports installing additional dependencies from requirements.txt
Serializing models as SavedModel format
In the upcoming training jobs we’ll be training a vanilla JAX model, a Stax model, and a Trax model on the fashion MNIST dataset. The full details of the model can be seen in the training_scripts/
directory, but it is worth calling out the methods for serialization.
The JAX/Stax models utilize the new jax2tf converter: https://github.com/google/jax/tree/master/jax/experimental/jax2tf
def save_model_tf(prediction_function, params_to_save):
tf_fun = jax2tf.convert(prediction_function, enable_xla=False)
param_vars = tf.nest.map_structure(lambda param: tf.Variable(param), params_to_save)
tf_graph = tf.function(
lambda inputs: tf_fun(param_vars, inputs),
autograph=False,
jit_compile=False,
)
The Trax model utilizes the new trax2keras functionality: https://github.com/google/trax/blob/master/trax/trax2keras.py
def save_model_tf(model_to_save):
"""
Serialize a TensorFlow graph from trained Trax Model
:param model_to_save: Trax Model
"""
keras_layer = trax.AsKeras(model_to_save, batch_size=1)
inputs = tf.keras.Input(shape=(28, 28, 1))
hidden = keras_layer(inputs)
keras_model = tf.keras.Model(inputs=inputs, outputs=hidden)
keras_model.save("/opt/ml/model/1", save_format="tf")
Train using Vanilla JAX
Note: Our source_dir
directory contains a requirements.txt
that will install JAX with CUDA support
[ ]:
vanilla_jax_estimator = TensorFlow(
role=role,
instance_count=1,
base_job_name="jax",
framework_version="2.10",
py_version="py39",
source_dir="training_scripts",
entry_point="train_jax.py",
instance_type="ml.p3.2xlarge",
hyperparameters={"num_epochs": 3},
)
vanilla_jax_estimator.fit(logs=False)
Train Using JAX Medium-level API Stax
[ ]:
stax_estimator = TensorFlow(
role=role,
instance_count=1,
base_job_name="stax",
framework_version="2.10",
py_version="py39",
source_dir="training_scripts",
entry_point="train_stax.py",
instance_type="ml.p3.2xlarge",
hyperparameters={"num_epochs": 3},
)
stax_estimator.fit(logs=False)
Train Using JAX High-level API Trax
[ ]:
trax_estimator = TensorFlow(
role=role,
instance_count=1,
base_job_name="trax",
framework_version="2.10",
py_version="py39",
source_dir="training_scripts",
entry_point="train_trax.py",
instance_type="ml.p3.2xlarge",
hyperparameters={"train_steps": 1000},
)
trax_estimator.fit(logs=False)
Deploy Models to managed TF Containers
Since we’ve serialized the models as TensorFlow SavedModel format, deploying these models as endpoints is just a trivial call to the estimator.deploy()
method
[ ]:
vanilla_jax_predictor = vanilla_jax_estimator.deploy(
initial_instance_count=1, instance_type="ml.m4.xlarge"
)
[ ]:
trax_predictor = trax_estimator.deploy(initial_instance_count=1, instance_type="ml.m4.xlarge")
[ ]:
stax_predictor = stax_estimator.deploy(initial_instance_count=1, instance_type="ml.m4.xlarge")
Test Inference Endpoints
This requires TF to be installed on your notebook’s kernel as it is used to load testing data
[ ]:
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
[ ]:
def test_image(predictor, test_images, test_labels, image_number):
np_img = np.expand_dims(np.expand_dims(test_images[image_number], axis=-1), axis=0)
result = predictor.predict(np_img)
pred_y = np.argmax(result["predictions"])
print("True Label:", test_labels[image_number])
print("Predicted Label:", pred_y)
plt.imshow(test_images[image_number])
[ ]:
test_image(vanilla_jax_predictor, x_test, y_test, 0)
[ ]:
test_image(stax_predictor, x_test, y_test, 0)
[ ]:
test_image(trax_predictor, x_test, y_test, 0)
Optional: Delete the running endpoints
[ ]:
# Clean-Up
vanilla_jax_predictor.delete_endpoint()
stax_predictor.delete_endpoint()
trax_predictor.delete_endpoint()
[ ]: