[1]:
!yes | pip uninstall torchvison
!pip install -qU torchvision
WARNING: Skipping torchvison as it is not installed.
yes: standard output: Broken pipe
MNIST Training using PyTorch
This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.
Contents
Background
MNIST is a widely used dataset for handwritten digit classification. It consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits). This tutorial will show how to train and test an MNIST model on SageMaker using PyTorch.
For more information about the PyTorch in SageMaker, please visit sagemaker-pytorch-containers and sagemaker-python-sdk github repositories.
Setup
This notebook was created and tested on an ml.m4.xlarge notebook instance.
Let’s start by creating a SageMaker session and specifying:
The S3 bucket and prefix that you want to use for training and model data. This should be within the same region as the Notebook Instance, training, and hosting.
The IAM role arn used to give training and hosting access to your data. See the documentation for how to create these. Note, if more than one role is required for notebook instances, training, and/or hosting, please replace the
sagemaker.get_execution_role()
with a the appropriate full IAM role arn string(s).
[2]:
import sagemaker
sagemaker_session = sagemaker.Session()
region = sagemaker_session.boto_region_name
bucket = sagemaker_session.default_bucket()
prefix = "sagemaker/DEMO-pytorch-mnist"
role = sagemaker.get_execution_role()
Data
Getting the data
[3]:
from torchvision.datasets import MNIST
from torchvision import transforms
MNIST.mirrors = [
f"https://sagemaker-example-files-prod-{region}.s3.amazonaws.com/datasets/image/MNIST/"
]
MNIST(
"data",
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
)
[3]:
Dataset MNIST
Number of datapoints: 60000
Root location: data
Split: Train
StandardTransform
Transform: Compose(
ToTensor()
Normalize(mean=(0.1307,), std=(0.3081,))
)
Uploading the data to S3
We are going to use the sagemaker.Session.upload_data
function to upload our datasets to an S3 location. The return value inputs identifies the location – we will use later when we start the training job.
[4]:
inputs = sagemaker_session.upload_data(path="data", bucket=bucket, key_prefix=prefix)
print("input spec (in this case, just an S3 path): {}".format(inputs))
input spec (in this case, just an S3 path): s3://sagemaker-us-west-2-688520471316/sagemaker/DEMO-pytorch-mnist
Train
Training script
The mnist.py
script provides all the code we need for training and hosting a SageMaker model (model_fn
function to load a model). The training script is very similar to a training script you might run outside of SageMaker, but you can access useful properties about the training environment through various environment variables, such as:
SM_MODEL_DIR
: A string representing the path to the directory to write model artifacts to. These artifacts are uploaded to S3 for model hosting.SM_NUM_GPUS
: The number of gpus available in the current container.SM_CURRENT_HOST
: The name of the current container on the container network.SM_HOSTS
: JSON encoded list containing all the hosts .
Supposing one input channel, ‘training’, was used in the call to the PyTorch estimator’s fit()
method, the following will be set, following the format SM_CHANNEL_[channel_name]
:
SM_CHANNEL_TRAINING
: A string representing the path to the directory containing data in the ‘training’ channel.
For more information about training environment variables, please visit SageMaker Containers.
A typical training script loads data from the input channels, configures training with hyperparameters, trains a model, and saves a model to model_dir
so that it can be hosted later. Hyperparameters are passed to your script as arguments and can be retrieved with an argparse.ArgumentParser
instance.
Because the SageMaker imports the training script, you should put your training code in a main guard (if __name__=='__main__':
) if you are using the same script to host your model as we do in this example, so that SageMaker does not inadvertently run your training code at the wrong point in execution.
For example, the script run by this notebook:
[ ]:
!pygmentize mnist.py
Run training in SageMaker
The PyTorch
class allows us to run our training function as a training job on SageMaker infrastructure. We need to configure it with our training script, an IAM role, the number of training instances, the training instance type, and hyperparameters. In this case we are going to run our training job on 2 ml.c4.xlarge
instances. But this example can be ran on one or multiple, cpu or gpu instances (full list of available
instances). The hyperparameters parameter is a dict of values that will be passed to your training script – you can see how to access these values in the mnist.py
script above.
[8]:
from sagemaker.pytorch import PyTorch
estimator = PyTorch(
entry_point="mnist.py",
role=role,
py_version="py38",
framework_version="1.11.0",
instance_count=2,
instance_type="ml.c5.2xlarge",
hyperparameters={"epochs": 1, "backend": "gloo"},
)
After we’ve constructed our PyTorch
object, we can fit it using the data we uploaded to S3. SageMaker makes sure our data is available in the local filesystem, so our training script can simply read the data from disk.
[9]:
estimator.fit({"training": inputs})
2021-06-04 21:20:49 Starting - Starting the training job...
2021-06-04 21:20:50 Starting - Launching requested ML instancesProfilerReport-1622841649: InProgress
......
2021-06-04 21:22:17 Starting - Preparing the instances for training.........
2021-06-04 21:23:48 Downloading - Downloading input data
2021-06-04 21:23:48 Training - Downloading the training image...
2021-06-04 21:24:20 Uploading - Uploading generated training modelbash: cannot set terminal process group (-1): Inappropriate ioctl for device
bash: no job control in this shell
2021-06-04 21:24:05,228 sagemaker-training-toolkit INFO Imported framework sagemaker_pytorch_container.training
2021-06-04 21:24:05,230 sagemaker-training-toolkit INFO No GPUs detected (normal if no gpus installed)
2021-06-04 21:24:05,239 sagemaker_pytorch_container.training INFO Block until all host DNS lookups succeed.
2021-06-04 21:24:05,246 sagemaker_pytorch_container.training INFO Invoking user training script.
2021-06-04 21:24:05,636 sagemaker-training-toolkit INFO No GPUs detected (normal if no gpus installed)
2021-06-04 21:24:05,647 sagemaker-training-toolkit INFO No GPUs detected (normal if no gpus installed)
2021-06-04 21:24:05,658 sagemaker-training-toolkit INFO No GPUs detected (normal if no gpus installed)
2021-06-04 21:24:05,667 sagemaker-training-toolkit INFO Invoking user script
Training Env:
{
"additional_framework_parameters": {},
"channel_input_dirs": {
"training": "/opt/ml/input/data/training"
},
"current_host": "algo-1",
"framework_module": "sagemaker_pytorch_container.training:main",
"hosts": [
"algo-1",
"algo-2"
],
"hyperparameters": {
"backend": "gloo",
"epochs": 1
},
"input_config_dir": "/opt/ml/input/config",
"input_data_config": {
"training": {
"TrainingInputMode": "File",
"S3DistributionType": "FullyReplicated",
"RecordWrapperType": "None"
}
},
"input_dir": "/opt/ml/input",
"is_master": true,
"job_name": "pytorch-training-2021-06-04-21-20-48-860",
"log_level": 20,
"master_hostname": "algo-1",
"model_dir": "/opt/ml/model",
"module_dir": "s3://sagemaker-us-west-2-688520471316/pytorch-training-2021-06-04-21-20-48-860/source/sourcedir.tar.gz",
"module_name": "mnist",
"network_interface_name": "eth0",
"num_cpus": 8,
"num_gpus": 0,
"output_data_dir": "/opt/ml/output/data",
"output_dir": "/opt/ml/output",
"output_intermediate_dir": "/opt/ml/output/intermediate",
"resource_config": {
"current_host": "algo-1",
"hosts": [
"algo-1",
"algo-2"
],
"network_interface_name": "eth0"
},
"user_entry_point": "mnist.py"
}
Environment variables:
SM_HOSTS=["algo-1","algo-2"]
SM_NETWORK_INTERFACE_NAME=eth0
SM_HPS={"backend":"gloo","epochs":1}
SM_USER_ENTRY_POINT=mnist.py
SM_FRAMEWORK_PARAMS={}
SM_RESOURCE_CONFIG={"current_host":"algo-1","hosts":["algo-1","algo-2"],"network_interface_name":"eth0"}
SM_INPUT_DATA_CONFIG={"training":{"RecordWrapperType":"None","S3DistributionType":"FullyReplicated","TrainingInputMode":"File"}}
SM_OUTPUT_DATA_DIR=/opt/ml/output/data
SM_CHANNELS=["training"]
SM_CURRENT_HOST=algo-1
SM_MODULE_NAME=mnist
SM_LOG_LEVEL=20
SM_FRAMEWORK_MODULE=sagemaker_pytorch_container.training:main
SM_INPUT_DIR=/opt/ml/input
SM_INPUT_CONFIG_DIR=/opt/ml/input/config
SM_OUTPUT_DIR=/opt/ml/output
SM_NUM_CPUS=8
SM_NUM_GPUS=0
SM_MODEL_DIR=/opt/ml/model
SM_MODULE_DIR=s3://sagemaker-us-west-2-688520471316/pytorch-training-2021-06-04-21-20-48-860/source/sourcedir.tar.gz
SM_TRAINING_ENV={"additional_framework_parameters":{},"channel_input_dirs":{"training":"/opt/ml/input/data/training"},"current_host":"algo-1","framework_module":"sagemaker_pytorch_container.training:main","hosts":["algo-1","algo-2"],"hyperparameters":{"backend":"gloo","epochs":1},"input_config_dir":"/opt/ml/input/config","input_data_config":{"training":{"RecordWrapperType":"None","S3DistributionType":"FullyReplicated","TrainingInputMode":"File"}},"input_dir":"/opt/ml/input","is_master":true,"job_name":"pytorch-training-2021-06-04-21-20-48-860","log_level":20,"master_hostname":"algo-1","model_dir":"/opt/ml/model","module_dir":"s3://sagemaker-us-west-2-688520471316/pytorch-training-2021-06-04-21-20-48-860/source/sourcedir.tar.gz","module_name":"mnist","network_interface_name":"eth0","num_cpus":8,"num_gpus":0,"output_data_dir":"/opt/ml/output/data","output_dir":"/opt/ml/output","output_intermediate_dir":"/opt/ml/output/intermediate","resource_config":{"current_host":"algo-1","hosts":["algo-1","algo-2"],"network_interface_name":"eth0"},"user_entry_point":"mnist.py"}
SM_USER_ARGS=["--backend","gloo","--epochs","1"]
SM_OUTPUT_INTERMEDIATE_DIR=/opt/ml/output/intermediate
SM_CHANNEL_TRAINING=/opt/ml/input/data/training
SM_HP_BACKEND=gloo
SM_HP_EPOCHS=1
PYTHONPATH=/opt/ml/code:/opt/conda/bin:/opt/conda/lib/python36.zip:/opt/conda/lib/python3.6:/opt/conda/lib/python3.6/lib-dynload:/opt/conda/lib/python3.6/site-packages
Invoking script with the following command:
/opt/conda/bin/python3.6 mnist.py --backend gloo --epochs 1
Distributed training - True
Number of gpus available - 0
bash: cannot set terminal process group (-1): Inappropriate ioctl for device
bash: no job control in this shell
2021-06-04 21:24:05,062 sagemaker-training-toolkit INFO Imported framework sagemaker_pytorch_container.training
2021-06-04 21:24:05,064 sagemaker-training-toolkit INFO No GPUs detected (normal if no gpus installed)
2021-06-04 21:24:05,073 sagemaker_pytorch_container.training INFO Block until all host DNS lookups succeed.
2021-06-04 21:24:05,081 sagemaker_pytorch_container.training INFO Invoking user training script.
2021-06-04 21:24:05,485 sagemaker-training-toolkit INFO No GPUs detected (normal if no gpus installed)
2021-06-04 21:24:05,496 sagemaker-training-toolkit INFO No GPUs detected (normal if no gpus installed)
2021-06-04 21:24:05,506 sagemaker-training-toolkit INFO No GPUs detected (normal if no gpus installed)
2021-06-04 21:24:05,515 sagemaker-training-toolkit INFO Invoking user script
Training Env:
{
"additional_framework_parameters": {},
"channel_input_dirs": {
"training": "/opt/ml/input/data/training"
},
"current_host": "algo-2",
"framework_module": "sagemaker_pytorch_container.training:main",
"hosts": [
"algo-1",
"algo-2"
],
"hyperparameters": {
"backend": "gloo",
"epochs": 1
},
"input_config_dir": "/opt/ml/input/config",
"input_data_config": {
"training": {
"TrainingInputMode": "File",
"S3DistributionType": "FullyReplicated",
"RecordWrapperType": "None"
}
},
"input_dir": "/opt/ml/input",
"is_master": false,
"job_name": "pytorch-training-2021-06-04-21-20-48-860",
"log_level": 20,
"master_hostname": "algo-1",
"model_dir": "/opt/ml/model",
"module_dir": "s3://sagemaker-us-west-2-688520471316/pytorch-training-2021-06-04-21-20-48-860/source/sourcedir.tar.gz",
"module_name": "mnist",
"network_interface_name": "eth0",
"num_cpus": 8,
"num_gpus": 0,
"output_data_dir": "/opt/ml/output/data",
"output_dir": "/opt/ml/output",
"output_intermediate_dir": "/opt/ml/output/intermediate",
"resource_config": {
"current_host": "algo-2",
"hosts": [
"algo-1",
"algo-2"
],
"network_interface_name": "eth0"
},
"user_entry_point": "mnist.py"
}
Environment variables:
SM_HOSTS=["algo-1","algo-2"]
SM_NETWORK_INTERFACE_NAME=eth0
SM_HPS={"backend":"gloo","epochs":1}
SM_USER_ENTRY_POINT=mnist.py
SM_FRAMEWORK_PARAMS={}
SM_RESOURCE_CONFIG={"current_host":"algo-2","hosts":["algo-1","algo-2"],"network_interface_name":"eth0"}
SM_INPUT_DATA_CONFIG={"training":{"RecordWrapperType":"None","S3DistributionType":"FullyReplicated","TrainingInputMode":"File"}}
SM_OUTPUT_DATA_DIR=/opt/ml/output/data
SM_CHANNELS=["training"]
SM_CURRENT_HOST=algo-2
SM_MODULE_NAME=mnist
SM_LOG_LEVEL=20
SM_FRAMEWORK_MODULE=sagemaker_pytorch_container.training:main
SM_INPUT_DIR=/opt/ml/input
SM_INPUT_CONFIG_DIR=/opt/ml/input/config
SM_OUTPUT_DIR=/opt/ml/output
SM_NUM_CPUS=8
SM_NUM_GPUS=0
SM_MODEL_DIR=/opt/ml/model
SM_MODULE_DIR=s3://sagemaker-us-west-2-688520471316/pytorch-training-2021-06-04-21-20-48-860/source/sourcedir.tar.gz
SM_TRAINING_ENV={"additional_framework_parameters":{},"channel_input_dirs":{"training":"/opt/ml/input/data/training"},"current_host":"algo-2","framework_module":"sagemaker_pytorch_container.training:main","hosts":["algo-1","algo-2"],"hyperparameters":{"backend":"gloo","epochs":1},"input_config_dir":"/opt/ml/input/config","input_data_config":{"training":{"RecordWrapperType":"None","S3DistributionType":"FullyReplicated","TrainingInputMode":"File"}},"input_dir":"/opt/ml/input","is_master":false,"job_name":"pytorch-training-2021-06-04-21-20-48-860","log_level":20,"master_hostname":"algo-1","model_dir":"/opt/ml/model","module_dir":"s3://sagemaker-us-west-2-688520471316/pytorch-training-2021-06-04-21-20-48-860/source/sourcedir.tar.gz","module_name":"mnist","network_interface_name":"eth0","num_cpus":8,"num_gpus":0,"output_data_dir":"/opt/ml/output/data","output_dir":"/opt/ml/output","output_intermediate_dir":"/opt/ml/output/intermediate","resource_config":{"current_host":"algo-2","hosts":["algo-1","algo-2"],"network_interface_name":"eth0"},"user_entry_point":"mnist.py"}
SM_USER_ARGS=["--backend","gloo","--epochs","1"]
SM_OUTPUT_INTERMEDIATE_DIR=/opt/ml/output/intermediate
SM_CHANNEL_TRAINING=/opt/ml/input/data/training
SM_HP_BACKEND=gloo
SM_HP_EPOCHS=1
PYTHONPATH=/opt/ml/code:/opt/conda/bin:/opt/conda/lib/python36.zip:/opt/conda/lib/python3.6:/opt/conda/lib/python3.6/lib-dynload:/opt/conda/lib/python3.6/site-packages
Invoking script with the following command:
/opt/conda/bin/python3.6 mnist.py --backend gloo --epochs 1
Distributed training - True
Number of gpus available - 0
Initialized the distributed environment: 'gloo' backend on 2 nodes. Current host rank is 0. Number of gpus: 0
Get train data loader
Get test data loader
Processes 30000/60000 (50%) of train data
Processes 10000/10000 (100%) of test data
[2021-06-04 21:24:08.348 algo-1:25 INFO utils.py:27] RULE_JOB_STOP_SIGNAL_FILENAME: None
Initialized the distributed environment: 'gloo' backend on 2 nodes. Current host rank is 1. Number of gpus: 0
Get train data loader
Get test data loader
Processes 30000/60000 (50%) of train data
Processes 10000/10000 (100%) of test data
[2021-06-04 21:24:08.339 algo-2:25 INFO utils.py:27] RULE_JOB_STOP_SIGNAL_FILENAME: None
[2021-06-04 21:24:08.751 algo-2:25 INFO profiler_config_parser.py:102] User has disabled profiler.
[2021-06-04 21:24:08.752 algo-2:25 INFO json_config.py:91] Creating hook from json_config at /opt/ml/input/config/debughookconfig.json.
[2021-06-04 21:24:08.752 algo-2:25 INFO hook.py:199] tensorboard_dir has not been set for the hook. SMDebug will not be exporting tensorboard summaries.
[2021-06-04 21:24:08.752 algo-2:25 INFO hook.py:253] Saving to /opt/ml/output/tensors
[2021-06-04 21:24:08.752 algo-2:25 INFO state_store.py:77] The checkpoint config file /opt/ml/input/config/checkpointconfig.json does not exist.
[2021-06-04 21:24:08.808 algo-2:25 INFO hook.py:584] name:module.conv1.weight count_params:250
[2021-06-04 21:24:08.808 algo-2:25 INFO hook.py:584] name:module.conv1.bias count_params:10
[2021-06-04 21:24:08.808 algo-2:25 INFO hook.py:584] name:module.conv2.weight count_params:5000
[2021-06-04 21:24:08.808 algo-2:25 INFO hook.py:584] name:module.conv2.bias count_params:20
[2021-06-04 21:24:08.808 algo-2:25 INFO hook.py:584] name:module.fc1.weight count_params:16000
[2021-06-04 21:24:08.808 algo-2:25 INFO hook.py:584] name:module.fc1.bias count_params:50
[2021-06-04 21:24:08.808 algo-2:25 INFO hook.py:584] name:module.fc2.weight count_params:500
[2021-06-04 21:24:08.808 algo-2:25 INFO hook.py:584] name:module.fc2.bias count_params:10
[2021-06-04 21:24:08.808 algo-2:25 INFO hook.py:586] Total Trainable Params: 21840
[2021-06-04 21:24:08.808 algo-2:25 INFO hook.py:413] Monitoring the collections: losses
[2021-06-04 21:24:08.744 algo-1:25 INFO profiler_config_parser.py:102] User has disabled profiler.
[2021-06-04 21:24:08.744 algo-1:25 INFO json_config.py:91] Creating hook from json_config at /opt/ml/input/config/debughookconfig.json.
[2021-06-04 21:24:08.745 algo-1:25 INFO hook.py:199] tensorboard_dir has not been set for the hook. SMDebug will not be exporting tensorboard summaries.
[2021-06-04 21:24:08.745 algo-1:25 INFO hook.py:253] Saving to /opt/ml/output/tensors
[2021-06-04 21:24:08.745 algo-1:25 INFO state_store.py:77] The checkpoint config file /opt/ml/input/config/checkpointconfig.json does not exist.
[2021-06-04 21:24:08.795 algo-1:25 INFO hook.py:584] name:module.conv1.weight count_params:250
[2021-06-04 21:24:08.795 algo-1:25 INFO hook.py:584] name:module.conv1.bias count_params:10
[2021-06-04 21:24:08.795 algo-1:25 INFO hook.py:584] name:module.conv2.weight count_params:5000
[2021-06-04 21:24:08.795 algo-1:25 INFO hook.py:584] name:module.conv2.bias count_params:20
[2021-06-04 21:24:08.795 algo-1:25 INFO hook.py:584] name:module.fc1.weight count_params:16000
[2021-06-04 21:24:08.795 algo-1:25 INFO hook.py:584] name:module.fc1.bias count_params:50
[2021-06-04 21:24:08.795 algo-1:25 INFO hook.py:584] name:module.fc2.weight count_params:500
[2021-06-04 21:24:08.795 algo-1:25 INFO hook.py:584] name:module.fc2.bias count_params:10
[2021-06-04 21:24:08.795 algo-1:25 INFO hook.py:586] Total Trainable Params: 21840
[2021-06-04 21:24:08.795 algo-1:25 INFO hook.py:413] Monitoring the collections: losses
Train Epoch: 1 [6400/30000 (21%)] Loss: 2.075230
Train Epoch: 1 [6400/30000 (21%)] Loss: 2.076306
Train Epoch: 1 [12800/30000 (43%)] Loss: 1.056925
Train Epoch: 1 [12800/30000 (43%)] Loss: 1.216741
Train Epoch: 1 [19200/30000 (64%)] Loss: 0.911026
Train Epoch: 1 [19200/30000 (64%)] Loss: 0.942084
Train Epoch: 1 [25600/30000 (85%)] Loss: 0.671317
Train Epoch: 1 [25600/30000 (85%)] Loss: 0.841636
Test set: Average loss: 0.3237, Accuracy: 9094/10000 (91%)
Saving the model.
INFO:__main__:Initialized the distributed environment: 'gloo' backend on 2 nodes. Current host rank is 0. Number of gpus: 0
INFO:__main__:Get train data loader
INFO:__main__:Get test data loader
DEBUG:__main__:Processes 30000/60000 (50%) of train data
DEBUG:__main__:Processes 10000/10000 (100%) of test data
/opt/conda/lib/python3.6/site-packages/torch/distributed/distributed_c10d.py:144: UserWarning: torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead
warnings.warn("torch.distributed.reduce_op is deprecated, please use "
INFO:__main__:Train Epoch: 1 [6400/30000 (21%)] Loss: 2.076306
INFO:__main__:Train Epoch: 1 [12800/30000 (43%)] Loss: 1.056925
INFO:__main__:Train Epoch: 1 [19200/30000 (64%)] Loss: 0.942084
INFO:__main__:Train Epoch: 1 [25600/30000 (85%)] Loss: 0.841636
/opt/conda/lib/python3.6/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
warnings.warn(warning.format(ret))
INFO:__main__:Test set: Average loss: 0.3237, Accuracy: 9094/10000 (91%)
INFO:__main__:Saving the model.
2021-06-04 21:24:18,084 sagemaker-training-toolkit INFO Reporting training SUCCESS
Test set: Average loss: 0.3237, Accuracy: 9094/10000 (91%)
Saving the model.
INFO:__main__:Initialized the distributed environment: 'gloo' backend on 2 nodes. Current host rank is 1. Number of gpus: 0
INFO:__main__:Get train data loader
INFO:__main__:Get test data loader
DEBUG:__main__:Processes 30000/60000 (50%) of train data
DEBUG:__main__:Processes 10000/10000 (100%) of test data
/opt/conda/lib/python3.6/site-packages/torch/distributed/distributed_c10d.py:144: UserWarning: torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead
warnings.warn("torch.distributed.reduce_op is deprecated, please use "
INFO:__main__:Train Epoch: 1 [6400/30000 (21%)] Loss: 2.075230
INFO:__main__:Train Epoch: 1 [12800/30000 (43%)] Loss: 1.216741
INFO:__main__:Train Epoch: 1 [19200/30000 (64%)] Loss: 0.911026
INFO:__main__:Train Epoch: 1 [25600/30000 (85%)] Loss: 0.671317
/opt/conda/lib/python3.6/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
warnings.warn(warning.format(ret))
INFO:__main__:Test set: Average loss: 0.3237, Accuracy: 9094/10000 (91%)
INFO:__main__:Saving the model.
2021-06-04 21:24:18,140 sagemaker-training-toolkit INFO Reporting training SUCCESS
2021-06-04 21:24:46 Completed - Training job completed
Training seconds: 124
Billable seconds: 124
Host
Create endpoint
After training, we use the PyTorch
estimator object to build and deploy a PyTorchPredictor
. This creates a Sagemaker Endpoint – a hosted prediction service that we can use to perform inference.
As mentioned above we have implementation of model_fn
in the mnist.py
script that is required. We are going to use default implementations of input_fn
, predict_fn
, output_fn
and transform_fm
defined in sagemaker-pytorch-containers.
The arguments to the deploy function allow us to set the number and type of instances that will be used for the Endpoint. These do not need to be the same as the values we used for the training job. For example, you can train a model on a set of GPU-based instances, and then deploy the Endpoint to a fleet of CPU-based instances, but you need to make sure that you return or save your model as a cpu model similar to what we did in mnist.py
. Here we will deploy the model to a single
ml.m4.xlarge
instance.
[10]:
predictor = estimator.deploy(initial_instance_count=1, instance_type="ml.m4.xlarge")
-------!
Evaluate
You can use the test images to evalute the endpoint. The accuracy of the model depends on how many it is trained.
[14]:
!ls data/MNIST/raw
t10k-images-idx3-ubyte train-images-idx3-ubyte
t10k-images-idx3-ubyte.gz train-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte train-labels-idx1-ubyte
t10k-labels-idx1-ubyte.gz train-labels-idx1-ubyte.gz
[15]:
import gzip
import numpy as np
import random
import os
data_dir = "data/MNIST/raw"
with gzip.open(os.path.join(data_dir, "t10k-images-idx3-ubyte.gz"), "rb") as f:
images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28).astype(np.float32)
mask = random.sample(range(len(images)), 16) # randomly select some of the test images
mask = np.array(mask, dtype=np.int32)
data = images[mask]
[18]:
response = predictor.predict(np.expand_dims(data, axis=1))
print("Raw prediction result:")
print(response)
print()
labeled_predictions = list(zip(range(10), response[0]))
print("Labeled predictions: ")
print(labeled_predictions)
print()
labeled_predictions.sort(key=lambda label_and_prob: 1.0 - label_and_prob[1])
print("Most likely answer: {}".format(labeled_predictions[0]))
Raw prediction result:
[[ -841.81384277 -809.42578125 -686.82019043 -514.44500732
-308.11486816 -555.79718018 -855.24853516 -68.2052002
-401.67419434 0. ]
[ -646.59735107 -895.97802734 0. -544.23052979
-755.7565918 -922.04547119 -625.824646 -812.57006836
-500.33776855 -805.3817749 ]
[ -887.1842041 -817.05535889 -741.71582031 -870.87322998
0. -643.92150879 -453.70858765 -701.31427002
-686.24975586 -333.54837036]
[ -891.87609863 -857.50964355 -627.76599121 -970.46557617
0. -768.94299316 -505.81027222 -653.79400635
-676.14233398 -344.46173096]
[ -660.29199219 -792.37634277 -664.95281982 0.
-958.40881348 -385.23153687 -958.44628906 -779.3081665
-500.71435547 -699.39404297]
[ -834.11621094 -886.70550537 -529.53649902 -998.31347656
-532.21209717 -538.67773438 0. -1198.16943359
-689.94897461 -867.32391357]
[ -465.91723633 0. -236.69277954 -139.23117065
-266.07495117 -185.28543091 -275.40328979 -267.25372314
-170.0657959 -231.46386719]
[ -211.82693481 -656.51843262 -441.55142212 -156.15490723
-581.20227051 0. -247.48072815 -792.90441895
-189.44906616 -535.47607422]
[ -823.66131592 -686.35046387 0. -666.04376221
-725.9387207 -951.12542725 -613.12091064 -868.25402832
-583.81872559 -809.61730957]
[ -716.01312256 0. -457.86019897 -407.98397827
-428.36694336 -492.61264038 -519.45214844 -401.35784912
-434.07836914 -442.5569458 ]
[ -614.79174805 -342.24572754 -124.56967163 0.
-503.47006226 -328.65908813 -694.87536621 -285.80020142
-37.24746704 -272.1781311 ]
[ -693.29498291 -568.50708008 -590.91729736 0.
-661.19091797 -358.86340332 -782.01672363 -653.08392334
-463.67196655 -516.52587891]
[ -377.30566406 -512.46130371 -608.6083374 0.
-637.54870605 -59.34466553 -394.70697021 -676.92193604
-573.77575684 -521.92242432]
[ -357.77236938 -1050.88598633 -937.78881836 -340.1932373
-673.72698975 0. -491.30606079 -905.3470459
-411.84393311 -518.42871094]
[ -852.84472656 -966.48913574 -761.59924316 -685.07775879
-177.07427979 -618.85943604 -756.14312744 -102.18136597
-460.47125244 0. ]
[ -712.10375977 -574.57720947 -563.62750244 0.
-855.70446777 -436.41946411 -968.52105713 -534.12976074
-455.9407959 -543.6784668 ]]
Labeled predictions:
[(0, -841.8138427734375), (1, -809.42578125), (2, -686.8201904296875), (3, -514.4450073242188), (4, -308.1148681640625), (5, -555.7971801757812), (6, -855.24853515625), (7, -68.2052001953125), (8, -401.6741943359375), (9, 0.0)]
Most likely answer: (9, 0.0)
Cleanup
After you have finished with this example, remember to delete the prediction endpoint to release the instance(s) associated with it
[24]:
sagemaker_session.delete_endpoint(endpoint_name=predictor.endpoint_name)
Notebook CI Test Results
This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.