Train an MNIST model with TensorFlow

MNIST is a widely-used dataset for handwritten digit classification. It consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits). This tutorial will show how to train a TensorFlow V2 model on MNIST model on SageMaker.

Runtime

This notebook takes approximately 5 minutes to run.

Contents

  1. TensorFlow Estimator

  2. Implement the training entry point

  3. Set hyperparameters

  4. Set up channels for training and testing data

  5. Run the training script on SageMaker

  6. Inspect and store model data

  7. Test and debug the entry point before running the training container

[2]:
import os
import json

import sagemaker
from sagemaker.tensorflow import TensorFlow
from sagemaker import get_execution_role

sess = sagemaker.Session()

role = get_execution_role()

output_path = "s3://" + sess.default_bucket() + "/DEMO-tensorflow/mnist"

TensorFlow Estimator

The TensorFlow class allows you to run your training script on SageMaker infrastracture in a containerized environment. In this notebook, we refer to this container as the “training container.”

Configure it with the following parameters to set up the environment:

  • entry_point: A user-defined Python file used by the training container as the instructions for training. We will further discuss this file in the next subsection.

  • role: An IAM role to make AWS service requests

  • instance_type: The type of SageMaker instance to run your training script. Set it to local if you want to run the training job on the SageMaker instance you are using to run this notebook.

  • model_dir: S3 bucket URI where the checkpoint data and models can be exported to during training (default: None). To disable having model_dir passed to your training script, set model_dir=False

  • instance_count: The number of instances to run your training job on. Multiple instances are needed for distributed training.

  • output_path: the S3 bucket URI to save training output (model artifacts and output files).

  • framework_version: The TensorFlow version to use.

  • py_version: The Python version to use.

For more information, see the EstimatorBase API reference.

Implement the training entry point

The entry point for training is a Python script that provides all the code for training a TensorFlow model. It is used by the SageMaker TensorFlow Estimator (TensorFlow class above) as the entry point for running the training job.

Under the hood, SageMaker TensorFlow Estimator downloads a docker image with runtime environments specified by the parameters to initiate the estimator class and it injects the training script into the docker image as the entry point to run the container.

In the rest of the notebook, we use training image to refer to the docker image specified by the TensorFlow Estimator and training container to refer to the container that runs the training image.

This means your training script is very similar to a training script you might run outside Amazon SageMaker, but it can access the useful environment variables provided by the training image. See the complete list of environment variables for a complete description of all environment variables your training script can access.

In this example, we use the training script code/train.py as the entry point for our TensorFlow Estimator.

[3]:
!pygmentize 'code/train.py'
from __future__ import print_function

import argparse
import gzip
import json
import logging
import os
import traceback

import numpy as np
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Conv2D, Dense, Flatten

logging.basicConfig(level=logging.DEBUG)

# Define the model object


class SmallConv(Model):
    def __init__(self):
        super(SmallConv, self).__init__()
        self.conv1 = Conv2D(32, 3, activation="relu")
        self.flatten = Flatten()
        self.d1 = Dense(128, activation="relu")
        self.d2 = Dense(10)

    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)


# Decode and preprocess data
def convert_to_numpy(data_dir, images_file, labels_file):
    """Byte string to numpy arrays"""
    with gzip.open(os.path.join(data_dir, images_file), "rb") as f:
        images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)

    with gzip.open(os.path.join(data_dir, labels_file), "rb") as f:
        labels = np.frombuffer(f.read(), np.uint8, offset=8)

    return (images, labels)


def mnist_to_numpy(data_dir, train):
    """Load raw MNIST data into numpy array

    Args:
        data_dir (str): directory of MNIST raw data.
            This argument can be accessed via SM_CHANNEL_TRAINING

        train (bool): use training data

    Returns:
        tuple of images and labels as numpy array
    """

    if train:
        images_file = "train-images-idx3-ubyte.gz"
        labels_file = "train-labels-idx1-ubyte.gz"
    else:
        images_file = "t10k-images-idx3-ubyte.gz"
        labels_file = "t10k-labels-idx1-ubyte.gz"

    return convert_to_numpy(data_dir, images_file, labels_file)


def normalize(x, axis):
    eps = np.finfo(float).eps

    mean = np.mean(x, axis=axis, keepdims=True)
    # avoid division by zero
    std = np.std(x, axis=axis, keepdims=True) + eps
    return (x - mean) / std


# Training logic


def train(args):
    # create data loader from the train / test channels
    x_train, y_train = mnist_to_numpy(data_dir=args.train, train=True)
    x_test, y_test = mnist_to_numpy(data_dir=args.test, train=False)

    x_train, x_test = x_train.astype(np.float32), x_test.astype(np.float32)

    # normalize the inputs to mean 0 and std 1
    x_train, x_test = normalize(x_train, (1, 2)), normalize(x_test, (1, 2))

    # expand channel axis
    # tf uses depth minor convention
    x_train, x_test = np.expand_dims(x_train, axis=3), np.expand_dims(x_test, axis=3)

    # normalize the data to mean 0 and std 1
    train_loader = (
        tf.data.Dataset.from_tensor_slices((x_train, y_train))
        .shuffle(len(x_train))
        .batch(args.batch_size)
    )

    test_loader = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(args.batch_size)

    model = SmallConv()
    model.compile()
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=args.learning_rate, beta_1=args.beta_1, beta_2=args.beta_2
    )

    train_loss = tf.keras.metrics.Mean(name="train_loss")
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="train_accuracy")

    test_loss = tf.keras.metrics.Mean(name="test_loss")
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="test_accuracy")

    @tf.function
    def train_step(images, labels):
        with tf.GradientTape() as tape:
            predictions = model(images, training=True)
            loss = loss_fn(labels, predictions)
        grad = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grad, model.trainable_variables))

        train_loss(loss)
        train_accuracy(labels, predictions)
        return

    @tf.function
    def test_step(images, labels):
        predictions = model(images, training=False)
        t_loss = loss_fn(labels, predictions)
        test_loss(t_loss)
        test_accuracy(labels, predictions)
        return

    print("Training starts ...")
    for epoch in range(args.epochs):
        train_loss.reset_states()
        train_accuracy.reset_states()
        test_loss.reset_states()
        test_accuracy.reset_states()

        for batch, (images, labels) in enumerate(train_loader):
            train_step(images, labels)

        for images, labels in test_loader:
            test_step(images, labels)

        print(
            f"Epoch {epoch + 1}, "
            f"Loss: {train_loss.result()}, "
            f"Accuracy: {train_accuracy.result() * 100}, "
            f"Test Loss: {test_loss.result()}, "
            f"Test Accuracy: {test_accuracy.result() * 100}"
        )

    # Save the model
    # A version number is needed for the serving container
    # to load the model
    version = "00000000"
    ckpt_dir = os.path.join(args.model_dir, version)
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    model.save(ckpt_dir)
    return


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--learning-rate", type=float, default=1e-3)
    parser.add_argument("--beta_1", type=float, default=0.9)
    parser.add_argument("--beta_2", type=float, default=0.999)

    # Environment variables given by the training image
    parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"])
    parser.add_argument("--train", type=str, default=os.environ["SM_CHANNEL_TRAINING"])
    parser.add_argument("--test", type=str, default=os.environ["SM_CHANNEL_TESTING"])

    parser.add_argument("--current-host", type=str, default=os.environ["SM_CURRENT_HOST"])
    parser.add_argument("--hosts", type=list, default=json.loads(os.environ["SM_HOSTS"]))

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    train(args)

Set hyperparameters

In addition, the TensorFlow estimator allows you to parse command line arguments to your training script via hyperparameters.

Note: local mode is not supported in SageMaker Studio.

[4]:
# Set local_mode to be True if you want to run the training script on the machine that runs this notebook

local_mode = False

if local_mode:
    instance_type = "local"
else:
    instance_type = "ml.c4.xlarge"

est = TensorFlow(
    entry_point="train.py",
    source_dir="code",  # directory of your training script
    role=role,
    framework_version="2.3.1",
    model_dir=False,  # don't pass --model_dir to your training script
    py_version="py37",
    instance_type=instance_type,
    instance_count=1,
    volume_size=250,
    output_path=output_path,
    hyperparameters={
        "batch-size": 512,
        "epochs": 1,
        "learning-rate": 1e-3,
        "beta_1": 0.9,
        "beta_2": 0.999,
    },
)

The training container runs your training script like:

python train.py --batch-size 32 --epochs 1 --learning-rate 0.001 --beta_1 0.9 --beta_2 0.999

Set up channels for training and testing data

Tell TensorFlow estimator where to find the training and testing data. It can be a path to an S3 bucket, or a path in your local file system if you use local mode. In this example, we download the MNIST data from a public S3 bucket and upload it to your default bucket.

[5]:
import logging
import boto3
from botocore.exceptions import ClientError

# Download training and testing data from a public S3 bucket


def download_from_s3(data_dir="./data", train=True):
    """Download MNIST dataset and convert it to numpy array

    Args:
        data_dir (str): directory to save the data
        train (bool): download training set

    Returns:
        None
    """

    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    if train:
        images_file = "train-images-idx3-ubyte.gz"
        labels_file = "train-labels-idx1-ubyte.gz"
    else:
        images_file = "t10k-images-idx3-ubyte.gz"
        labels_file = "t10k-labels-idx1-ubyte.gz"

    # download objects
    s3 = boto3.client("s3")
    bucket = f"sagemaker-sample-files"
    for obj in [images_file, labels_file]:
        key = os.path.join("datasets/image/MNIST", obj)
        dest = os.path.join(data_dir, obj)
        if not os.path.exists(dest):
            s3.download_file(bucket, key, dest)
    return


download_from_s3("./data", True)
download_from_s3("./data", False)
[6]:
# Upload to the default bucket

prefix = "DEMO-mnist"
bucket = sess.default_bucket()
loc = sess.upload_data(path="./data", bucket=bucket, key_prefix=prefix)

channels = {"training": loc, "testing": loc}

The keys of the channels dictionary are passed to the training image, and it creates the environment variable SM_CHANNEL_<key name>.

In this example, SM_CHANNEL_TRAINING and SM_CHANNEL_TESTING are created in the training image (see how code/train.py accesses these variables). For more information, see: SM_CHANNEL_{channel_name}.

If you want, you can create a channel for validation:

channels = {
    'training': train_data_loc,
    'validation': val_data_loc,
    'test': test_data_loc
}

You can then access this channel within your training script via SM_CHANNEL_VALIDATION.

Run the training script on SageMaker

Now, the training container has everything to run your training script. Start the container by calling the fit() method.

[7]:
est.fit(inputs=channels)
2022-04-18 00:20:12 Starting - Starting the training job...
2022-04-18 00:20:38 Starting - Preparing the instances for trainingProfilerReport-1650241212: InProgress
......
2022-04-18 00:21:54 Downloading - Downloading input data...
2022-04-18 00:22:21 Training - Downloading the training image...
2022-04-18 00:23:07 Training - Training image download completed. Training in progress...2022-04-18 00:23:10.600096: W tensorflow/core/profiler/internal/smprofiler_timeline.cc:460] Initializing the SageMaker Profiler.
2022-04-18 00:23:10.611582: W tensorflow/core/profiler/internal/smprofiler_timeline.cc:105] SageMaker Profiler is not enabled. The timeline writer thread will not be started, future recorded events will be dropped.
2022-04-18 00:23:11.007108: W tensorflow/core/profiler/internal/smprofiler_timeline.cc:460] Initializing the SageMaker Profiler.
2022-04-18 00:23:15,397 sagemaker-training-toolkit INFO     Imported framework sagemaker_tensorflow_container.training
2022-04-18 00:23:15,405 sagemaker-training-toolkit INFO     No GPUs detected (normal if no gpus installed)
2022-04-18 00:23:15,857 sagemaker-training-toolkit INFO     No GPUs detected (normal if no gpus installed)
2022-04-18 00:23:15,878 sagemaker-training-toolkit INFO     No GPUs detected (normal if no gpus installed)
2022-04-18 00:23:15,895 sagemaker-training-toolkit INFO     No GPUs detected (normal if no gpus installed)
2022-04-18 00:23:15,908 sagemaker-training-toolkit INFO     Invoking user script
Training Env:
{
    "additional_framework_parameters": {},
    "channel_input_dirs": {
        "testing": "/opt/ml/input/data/testing",
        "training": "/opt/ml/input/data/training"
    },
    "current_host": "algo-1",
    "framework_module": "sagemaker_tensorflow_container.training:main",
    "hosts": [
        "algo-1"
    ],
    "hyperparameters": {
        "batch-size": 512,
        "beta_1": 0.9,
        "beta_2": 0.999,
        "epochs": 1,
        "learning-rate": 0.001
    },
    "input_config_dir": "/opt/ml/input/config",
    "input_data_config": {
        "testing": {
            "TrainingInputMode": "File",
            "S3DistributionType": "FullyReplicated",
            "RecordWrapperType": "None"
        },
        "training": {
            "TrainingInputMode": "File",
            "S3DistributionType": "FullyReplicated",
            "RecordWrapperType": "None"
        }
    },
    "input_dir": "/opt/ml/input",
    "is_master": true,
    "job_name": "tensorflow-training-2022-04-18-00-20-12-056",
    "log_level": 20,
    "master_hostname": "algo-1",
    "model_dir": "/opt/ml/model",
    "module_dir": "s3://sagemaker-us-west-2-000000000000/tensorflow-training-2022-04-18-00-20-12-056/source/sourcedir.tar.gz",
    "module_name": "train",
    "network_interface_name": "eth0",
    "num_cpus": 4,
    "num_gpus": 0,
    "output_data_dir": "/opt/ml/output/data",
    "output_dir": "/opt/ml/output",
    "output_intermediate_dir": "/opt/ml/output/intermediate",
    "resource_config": {
        "current_host": "algo-1",
        "current_instance_type": "ml.c4.xlarge",
        "current_group_name": "homogeneousCluster",
        "hosts": [
            "algo-1"
        ],
        "instance_groups": [
            {
                "instance_group_name": "homogeneousCluster",
                "instance_type": "ml.c4.xlarge",
                "hosts": [
                    "algo-1"
                ]
            }
        ],
        "network_interface_name": "eth0"
    },
    "user_entry_point": "train.py"
}
Environment variables:
SM_HOSTS=["algo-1"]
SM_NETWORK_INTERFACE_NAME=eth0
SM_HPS={"batch-size":512,"beta_1":0.9,"beta_2":0.999,"epochs":1,"learning-rate":0.001}
SM_USER_ENTRY_POINT=train.py
SM_FRAMEWORK_PARAMS={}
SM_RESOURCE_CONFIG={"current_group_name":"homogeneousCluster","current_host":"algo-1","current_instance_type":"ml.c4.xlarge","hosts":["algo-1"],"instance_groups":[{"hosts":["algo-1"],"instance_group_name":"homogeneousCluster","instance_type":"ml.c4.xlarge"}],"network_interface_name":"eth0"}
SM_INPUT_DATA_CONFIG={"testing":{"RecordWrapperType":"None","S3DistributionType":"FullyReplicated","TrainingInputMode":"File"},"training":{"RecordWrapperType":"None","S3DistributionType":"FullyReplicated","TrainingInputMode":"File"}}
SM_OUTPUT_DATA_DIR=/opt/ml/output/data
SM_CHANNELS=["testing","training"]
SM_CURRENT_HOST=algo-1
SM_MODULE_NAME=train
SM_LOG_LEVEL=20
SM_FRAMEWORK_MODULE=sagemaker_tensorflow_container.training:main
SM_INPUT_DIR=/opt/ml/input
SM_INPUT_CONFIG_DIR=/opt/ml/input/config
SM_OUTPUT_DIR=/opt/ml/output
SM_NUM_CPUS=4
SM_NUM_GPUS=0
SM_MODEL_DIR=/opt/ml/model
SM_MODULE_DIR=s3://sagemaker-us-west-2-000000000000/tensorflow-training-2022-04-18-00-20-12-056/source/sourcedir.tar.gz
SM_TRAINING_ENV={"additional_framework_parameters":{},"channel_input_dirs":{"testing":"/opt/ml/input/data/testing","training":"/opt/ml/input/data/training"},"current_host":"algo-1","framework_module":"sagemaker_tensorflow_container.training:main","hosts":["algo-1"],"hyperparameters":{"batch-size":512,"beta_1":0.9,"beta_2":0.999,"epochs":1,"learning-rate":0.001},"input_config_dir":"/opt/ml/input/config","input_data_config":{"testing":{"RecordWrapperType":"None","S3DistributionType":"FullyReplicated","TrainingInputMode":"File"},"training":{"RecordWrapperType":"None","S3DistributionType":"FullyReplicated","TrainingInputMode":"File"}},"input_dir":"/opt/ml/input","is_master":true,"job_name":"tensorflow-training-2022-04-18-00-20-12-056","log_level":20,"master_hostname":"algo-1","model_dir":"/opt/ml/model","module_dir":"s3://sagemaker-us-west-2-000000000000/tensorflow-training-2022-04-18-00-20-12-056/source/sourcedir.tar.gz","module_name":"train","network_interface_name":"eth0","num_cpus":4,"num_gpus":0,"output_data_dir":"/opt/ml/output/data","output_dir":"/opt/ml/output","output_intermediate_dir":"/opt/ml/output/intermediate","resource_config":{"current_group_name":"homogeneousCluster","current_host":"algo-1","current_instance_type":"ml.c4.xlarge","hosts":["algo-1"],"instance_groups":[{"hosts":["algo-1"],"instance_group_name":"homogeneousCluster","instance_type":"ml.c4.xlarge"}],"network_interface_name":"eth0"},"user_entry_point":"train.py"}
SM_USER_ARGS=["--batch-size","512","--beta_1","0.9","--beta_2","0.999","--epochs","1","--learning-rate","0.001"]
SM_OUTPUT_INTERMEDIATE_DIR=/opt/ml/output/intermediate
SM_CHANNEL_TESTING=/opt/ml/input/data/testing
SM_CHANNEL_TRAINING=/opt/ml/input/data/training
SM_HP_BATCH-SIZE=512
SM_HP_BETA_1=0.9
SM_HP_BETA_2=0.999
SM_HP_EPOCHS=1
SM_HP_LEARNING-RATE=0.001
PYTHONPATH=/opt/ml/code:/usr/local/bin:/usr/local/lib/python37.zip:/usr/local/lib/python3.7:/usr/local/lib/python3.7/lib-dynload:/usr/local/lib/python3.7/site-packages
Invoking script with the following command:
/usr/local/bin/python3.7 train.py --batch-size 512 --beta_1 0.9 --beta_2 0.999 --epochs 1 --learning-rate 0.001
Training starts ...
[2022-04-18 00:23:19.979 ip-10-0-214-241.us-west-2.compute.internal:26 INFO utils.py:27] RULE_JOB_STOP_SIGNAL_FILENAME: None
[2022-04-18 00:23:20.324 ip-10-0-214-241.us-west-2.compute.internal:26 INFO profiler_config_parser.py:102] User has disabled profiler.
[2022-04-18 00:23:20.325 ip-10-0-214-241.us-west-2.compute.internal:26 INFO json_config.py:91] Creating hook from json_config at /opt/ml/input/config/debughookconfig.json.
[2022-04-18 00:23:20.326 ip-10-0-214-241.us-west-2.compute.internal:26 INFO hook.py:199] tensorboard_dir has not been set for the hook. SMDebug will not be exporting tensorboard summaries.
[2022-04-18 00:23:20.327 ip-10-0-214-241.us-west-2.compute.internal:26 INFO hook.py:253] Saving to /opt/ml/output/tensors
[2022-04-18 00:23:20.327 ip-10-0-214-241.us-west-2.compute.internal:26 INFO state_store.py:75] The checkpoint config file /opt/ml/input/config/checkpointconfig.json does not exist.
[2022-04-18 00:23:20.327 ip-10-0-214-241.us-west-2.compute.internal:26 INFO hook.py:413] Monitoring the collections: sm_metrics, metrics, losses
Epoch 1, Loss: 0.2639467418193817, Accuracy: 91.82500457763672, Test Loss: 0.10326800495386124, Test Accuracy: 97.0
2022-04-18 00:23:16.325876: W tensorflow/core/profiler/internal/smprofiler_timeline.cc:460] Initializing the SageMaker Profiler.
2022-04-18 00:23:16.326045: W tensorflow/core/profiler/internal/smprofiler_timeline.cc:105] SageMaker Profiler is not enabled. The timeline writer thread will not be started, future recorded events will be dropped.
2022-04-18 00:23:16.359480: W tensorflow/core/profiler/internal/smprofiler_timeline.cc:460] Initializing the SageMaker Profiler.
WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
2022-04-18 00:23:40.983685: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /opt/ml/model/00000000/assets
INFO:tensorflow:Assets written to: /opt/ml/model/00000000/assets
2022-04-18 00:23:41,766 sagemaker-training-toolkit INFO     Reporting training SUCCESS

2022-04-18 00:23:58 Uploading - Uploading generated training model
2022-04-18 00:23:58 Completed - Training job completed
Training seconds: 133
Billable seconds: 133

Inspect and store model data

Now, the training is finished, and the model artifact has been saved in the output_path.

[8]:
tf_mnist_model_data = est.model_data
print("Model artifact saved at:\n", tf_mnist_model_data)
Model artifact saved at:
 s3://sagemaker-us-west-2-000000000000/DEMO-tensorflow/mnist/tensorflow-training-2022-04-18-00-20-12-056/output/model.tar.gz

We store the variable tf_mnist_model_data in the current notebook kernel.

[9]:
%store tf_mnist_model_data
Stored 'tf_mnist_model_data' (str)

Test and debug the entry point before running the training container

The entry point code/train.py provided here has been tested and it can be runs in the training container. When you develop your own training script, it is a good practice to simulate the container environment in the local shell and test it before sending it to SageMaker, because debugging in a containerized environment is rather cumbersome. The following script shows how you can test your training script:

[10]:
!pygmentize code/test_train.py
import json
import os
import sys

import boto3
from train import parse_args, train

dirname = os.path.dirname(os.path.abspath(__file__))

with open(os.path.join(dirname, "config.json"), "r") as f:
    CONFIG = json.load(f)


def download_from_s3(data_dir="/tmp/data", train=True):
    """Download MNIST dataset and convert it to numpy array
    Args:
        data_dir (str): directory to save the data
        train (bool): download training set
    Returns:
        tuple of images and labels as numpy arrays
    """

    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    if train:
        images_file = "train-images-idx3-ubyte.gz"
        labels_file = "train-labels-idx1-ubyte.gz"
    else:
        images_file = "t10k-images-idx3-ubyte.gz"
        labels_file = "t10k-labels-idx1-ubyte.gz"

    # download objects
    s3 = boto3.client("s3")
    bucket = CONFIG["public_bucket"]
    for obj in [images_file, labels_file]:
        key = os.path.join("datasets/image/MNIST", obj)
        dest = os.path.join(data_dir, obj)
        if not os.path.exists(dest):
            s3.download_file(bucket, key, dest)
    return


class Env:
    def __init__(self):
        # simulate container env
        os.environ["SM_MODEL_DIR"] = "/tmp/tf/model"
        os.environ["SM_CHANNEL_TRAINING"] = "/tmp/data"
        os.environ["SM_CHANNEL_TESTING"] = "/tmp/data"
        os.environ["SM_HOSTS"] = '["algo-1"]'
        os.environ["SM_CURRENT_HOST"] = "algo-1"
        os.environ["SM_NUM_GPUS"] = "0"


if __name__ == "__main__":
    Env()
    download_from_s3()
    download_from_s3(train=False)
    args = parse_args()
    train(args)

Conclusion

In this notebook, we trained a TensorFlow model on the MNIST dataset by fitting a SageMaker estimator. For next steps on how to deploy the trained model and perform inference, see Deploy a Trained TensorFlow V2 Model.