Train an MNIST model with PyTorch
This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.
MNIST is a widely used dataset for handwritten digit classification. It consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits). This tutorial shows how to train and test an MNIST model on SageMaker using PyTorch.
Runtime
This notebook takes approximately 5 minutes to run.
Contents
[ ]:
import os
import json
import sagemaker
from sagemaker.pytorch import PyTorch
from sagemaker import get_execution_role
sess = sagemaker.Session()
region = sess.boto_region_name
role = get_execution_role()
output_path = "s3://" + sess.default_bucket() + "/DEMO-mnist"
PyTorch Estimator
The PyTorch
class allows you to run your training script on SageMaker infrastracture in a containerized environment. In this notebook, we refer to this container as training container.
You need to configure it with the following parameters to set up the environment:
entry_point
: A user-defined Python file used by the training container as the instructions for training. We further discuss this file in the next subsection.role
: An IAM role to make AWS service requestsinstance_type
: The type of SageMaker instance to run your training script. Set it tolocal
if you want to run the training job on the SageMaker instance you are using to run this notebookinstance_count
: The number of instances to run your training job on. Multiple instances are needed for distributed training.output_path
: S3 bucket URI to save training output (model artifacts and output files)framework_version
: The version of PyTorch to usepy_version
: The Python version to use
For more information, see the EstimatorBase API reference
Implement the entry point for training
The entry point for training is a Python script that provides all the code for training a PyTorch model. It is used by the SageMaker PyTorch Estimator (PyTorch
class above) as the entry point for running the training job.
Under the hood, SageMaker PyTorch Estimator creates a docker image with runtime environemnts specified by the parameters you provide to initiate the estimator class, and it injects the training script into the docker image as the entry point to run the container.
In the rest of the notebook, we use training image to refer to the docker image specified by the PyTorch Estimator and training container to refer to the container that runs the training image.
This means your training script is very similar to a training script you might run outside Amazon SageMaker, but it can access the useful environment variables provided by the training image. See the complete list of environment variables for a complete description of all environment variables your training script can access.
In this example, we use the training script code/train.py
as the entry point for our PyTorch Estimator.
[ ]:
!pygmentize 'code/train.py'
Set hyperparameters
In addition, the PyTorch estimator allows you to parse command line arguments to your training script via hyperparameters
.
Note: local mode is not supported in SageMaker Studio.
[ ]:
# Set local_mode to True to run the training script on the machine that runs this notebook
local_mode = False
if local_mode:
instance_type = "local"
else:
instance_type = "ml.c4.xlarge"
est = PyTorch(
entry_point="train.py",
source_dir="code", # directory of your training script
role=role,
framework_version="1.5.0",
py_version="py3",
instance_type=instance_type,
instance_count=1,
volume_size=250,
output_path=output_path,
hyperparameters={"batch-size": 128, "epochs": 1, "learning-rate": 1e-3, "log-interval": 100},
)
The training container executes your training script like:
python train.py --batch-size 100 --epochs 1 --learning-rate 1e-3 --log-interval 100
Set up channels for the training and testing data
Tell the PyTorch
estimator where to find the training and testing data. It can be a path to an S3 bucket, or a path in your local file system if you use local mode. In this example, we download the MNIST data from a public S3 bucket and upload it to your default bucket.
[ ]:
import logging
import boto3
from botocore.exceptions import ClientError
# Download training and testing data from a public S3 bucket
def download_from_s3(data_dir="./data", train=True):
"""Download MNIST dataset and convert it to numpy array
Args:
data_dir (str): directory to save the data
train (bool): download training set
Returns:
None
"""
if not os.path.exists(data_dir):
os.makedirs(data_dir)
if train:
images_file = "train-images-idx3-ubyte.gz"
labels_file = "train-labels-idx1-ubyte.gz"
else:
images_file = "t10k-images-idx3-ubyte.gz"
labels_file = "t10k-labels-idx1-ubyte.gz"
# download objects
s3 = boto3.client("s3")
bucket = f"sagemaker-example-files-prod-{region}"
for obj in [images_file, labels_file]:
key = os.path.join("datasets/image/MNIST", obj)
dest = os.path.join(data_dir, obj)
if not os.path.exists(dest):
s3.download_file(bucket, key, dest)
return
download_from_s3("./data", True)
download_from_s3("./data", False)
[ ]:
# Upload to the default bucket
prefix = "DEMO-mnist"
bucket = sess.default_bucket()
loc = sess.upload_data(path="./data", bucket=bucket, key_prefix=prefix)
channels = {"training": loc, "testing": loc}
The keys of the channels
dictionary are passed to the training image, and it creates the environment variable SM_CHANNEL_<key name>
.
In this example, SM_CHANNEL_TRAINING
and SM_CHANNEL_TESTING
are created in the training image (see how code/train.py
accesses these variables). For more information, see: SM_CHANNEL_{channel_name}.
If you want, you can create a channel for validation:
channels = {
'training': train_data_loc,
'validation': val_data_loc,
'test': test_data_loc
}
You can then access this channel within your training script via SM_CHANNEL_VALIDATION
.
Run the training script on SageMaker
Now, the training container has everything to execute your training script. Start the container by calling the fit()
method.
[ ]:
est.fit(inputs=channels)
Inspect and store model data
Now, the training is finished, and the model artifact has been saved in the output_path
.
[ ]:
pt_mnist_model_data = est.model_data
print("Model artifact saved at:\n", pt_mnist_model_data)
We store the variable pt_mnist_model_data
in the current notebook kernel.
[ ]:
%store pt_mnist_model_data
Test and debug the entry point before executing the training container
The entry point code/train.py
can be executed in the training container. When you develop your own training script, it is a good practice to simulate the container environment in the local shell and test it before sending it to SageMaker, because debugging in a containerized environment is rather cumbersome. The following script shows how you can test your training script:
[ ]:
!pygmentize code/test_train.py
Conclusion
In this notebook, we trained a PyTorch model on the MNIST dataset by fitting a SageMaker estimator. For next steps on how to deploy the trained model and perform inference, see Deploy a Trained PyTorch Model.
Notebook CI Test Results
This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.