Training using SageMaker Estimators on SageMaker Managed Spot Training

This notebook is able to run using SageMaker Managed Spot infrastructure. SageMaker Managed Spot uses EC2 Spot Instances to run Training at a lower cost.

Please read the original notebook and try it out to gain an understanding of the ML use-case and how it is being solved. We will not delve into that here in this notebook.


This notebook uses the Iris dataset from the UCI Machine Learning Repository.

Iris Data Set [].

Dua, D. and Graff, C. (2019). UCI Machine Learning Repository []. Irvine, CA: University of California, School of Information and Computer Science.

Set up variables and define functions

[ ]:
from sagemaker import get_execution_role
from sagemaker.session import Session

# S3 bucket for saving code and model artifacts.
# Feel free to specify a different bucket here if you wish.
bucket = Session().default_bucket()

# Location to save your custom code in tar.gz format.
custom_code_upload_location = "s3://{}/customcode/tensorflow_iris".format(bucket)

# Location where results of model training are saved.
model_artifacts_location = "s3://{}/artifacts".format(bucket)

# IAM execution role that gives SageMaker access to resources in your AWS account.
role = get_execution_role()

def estimator(model_path, hyperparameters):
    feature_columns = [tf.feature_column.numeric_column(INPUT_TENSOR_NAME, shape=[4])]
    return tf.estimator.DNNClassifier(
        hidden_units=[10, 20, 10],

def estimator(model_path, hyperparameters):
    feature_columns = [tf.feature_column.numeric_column(INPUT_TENSOR_NAME, shape=[4])]
    return tf.estimator.DNNClassifier(
        hidden_units=[10, 20, 10],

def train_input_fn(training_dir, hyperparameters):
    training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
        filename=os.path.join(training_dir, "iris_training.csv"),,

    return tf.estimator.inputs.numpy_input_fn(
        x={INPUT_TENSOR_NAME: np.array(},

def serving_input_fn(hyperparameters):
    feature_spec = {INPUT_TENSOR_NAME: tf.FixedLenFeature(dtype=tf.float32, shape=[4])}
    return tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)()

import boto3

region = boto3.Session().region_name

Managed Spot Training with a TensorFlow Estimator

For Managed Spot Training using a TensorFlow Estimator we need to configure two things: 1. Enable the use_spot_instances constructor arg - a simple self-explanatory boolean. 2. Set the max_wait constructor arg - this is an int arg representing the amount of time you are willing to wait for Spot infrastructure to become available. Some instance types are harder to get at Spot prices and you may have to wait longer. You are not charged for time spent waiting for Spot infrastructure to become available, you’re only charged for actual compute time spent once Spot instances have been successfully procured.

Normally, a third requirement would also be necessary here - modifying your code to ensure a regular checkpointing cadence - however, TensorFlow Estimators already do this, so no changes are necessary here. Checkpointing is highly recommended for Manage Spot Training jobs due to the fact that Spot instances can be interrupted with short notice and using checkpoints to resume from the last interruption ensures you don’t lose any progress made before the interruption.

Feel free to toggle the use_spot_instances variable to see the effect of running the same job using regular (a.k.a. “On Demand”) infrastructure.

Note that max_wait can be set if and only if use_spot_instances is enabled and must be greater than or equal to max_run.

[ ]:
use_spot_instances = True
max_run = 3600
max_wait = 7200 if use_spot_instances else None
[ ]:
from sagemaker.tensorflow import TensorFlow

iris_estimator = TensorFlow(

s3 = boto3.client("s3")
s3.download_file("sagemaker-sample-files", "datasets/tabular/iris/iris_train.csv", "iris_train.csv")
s3.download_file("sagemaker-sample-files", "datasets/tabular/iris/iris_test.csv", "iris_test.csv")
s3.upload_file("iris_train.csv", bucket, "DEMO-tensorflow-iris/iris_train.csv")
s3.upload_file("iris_test.csv", bucket, "DEMO-tensorflow-iris/iris_test.csv")
train_data_location = "s3://{}/DEMO-tensorflow-iris/".format(bucket)


Towards the end of the job you should see two lines of output printed:

  • Training seconds: X : This is the actual compute-time your training job spent

  • Billable seconds: Y : This is the time you will be billed for after Spot discounting is applied.

If you enabled the use_spot_instances var then you should see a notable difference between X and Y signifying the cost savings you will get for having chosen Managed Spot Training. This should be reflected in an additional line: - Managed Spot Training savings: (1-Y/X)*100 %

For instance:

Training seconds: 42Billable seconds: 8Managed Spot Training savings: 81.0%