Train an TensorFlow model with a SageMaker Training Job and track it using SageMaker Experiments


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


This notebook shows how you can use the SageMaker SDK to track a Machine Learning experiment.

We introduce two concepts in this notebook -

  • Experiment: An experiment is a collection of runs. When you initialize a run in your training loop, you include the name of the experiment that the run belongs to. Experiment names must be unique within your AWS account.

  • Run: A run consists of all the inputs, parameters, configurations, and results for one iteration of model training. Initialize an experiment run for tracking a training job with Run().

In this notebook we train a Keras model using the MNIST dataset on a remote SageMaker instance using a training job.

[ ]:
import sys
[ ]:
# update boto3 and sagemaker to ensure latest SDK version
!{sys.executable} -m pip install --upgrade pip
!{sys.executable} -m pip install --upgrade boto3
!{sys.executable} -m pip install --upgrade sagemaker
!{sys.executable} -m pip install --upgrade tensorflow
[ ]:
import os
import boto3
import json
import sagemaker
from sagemaker.session import Session
from sagemaker import get_execution_role
from sagemaker.experiments.run import Run
from sagemaker.utils import unique_name_from_base
[ ]:
sagemaker_session = Session()
boto_sess = boto3.Session()

role = get_execution_role()
default_bucket = sagemaker_session.default_bucket()


sm = boto_sess.client("sagemaker")
region = boto_sess.region_name

Prepare the training script

Here we use a SageMaker Training job to train the model on a remote instance.

[ ]:
!mkdir -p script
[ ]:
%%writefile ./script/train.py

import os

os.system("pip install -U sagemaker")

import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
import pandas as pd
import argparse

from sagemaker.session import Session
from sagemaker.experiments import load_run

import boto3

boto_session = boto3.session.Session(region_name=os.environ["REGION"])
sagemaker_session = Session(boto_session=boto_session)
s3 = boto3.client("s3")


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--dropout", type=float, default=0.01)

    return parser.parse_known_args()


class ExperimentCallback(keras.callbacks.Callback):
    """ """

    def __init__(self, run, model, x_test, y_test):
        """Save params in constructor"""
        self.run = run
        self.model = model
        self.x_test = x_test
        self.y_test = y_test

    def on_epoch_end(self, epoch, logs=None):
        """ """
        keys = list(logs.keys())
        for key in keys:
            self.run.log_metric(name=key, value=logs[key], step=epoch)
            print("{} -> {}".format(key, logs[key]))


def load_data():
    num_classes = 10
    input_shape = (28, 28, 1)

    train_path = "input_train.npy"
    test_path = "input_test.npy"
    train_labels_path = "input_train_labels.npy"
    test_labels_path = "input_test_labels.npy"

    # Load the data and split it between train and test sets
    s3.download_file(
        f"sagemaker-example-files-prod-{os.environ['REGION']}", "datasets/image/MNIST/numpy/input_train.npy", train_path
    )
    s3.download_file(
        f"sagemaker-example-files-prod-{os.environ['REGION']}", "datasets/image/MNIST/numpy/input_test.npy", test_path
    )
    s3.download_file(
        f"sagemaker-example-files-prod-{os.environ['REGION']}",
        "datasets/image/MNIST/numpy/input_train_labels.npy",
        train_labels_path,
    )
    s3.download_file(
        f"sagemaker-example-files-prod-{os.environ['REGION']}",
        "datasets/image/MNIST/numpy/input_test_labels.npy",
        test_labels_path,
    )

    x_train = np.load(train_path)
    x_test = np.load(test_path)
    y_train = np.load(train_labels_path)
    y_test = np.load(test_labels_path)

    # Reshape the arrays
    x_train = np.reshape(x_train, (60000, 28, 28))
    x_test = np.reshape(x_test, (10000, 28, 28))
    y_train = np.reshape(y_train, (60000,))
    y_test = np.reshape(y_test, (10000,))

    # Scale images to the [0, 1] range
    x_train = x_train.astype("float32") / 255
    x_test = x_test.astype("float32") / 255

    # Make sure images have shape (28, 28, 1)
    x_train = np.expand_dims(x_train, -1)
    x_test = np.expand_dims(x_test, -1)
    print("x_train shape:", x_train.shape)
    print(x_train.shape[0], "train samples")
    print(x_test.shape[0], "test samples")

    # convert class vectors to binary class matrices
    y_train = keras.utils.to_categorical(y_train, num_classes)
    y_test = keras.utils.to_categorical(y_test, num_classes)

    return x_train, x_test, y_train, y_test


def main():
    """ """
    args, _ = parse_args()
    print("Args are : ", args)
    num_classes = 10
    input_shape = (28, 28, 1)
    x_train, x_test, y_train, y_test = load_data()

    model = keras.Sequential(
        [
            keras.Input(shape=input_shape),
            layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Flatten(),
            layers.Dropout(args.dropout),
            layers.Dense(num_classes, activation="softmax"),
        ]
    )

    model.summary()

    batch_size = args.batch_size
    epochs = args.epochs

    model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

    ###
    # `load_run` will use the run defined when calling the estimator
    ###
    with load_run(sagemaker_session=sagemaker_session) as run:
        model.fit(
            x_train,
            y_train,
            batch_size=batch_size,
            epochs=epochs,
            validation_split=0.1,
            callbacks=[ExperimentCallback(run, model, x_test, y_test)],
        )

        score = model.evaluate(x_test, y_test, verbose=0)
        print("Test loss:", score[0])
        print("Test accuracy:", score[1])

        run.log_metric(name="Final Test Loss", value=score[0])
        run.log_metric(name="Final Test Accuracy", value=score[1])

        model.save("/opt/ml/model")


if __name__ == "__main__":
    main()

Create an Experiment and launch a training job

[ ]:
from sagemaker.tensorflow.estimator import TensorFlow
from sagemaker.experiments.run import Run

exp_name = "tensorflow-script-mode-experiment"

batch_size = 256
epochs = 5
dropout = 0.1

with Run(
    experiment_name=exp_name,
    sagemaker_session=sagemaker_session,
) as run:
    run.log_parameter("batch_size", batch_size)
    run.log_parameter("epochs", epochs)
    run.log_parameter("dropout", dropout)

    est = TensorFlow(
        entry_point="./script/train.py",
        role=role,
        model_dir=False,
        hyperparameters={"epochs": epochs, "batch_size": batch_size, "dropout": dropout},
        framework_version="2.8",
        py_version="py39",
        instance_type="ml.m5.xlarge",
        instance_count=1,
        keep_alive_period_in_seconds=3600,
        environment={"REGION": region},
    )

    est.fit()
[ ]:
est.model_data

Register the trained model in the Model Registry

This is an optional step users can take if they want to keep track of their models in a central model catalog.

[ ]:
from sagemaker.tensorflow.model import TensorFlowModel


model = TensorFlowModel(model_data=est.model_data, role=role, framework_version="2.8")
[ ]:
model.register(
    model_package_group_name="tensorflow-script-mode-model",
    content_types=["text/csv"],
    inference_instances=["ml.m5.xlarge"],
    transform_instances=["ml.m5.xlarge"],
    response_types=["text/csv"],
    approval_status="PendingManualApproval",
)

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable