MNIST Training with MXNet and Gluon
MNIST is a widely used dataset for handwritten digit classification. It consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits). This tutorial shows how to train and test an MNIST model on SageMaker using MXNet and the Gluon API.
Runtime
This notebook takes approximately 20 minutes to run.
Contents
[2]:
import os
import boto3
import sagemaker
from sagemaker.mxnet import MXNet
from mxnet import gluon
from sagemaker import get_execution_role
sagemaker_session = sagemaker.Session()
role = get_execution_role()
Download training and test data
[3]:
import os
for inner_dir in ["train", "test"]:
data_dir = "./data/{}/".format(inner_dir)
if not os.path.exists(data_dir):
os.makedirs(data_dir)
s3 = boto3.client("s3")
s3.download_file(
"sagemaker-sample-files",
"datasets/image/MNIST/train/train-images-idx3-ubyte.gz",
"./data/train/train-images-idx3-ubyte.gz",
)
s3.download_file(
"sagemaker-sample-files",
"datasets/image/MNIST/train/train-labels-idx1-ubyte.gz",
"./data/train/train-labels-idx1-ubyte.gz",
)
s3.download_file(
"sagemaker-sample-files",
"datasets/image/MNIST/test/t10k-images-idx3-ubyte.gz",
"./data/test/t10k-images-idx3-ubyte.gz",
)
s3.download_file(
"sagemaker-sample-files",
"datasets/image/MNIST/test/t10k-labels-idx1-ubyte.gz",
"./data/test/t10k-labels-idx1-ubyte.gz",
)
Upload the data
We use the sagemaker.Session.upload_data
function to upload our datasets to an S3 location. The return value inputs
identifies the location – we use this later when we start the training job.
[4]:
inputs = sagemaker_session.upload_data(path="data", key_prefix="data/DEMO-mnist")
Implement the training function
We need to provide a training script that can run on the SageMaker platform. The training scripts are essentially the same as one you would write for local training, except that you need to provide a train()
function. The train()
function checks for the validation accuracy at the end of every epoch and checkpoints the best model so far, along with the optimizer state, in the folder /opt/ml/checkpoints
if the folder path exists, else it skips the checkpointing. When SageMaker calls
your function, it passes in arguments that describe the training environment. Check the script below to see how this works.
The script here is an adaptation of the Gluon MNIST example provided by the Apache MXNet project.
[5]:
!cat 'mnist.py'
from __future__ import print_function
import argparse
import json
import logging
import os
import time
import mxnet as mx
import numpy as np
from mxnet import autograd, gluon
from mxnet.gluon import nn
logging.basicConfig(level=logging.DEBUG)
# ------------------------------------------------------------ #
# Training methods #
# ------------------------------------------------------------ #
def train(args):
# SageMaker passes num_cpus, num_gpus and other args we can use to tailor training to
# the current container environment, but here we just use simple cpu context.
ctx = mx.cpu()
# retrieve the hyperparameters we set in notebook (with some defaults)
batch_size = args.batch_size
epochs = args.epochs
learning_rate = args.learning_rate
momentum = args.momentum
log_interval = args.log_interval
num_gpus = int(os.environ["SM_NUM_GPUS"])
current_host = args.current_host
hosts = args.hosts
model_dir = args.model_dir
CHECKPOINTS_DIR = "/opt/ml/checkpoints"
checkpoints_enabled = os.path.exists(CHECKPOINTS_DIR)
# load training and validation data
# we use the gluon.data.vision.MNIST class because of its built in mnist pre-processing logic,
# but point it at the location where SageMaker placed the data files, so it doesn't download them again.
training_dir = args.train
train_data = get_train_data(training_dir + "/train", batch_size)
val_data = get_val_data(training_dir + "/test", batch_size)
# define the network
net = define_network()
# Collect all parameters from net and its children, then initialize them.
net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
# Trainer is for updating parameters with gradient.
if len(hosts) == 1:
kvstore = "device" if num_gpus > 0 else "local"
else:
kvstore = "dist_device_sync" if num_gpus > 0 else "dist_sync"
trainer = gluon.Trainer(
net.collect_params(),
"sgd",
{"learning_rate": learning_rate, "momentum": momentum},
kvstore=kvstore,
)
metric = mx.metric.Accuracy()
loss = gluon.loss.SoftmaxCrossEntropyLoss()
# shard the training data in case we are doing distributed training. Alternatively to splitting in memory,
# the data could be pre-split in S3 and use ShardedByS3Key to do distributed training.
if len(hosts) > 1:
train_data = [x for x in train_data]
shard_size = len(train_data) // len(hosts)
for i, host in enumerate(hosts):
if host == current_host:
start = shard_size * i
end = start + shard_size
break
train_data = train_data[start:end]
net.hybridize()
best_val_score = 0.0
for epoch in range(epochs):
# reset data iterator and metric at begining of epoch.
metric.reset()
btic = time.time()
for i, (data, label) in enumerate(train_data):
# Copy data to ctx if necessary
data = data.as_in_context(ctx)
label = label.as_in_context(ctx)
# Start recording computation graph with record() section.
# Recorded graphs can then be differentiated with backward.
with autograd.record():
output = net(data)
L = loss(output, label)
L.backward()
# take a gradient step with batch_size equal to data.shape[0]
trainer.step(data.shape[0])
# update metric at last.
metric.update([label], [output])
if i % log_interval == 0 and i > 0:
name, acc = metric.get()
print(
"[Epoch %d Batch %d] Training: %s=%f, %f samples/s"
% (epoch, i, name, acc, batch_size / (time.time() - btic))
)
btic = time.time()
name, acc = metric.get()
print("[Epoch %d] Training: %s=%f" % (epoch, name, acc))
name, val_acc = test(ctx, net, val_data)
print("[Epoch %d] Validation: %s=%f" % (epoch, name, val_acc))
# checkpoint the model, params and optimizer states in the folder /opt/ml/checkpoints
if checkpoints_enabled and val_acc > best_val_score:
best_val_score = val_acc
logging.info("Saving the model, params and optimizer state.")
net.export(CHECKPOINTS_DIR + "/%.4f-gluon_mnist" % (best_val_score), epoch)
trainer.save_states(
CHECKPOINTS_DIR + "/%.4f-gluon_mnist-%d.states" % (best_val_score, epoch)
)
if current_host == hosts[0]:
save(net, model_dir)
def save(net, model_dir):
# save the model
net.export("%s/model" % model_dir)
def define_network():
net = nn.HybridSequential()
with net.name_scope():
net.add(nn.Dense(128, activation="relu"))
net.add(nn.Dense(64, activation="relu"))
net.add(nn.Dense(10))
return net
def input_transformer(data, label):
data = data.reshape((-1,)).astype(np.float32) / 255.0
return data, label
def get_train_data(data_dir, batch_size):
return gluon.data.DataLoader(
gluon.data.vision.MNIST(data_dir, train=True, transform=input_transformer),
batch_size=batch_size,
shuffle=True,
last_batch="rollover",
)
def get_val_data(data_dir, batch_size):
return gluon.data.DataLoader(
gluon.data.vision.MNIST(data_dir, train=False, transform=input_transformer),
batch_size=batch_size,
shuffle=False,
)
def test(ctx, net, val_data):
metric = mx.metric.Accuracy()
for data, label in val_data:
data = data.as_in_context(ctx)
label = label.as_in_context(ctx)
output = net(data)
metric.update([label], [output])
return metric.get()
# ------------------------------------------------------------ #
# Hosting methods #
# ------------------------------------------------------------ #
def model_fn(model_dir):
"""
Load the gluon model. Called once when hosting service starts.
:param: model_dir The directory where model files are stored.
:return: a model (in this case a Gluon network)
"""
net = gluon.SymbolBlock.imports(
"%s/model-symbol.json" % model_dir,
["data"],
"%s/model-0000.params" % model_dir,
)
return net
def transform_fn(net, data, input_content_type, output_content_type):
"""
Transform a request using the Gluon model. Called once per request.
:param net: The Gluon model.
:param data: The request payload.
:param input_content_type: The request content type.
:param output_content_type: The (desired) response content type.
:return: response payload and content type.
"""
# we can use content types to vary input/output handling, but
# here we just assume json for both
parsed = json.loads(data)
nda = mx.nd.array(parsed)
output = net(nda)
prediction = mx.nd.argmax(output, axis=1)
response_body = json.dumps(prediction.asnumpy().tolist()[0])
return response_body, output_content_type
# ------------------------------------------------------------ #
# Training execution #
# ------------------------------------------------------------ #
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", type=int, default=100)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--learning-rate", type=float, default=0.1)
parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument("--log-interval", type=float, default=100)
parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"])
parser.add_argument("--train", type=str, default=os.environ["SM_CHANNEL_TRAINING"])
parser.add_argument("--current-host", type=str, default=os.environ["SM_CURRENT_HOST"])
parser.add_argument("--hosts", type=list, default=json.loads(os.environ["SM_HOSTS"]))
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
train(args)
Run the training script on SageMaker
The MXNet
class allows us to run our training function on SageMaker infrastructure. We need to configure it with our training script, an IAM role, the number of training instances, and the training instance type. In this case we run our training job on a single c4.xlarge instance.
[6]:
m = MXNet(
"mnist.py",
role=role,
instance_count=1,
instance_type="ml.c4.xlarge",
framework_version="1.6.0",
py_version="py3",
hyperparameters={
"batch-size": 100,
"epochs": 20,
"learning-rate": 0.1,
"momentum": 0.9,
"log-interval": 100,
},
)
After we’ve constructed our MXNet
object, we fit it using the data we uploaded to S3. SageMaker makes sure our data is available in the local filesystem, so our training script can simply read the data from disk.
[7]:
m.fit(inputs)
2022-04-18 00:07:22 Starting - Starting the training job...
2022-04-18 00:07:45 Starting - Preparing the instances for trainingProfilerReport-1650240442: InProgress
......
2022-04-18 00:08:50 Downloading - Downloading input data......
2022-04-18 00:09:46 Training - Training image download completed. Training in progress..2022-04-18 00:09:48,174 sagemaker-training-toolkit INFO Imported framework sagemaker_mxnet_container.training
2022-04-18 00:09:48,176 sagemaker-training-toolkit INFO No GPUs detected (normal if no gpus installed)
2022-04-18 00:09:48,189 sagemaker_mxnet_container.training INFO MXNet training environment: {'SM_HOSTS': '["algo-1"]', 'SM_NETWORK_INTERFACE_NAME': 'eth0', 'SM_HPS': '{"batch-size":100,"epochs":20,"learning-rate":0.1,"log-interval":100,"momentum":0.9}', 'SM_USER_ENTRY_POINT': 'mnist.py', 'SM_FRAMEWORK_PARAMS': '{}', 'SM_RESOURCE_CONFIG': '{"current_group_name":"homogeneousCluster","current_host":"algo-1","current_instance_type":"ml.c4.xlarge","hosts":["algo-1"],"instance_groups":[{"hosts":["algo-1"],"instance_group_name":"homogeneousCluster","instance_type":"ml.c4.xlarge"}],"network_interface_name":"eth0"}', 'SM_INPUT_DATA_CONFIG': '{"training":{"RecordWrapperType":"None","S3DistributionType":"FullyReplicated","TrainingInputMode":"File"}}', 'SM_OUTPUT_DATA_DIR': '/opt/ml/output/data', 'SM_CHANNELS': '["training"]', 'SM_CURRENT_HOST': 'algo-1', 'SM_MODULE_NAME': 'mnist', 'SM_LOG_LEVEL': '20', 'SM_FRAMEWORK_MODULE': 'sagemaker_mxnet_container.training:main', 'SM_INPUT_DIR': '/opt/ml/input', 'SM_INPUT_CONFIG_DIR': '/opt/ml/input/config', 'SM_OUTPUT_DIR': '/opt/ml/output', 'SM_NUM_CPUS': '4', 'SM_NUM_GPUS': '0', 'SM_MODEL_DIR': '/opt/ml/model', 'SM_MODULE_DIR': 's3://sagemaker-us-west-2-000000000000/mxnet-training-2022-04-18-00-07-22-443/source/sourcedir.tar.gz', 'SM_TRAINING_ENV': '{"additional_framework_parameters":{},"channel_input_dirs":{"training":"/opt/ml/input/data/training"},"current_host":"algo-1","framework_module":"sagemaker_mxnet_container.training:main","hosts":["algo-1"],"hyperparameters":{"batch-size":100,"epochs":20,"learning-rate":0.1,"log-interval":100,"momentum":0.9},"input_config_dir":"/opt/ml/input/config","input_data_config":{"training":{"RecordWrapperType":"None","S3DistributionType":"FullyReplicated","TrainingInputMode":"File"}},"input_dir":"/opt/ml/input","is_master":true,"job_name":"mxnet-training-2022-04-18-00-07-22-443","log_level":20,"master_hostname":"algo-1","model_dir":"/opt/ml/model","module_dir":"s3://sagemaker-us-west-2-000000000000/mxnet-training-2022-04-18-00-07-22-443/source/sourcedir.tar.gz","module_name":"mnist","network_interface_name":"eth0","num_cpus":4,"num_gpus":0,"output_data_dir":"/opt/ml/output/data","output_dir":"/opt/ml/output","output_intermediate_dir":"/opt/ml/output/intermediate","resource_config":{"current_group_name":"homogeneousCluster","current_host":"algo-1","current_instance_type":"ml.c4.xlarge","hosts":["algo-1"],"instance_groups":[{"hosts":["algo-1"],"instance_group_name":"homogeneousCluster","instance_type":"ml.c4.xlarge"}],"network_interface_name":"eth0"},"user_entry_point":"mnist.py"}', 'SM_USER_ARGS': '["--batch-size","100","--epochs","20","--learning-rate","0.1","--log-interval","100","--momentum","0.9"]', 'SM_OUTPUT_INTERMEDIATE_DIR': '/opt/ml/output/intermediate', 'SM_CHANNEL_TRAINING': '/opt/ml/input/data/training', 'SM_HP_BATCH-SIZE': '100', 'SM_HP_EPOCHS': '20', 'SM_HP_LEARNING-RATE': '0.1', 'SM_HP_LOG-INTERVAL': '100', 'SM_HP_MOMENTUM': '0.9'}
2022-04-18 00:09:48,612 sagemaker-training-toolkit INFO No GPUs detected (normal if no gpus installed)
...
DEBUG:root:Writing metric: _RawMetricData(MetricName='softmaxcrossentropyloss0_output_0_GLOBAL',Value=0.017137402668595314,Timestamp=1650240989.8969676,IterationNumber=13500)
[Epoch 19 Batch 300] Training: accuracy=0.995880, 3291.121521 samples/s
[Epoch 19 Batch 400] Training: accuracy=0.995611, 2931.439754 samples/s
[Epoch 19 Batch 500] Training: accuracy=0.995369, 3627.066993 samples/s
[Epoch 19] Training: accuracy=0.994967
[Epoch 19] Validation: accuracy=0.976800
2022-04-18 00:16:44,533 sagemaker-training-toolkit INFO Reporting training SUCCESS
2022-04-18 00:16:58 Uploading - Uploading generated training model
2022-04-18 00:16:58 Completed - Training job completed
Training seconds: 488
Billable seconds: 488
After training, we use the MXNet object to build and deploy an MXNetPredictor object. This creates a SageMaker endpoint that we use to perform inference.
This allows us to perform inference on JSON-encoded multi-dimensional arrays.
[8]:
predictor = m.deploy(initial_instance_count=1, instance_type="ml.m4.xlarge")
------!
We can now use this predictor to classify hand-written digits. Manually drawing into the image box loads the pixel data into a ‘data’ variable in this notebook, which we can then pass to the MXNet predictor.
[9]:
from IPython.display import HTML
HTML(open("input.html").read())
[9]:
|
Fetch the first image from the test dataset and display it.
[10]:
import gzip
import numpy as np
import matplotlib.pyplot as plt
f = gzip.open("data/train/train-images-idx3-ubyte.gz", "r")
image_size = 28
f.read(16)
buf = f.read(image_size * image_size)
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
data = data.reshape(1, image_size, image_size, 1)
image = np.asarray(data).squeeze()
plt.imshow(image)
plt.show()

The predictor runs inference on our input data and returns the predicted digit (as a float value, so we convert to int for display).
[11]:
response = predictor.predict(data)
print(int(response))
5
Cleanup
After you have finished with this example, delete the prediction endpoint to release the instance associated with it.
[12]:
predictor.delete_endpoint()