Migrating scripts from Framework Mode to Script Mode


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


This notebook focuses on how to migrate scripts using Framework Mode to Script Mode.

Set up the environment

[ ]:
import os
import subprocess
import sagemaker
from sagemaker import get_execution_role

sagemaker_session = sagemaker.Session()

role = get_execution_role()

Download the MNIST dataset

[ ]:
! mkdir data
import utils
import numpy as np
import tensorflow as tf

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images = np.expand_dims(train_images, axis=-1)
test_images = np.expand_dims(test_images, axis=-1)

utils.convert_to(train_images, train_labels, "train", "data")
utils.convert_to(test_images, test_labels, "test", "data")

Upload the data

We use the sagemaker.Session.upload_data function to upload our datasets to an S3 location. The return value inputs identifies the location – we will use this later when we start the training job.

[ ]:
inputs = sagemaker_session.upload_data(path="data", key_prefix="data/mnist")

Construct an entry point script for training

On this example, we assume that you aready have a Framework Mode training script named mnist.py:

[ ]:
!pygmentize 'mnist.py'

The training script mnist.py include the Framework Mode functions model_fn, train_input_fn, eval_input_fn, and serving_input_fn. We need to create a entrypoint script that uses the functions above to create a tf.estimator:

[ ]:
%%writefile train.py

import argparse

# import original framework mode script
import mnist

import tensorflow as tf

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # read hyperparameters as script arguments
    parser.add_argument("--training_steps", type=int)
    parser.add_argument("--evaluation_steps", type=int)

    args, _ = parser.parse_known_args()

    # creates a tf.Estimator using `model_fn` that saves models to /opt/ml/model
    estimator = tf.estimator.Estimator(model_fn=mnist.model_fn, model_dir="/opt/ml/model")

    # creates parameterless input_fn function required by the estimator
    def input_fn():
        return mnist.train_input_fn(training_dir="/opt/ml/input/data/training", params=None)

    train_spec = tf.estimator.TrainSpec(input_fn, max_steps=args.training_steps)

    # creates parameterless serving_input_receiver_fn function required by the exporter
    def serving_input_receiver_fn():
        return mnist.serving_input_fn(params=None)

    exporter = tf.estimator.LatestExporter(
        "Servo", serving_input_receiver_fn=serving_input_receiver_fn
    )

    # creates parameterless input_fn function required by the evaluation
    def input_fn():
        return mnist.eval_input_fn(training_dir="/opt/ml/input/data/training", params=None)

    eval_spec = tf.estimator.EvalSpec(input_fn, steps=args.evaluation_steps, exporters=exporter)

    # start training and evaluation
    tf.estimator.train_and_evaluate(estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)

Changes in the SageMaker TensorFlow estimator

We need to create a TensorFlow estimator pointing to train.py as the entrypoint:

[ ]:
from sagemaker.tensorflow import TensorFlow

mnist_estimator = TensorFlow(
    entry_point="train.py",
    dependencies=["mnist.py"],
    role=role,
    framework_version="1.15.2",
    hyperparameters={"training_steps": 10, "evaluation_steps": 10},
    py_version="py3",
    train_instance_count=1,
    train_instance_type="local",
)

mnist_estimator.fit(inputs)

Deploy the trained model to prepare for predictions

The deploy() method creates an endpoint (in this case locally) which serves prediction requests in real-time.

[ ]:
mnist_predictor = mnist_estimator.deploy(initial_instance_count=1, instance_type="local")

Invoking the endpoint

[ ]:
import numpy as np

random_sample_idx = np.random.choice(test_images.shape[0], size=10)
mnist_images = test_images[random_sample_idx]
mnist_labels = test_labels[random_sample_idx]

for i in range(10):
    data = mnist_images[i]

    predict_response = mnist_predictor.predict(data)

    print("========================================")
    label = mnist_labels[i]
    print("label is {}".format(label))
    print("prediction is {}".format(predict_response["predictions"][0]["classes"]))

Clean-up

Deleting the local endpoint when you’re finished is important since you can only run one local endpoint at a time.

[ ]:
mnist_estimator.delete_endpoint()

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable