Part 1: Distributed data parallel MNIST training with PyTorch and SageMaker distributed


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


Background

Amazon SageMaker’s distributed library can be used to train deep learning models faster and cheaper. The data parallel feature in this library is a distributed data parallel training framework for PyTorch, TensorFlow, and MXNet. This notebook demonstrates how to use the SageMaker distributed data library to train a PyTorch model using the MNIST dataset.

This notebook example shows how to use smdistributed.dataparallel with PyTorch in SageMaker using MNIST dataset.

For more information:

  1. SageMaker distributed data parallel PyTorch API Specification

  2. Getting started with SageMaker distributed data parallel

  3. PyTorch in SageMaker

Dataset

This example uses the MNIST dataset. MNIST is a widely used dataset for handwritten digit classification. It consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits).

NOTE: This example requires SageMaker Python SDK v2.**.

[ ]:
!pip install sagemaker --upgrade

SageMaker role

The following code cell defines role which is the IAM role ARN used to create and run SageMaker training and hosting jobs. This is the same IAM role used to create this SageMaker Notebook instance.

role must have permission to create a SageMaker training job and launch an endpoint to host a model. For granular policies you can use to grant these permissions, see Amazon SageMaker Roles.

[ ]:
import sagemaker

sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
role_name = role.split(["/"][-1])
print(f"The Amazon Resource Name (ARN) of the role used for this demo is: {role}")
print(f"The name of the role used for this demo is: {role_name[-1]}")

To verify that the role above has required permissions:

  1. Go to the IAM console: https://console.aws.amazon.com/iam/home.

  2. Select Roles.

  3. Enter the role name in the search box to search for that role.

  4. Select the role.

  5. Use the Permissions tab to verify this role has required permissions attached.

Model training with SageMaker distributed data parallel

Training script

The MNIST dataset is downloaded using the torchvision.datasets PyTorch module; you can see how this is implemented in the train_pytorch_smdataparallel_mnist.py training script that is printed out in the next cell.

The training script provides the code you need for distributed data parallel (DDP) training using SageMaker’s distributed data parallel library (smdistributed.dataparallel). The training script is very similar to a PyTorch training script you might run outside SageMaker, but modified to run with the smdistributed.dataparallel library. This library’s PyTorch client provides an alternative to PyTorch’s native DDP.

For details about how to use smdistributed.dataparallel’s DDP in your native PyTorch script, see the Modify a PyTorch Training Script Using SMD Data Parallel.

[ ]:
!pygmentize code/train_pytorch_smdataparallel_mnist.py

Estimator function options

In the following code block, you can update the estimator function to use a different instance type, instance count, and distribution strategy. You’re also passing in the training script you reviewed in the previous cell to this estimator.

Instance types

smdistributed.dataparallel supports model training on SageMaker with the following instance types only. For best performance, it is recommended you use an instance type that supports Amazon Elastic Fabric Adapter (ml.p3dn.24xlarge and ml.p4d.24xlarge).

  1. ml.p3.16xlarge

  2. ml.p3dn.24xlarge [Recommended]

  3. ml.p4d.24xlarge [Recommended]

If you want to use another instance type where SM DDP is not supported, you can change simply change the distribution parameter to use another PyTorch launcher (as detailed here), with no required changes to your training code.

Instance count

To get the best performance and the most out of smdistributed.dataparallel, you should use at least 2 instances, but you can also use 1 for testing this example.

Distribution strategy

Note that to use DDP mode, you update the distribution strategy, and set it to use smdistributed dataparallel.

[ ]:
from sagemaker.pytorch import PyTorch

estimator = PyTorch(
    base_job_name="pytorch-smdataparallel-mnist",
    source_dir="code",
    entry_point="train_pytorch_smdataparallel_mnist.py",
    role=role,
    framework_version="2.0.1",
    py_version="py38",
    # For training with multinode distributed training, set this count. Example: 2
    instance_count=1,
    # For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge
    instance_type="ml.p4d.24xlarge",
    # instance_type="ml.g5.48xlarge",
    sagemaker_session=sagemaker_session,
    # Training using SMDataParallel Distributed Training Framework
    distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
    # Training with torchrun launcher, for instance types that do not support smddp
    # distribution={ "torch_distributed": { "enabled": True } },
    hyperparameters={"region": sagemaker_session.boto_region_name},
    debugger_hook_config=False,
)
[ ]:
estimator.fit()

Next steps

Now that you have a trained model, you can deploy an endpoint to host the model. After you deploy the endpoint, you can then test it with inference requests. The following cell will store the model_data variable to be used with the inference notebook.

[ ]:
model_data = estimator.model_data
print("Storing {} as model_data".format(model_data))
%store model_data

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable