Gluon CIFAR-10 Hyperparameter Tuning

*ResNet model in Gluon trained with SageMaker Automatic Model Tuning and Random Search Tuning*



This notebook was created and tested on an ml.m4.xlarge notebook instance. However, the tuning jobs use multiple ml.p3.8xlarge instances, meaning re-running this test could cost approximately $400. Please do not use Cell -> Run All. Certain cell outputs have not been cleared so that you can see results without having to run the notebook yourself.

Outline

  1. Background

  2. Setup

  3. Data

  4. Script

  5. Train: Initial

  6. Tune: Random

  7. Tune: Automatic Model Tuning

  8. Wrap-up

Background

Selecting the right hyperparameter values for your machine learning model can be difficult. The right answer is dependent on your data; some algorithms have many different hyperparameters that can be tweaked; some are very sensitive to the hyperparameter values selected; and most have a non-linear relationship between model fit and hyperparameter values.

There are a variety of strategies to select hyperparameter values. Some scientists use domain knowledge, heuristics, intuition, or manual experimentation; others use brute force searches; and some build meta models to predict what performant hyperparameter values may be. But regardless of the method, it usually requires a specialized skill set. Meanwhile, most scientists themselves would prefer to be creating new models rather than endlessly refining an old one.

Amazon SageMaker can ease this process with Automatic Model Tuning. This technique uses Gaussian Process regression to predict which hyperparameter values may be most effective at improving fit, and Bayesian optimization to balance exploring the hyperparameter space (so that a better predictive model for hyperparameters can be built) and exploiting specific hyperparameter values when needed.

Other popular methods of hyperparameter optimization include brute force methods like random search. Despite sounding naive, this is often very competitive. However, we’ve found SageMaker’s Automatic Model Tuning to provide better fits in fewer job runs, resulting in a better model with less time spent and at a lower cost. This notebook will compare the two methods in more detail.

SageMaker’s Automatic Model Tuning works with SageMaker’s built-in algorithms, pre-built deep learning frameworks, and the bring your own algorithm container options. But, for this example, let’s stick with the MXNet framework, a ResNet-34 convolutional neural network, and the CIFAR-10 image dataset. For more background, please see the MXNet CIFAR-10 example notebook.

Setup

Install scikit-image==0.14.2

[ ]:
# Install a scikit-image package in the current Jupyter kernel
import sys

!{sys.executable} -m pip install scikit-image==0.14.2

Specify the IAM role for permission to access the dataset in S3 and SageMaker functionality.

[ ]:
import sagemaker

sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()

Let’s import the necessary libraries.

[ ]:
from sagemaker.mxnet import MXNet
from sagemaker.tuner import (
    IntegerParameter,
    CategoricalParameter,
    ContinuousParameter,
    HyperparameterTuner,
)
import random_tuner as rt
import pandas as pd
import matplotlib.pyplot as plt

Data

We’ll use a helper script to download CIFAR-10 training data and sample images. CIFAR-10 consists of 60K 32x32 pixel color images (50K train, 10K test) evenly distributed across 10 classes.

[ ]:
from cifar10_utils import download_training_data

download_training_data()

Next we’ll use the sagemaker.Session.upload_data function to upload our datasets to an S3 location. The return value inputs identifies the location – we will use this later when we start the training and tuning jobs.

[ ]:
inputs = sagemaker_session.upload_data(path="data", key_prefix="data/DEMO-gluon-cifar10")
print("input spec (in this case, just an S3 path): {}".format(inputs))

Script

We need to provide a training script that can run on the SageMaker platform. This is idiomatic MXNet code arranged into a few key functions: * A train() function that takes in hyperparameters, defines our neural net architecture, and trains our network. * A save() function that saves our trained network as an MXNet model. * Helper functions get_data(), get_train_data(), and get_test_data() which prepare the CIFAR-10 image data for our train() function. * A helper function called test() which calculates our accuracy on the holdout datasets. * Hosting functions (which we keep for alignment with other MXNet CIFAR-10 notebooks, but won’t dig into since the focus of this notebook is only on training).

The network itself is a ResNet-34 architecture imported from the Gluon Model Zoo.

[ ]:
!cat 'cifar10.py'

Train: Initial

Now that we’ve written our training script, we can submit it as a job to SageMaker training. Normally, we might test locally to ensure our script worked (See the MXNet CIFAR-10 local mode example), but since we already know the script works, we’ll skip that step.

Let’s see how our model performs with a naive guess for hyperparameter values. We’re training our network with stochastic gradient descent (SGD), which is an iterative method to minimize our training loss by finding the direction to change our network weights that improves training loss and then taking a small step in that direction and repeating. Since we’re using SGD, the three hyperparameters we’ll focus on will be:

  • learning_rate: which controls how large of steps we take.

  • momentum: which uses information from the direction of our previous step to inform our current step.

  • wd: which penalizes weights when they grow too large.

In this case, we’ll set the hyperparameters to MXNet’s default values.

[ ]:
m = MXNet(
    "cifar10.py",
    role=role,
    train_instance_count=1,
    train_instance_type="ml.p3.8xlarge",
    framework_version="1.4.1",
    py_version="py3",
    hyperparameters={
        "batch_size": 1024,
        "epochs": 50,
        "learning_rate": 0.01,
        "momentum": 0.0,
        "wd": 0.0,
    },
)

Now that we’ve constructed our MXNet object, we can fit it using the data we uploaded to S3.

[ ]:
m.fit(inputs)

As we can see, our accuracy is only about 53% on our validation dataset. CIFAR-10 can be challenging, but we’d want our accuracy much better than just over half if users are depending on an accurate prediction.


Tune: Random

One method of hyperparameter tuning that performs surprisingly well for how simple it is, is randomly trying a variety of hyperparameter values within set ranges. So, for this example, we’ve created a helper script random_tuner.py to help us do this.

We’ll need to supply:

  • A function that trains our MXNet model given a job name and list of hyperparameters. Note, wait is set to false in our fit() call so that we can train multiple jobs at once.

  • A dictionary of hyperparameters where the ones we want to tune are defined as one of three types (ContinuousParameter, IntegerParameter, or CategoricalParameter) and appropriate minimum and maximum ranges or a list of possible values are provided.

[ ]:
def fit_random(job_name, hyperparameters):
    m = MXNet(
        "cifar10.py",
        role=role,
        train_instance_count=1,
        train_instance_type="ml.p3.8xlarge",
        framework_version="1.4.1",
        py_version="py3",
        hyperparameters=hyperparameters,
    )
    m.fit(inputs, wait=False, job_name=job_name)
[ ]:
hyperparameters = {
    "batch_size": 1024,
    "epochs": 50,
    "learning_rate": rt.ContinuousParameter(0.001, 0.5),
    "momentum": rt.ContinuousParameter(0.0, 0.99),
    "wd": rt.ContinuousParameter(0.0, 0.001),
}

Next, we can kick off our random search. We’ve specified the total number of training jobs to be only 5 and the maximum number of parallel jobs to be 2. Below we have printed out the results from having used 120 training jobs and a maximum number of parallel jobs to be 8. But this is a large amount and could make the cost of this notebook at much as $400 to run. (That output will be overwritten when you run the tuning job with the new parameter values.) The larger set of values exceeds the default concurrent instance limit for ml.p3.8xlarge instances. For testing this notebook, the smaller values control costs and allow you to complete successfully without requiring a service limit increase.

Note, this step may take up to an hour to complete. Even if you loose connection with the notebook in the middle, as long as the notebook instance continues to run, ``jobs`` should still be successfully created for future use.

[ ]:
%%time

jobs = rt.random_search(fit_random, hyperparameters, max_jobs=5, max_parallel_jobs=2)

Once our random search completes, we’ll want to compare our training jobs (which may take a few extra minutes to finish) in order to understand how our objective metric (% accuracy on our validation dataset) varies by hyperparameter values. In this case, our helper function includes two functions.

  • get_metrics() scrapes the CloudWatch logs for our training jobs and uses a regex to return any reported values of our objective metric.

  • table_metrics() joins on the hyperparameter values for each job, grabs the ending objective value, and converts the result to a Pandas DataFrame.

[31]:
random_metrics = rt.table_metrics(jobs, rt.get_metrics(jobs, "validation: accuracy=([0-9\\.]+)"))
random_metrics.sort_values(["objective"], ascending=False)
[31]:
epochs wd batch_size objective learning_rate job_number momentum
random-hp-2018-07-08-20-06-10-189-17 50 0.000539 1024 0.736938 0.346114 17 0.231219
random-hp-2018-07-08-20-06-10-189-106 50 0.000658 1024 0.736572 0.203518 106 0.808102
random-hp-2018-07-08-20-06-10-189-78 50 0.000955 1024 0.735352 0.044036 78 0.962561
random-hp-2018-07-08-20-06-10-189-15 50 0.000304 1024 0.733887 0.187376 15 0.954231
random-hp-2018-07-08-20-06-10-189-16 50 0.000849 1024 0.733643 0.381012 16 0.049903
random-hp-2018-07-08-20-06-10-189-117 50 0.000708 1024 0.732544 0.314118 117 0.817854
random-hp-2018-07-08-20-06-10-189-70 50 0.000115 1024 0.732178 0.396326 70 0.510912
random-hp-2018-07-08-20-06-10-189-28 50 0.000593 1024 0.731689 0.398318 28 0.394819
random-hp-2018-07-08-20-06-10-189-2 50 0.000155 1024 0.731689 0.144946 2 0.924371
random-hp-2018-07-08-20-06-10-189-81 50 0.000810 1024 0.730103 0.351694 81 0.731069
random-hp-2018-07-08-20-06-10-189-104 50 0.000001 1024 0.729980 0.450279 104 0.100731
random-hp-2018-07-08-20-06-10-189-91 50 0.000324 1024 0.729980 0.289774 91 0.705745
random-hp-2018-07-08-20-06-10-189-80 50 0.000307 1024 0.729858 0.242544 80 0.544332
random-hp-2018-07-08-20-06-10-189-4 50 0.000478 1024 0.729736 0.158774 4 0.793798
random-hp-2018-07-08-20-06-10-189-20 50 0.000526 1024 0.729736 0.394191 20 0.826365
random-hp-2018-07-08-20-06-10-189-25 50 0.000636 1024 0.729126 0.293751 25 0.304069
random-hp-2018-07-08-20-06-10-189-40 50 0.000565 1024 0.728271 0.150810 40 0.795445
random-hp-2018-07-08-20-06-10-189-44 50 0.000305 1024 0.728271 0.222650 44 0.744344
random-hp-2018-07-08-20-06-10-189-77 50 0.000308 1024 0.727783 0.412024 77 0.049680
random-hp-2018-07-08-20-06-10-189-92 50 0.000830 1024 0.727295 0.401129 92 0.514144
random-hp-2018-07-08-20-06-10-189-101 50 0.000999 1024 0.726807 0.299126 101 0.545094
random-hp-2018-07-08-20-06-10-189-51 50 0.000371 1024 0.726074 0.413220 51 0.472064
random-hp-2018-07-08-20-06-10-189-10 50 0.000663 1024 0.726074 0.266806 10 0.671921
random-hp-2018-07-08-20-06-10-189-79 50 0.000176 1024 0.726074 0.204075 79 0.921536
random-hp-2018-07-08-20-06-10-189-66 50 0.000502 1024 0.724731 0.269668 66 0.951994
random-hp-2018-07-08-20-06-10-189-21 50 0.000166 1024 0.724487 0.262985 21 0.726774
random-hp-2018-07-08-20-06-10-189-75 50 0.000609 1024 0.724243 0.239251 75 0.336703
random-hp-2018-07-08-20-06-10-189-85 50 0.000403 1024 0.723267 0.140760 85 0.790066
random-hp-2018-07-08-20-06-10-189-105 50 0.000472 1024 0.723145 0.276357 105 0.608669
random-hp-2018-07-08-20-06-10-189-65 50 0.000160 1024 0.722900 0.147761 65 0.745701
... ... ... ... ... ... ... ...
random-hp-2018-07-08-20-06-10-189-96 50 0.000215 1024 0.676025 0.142288 96 0.153832
random-hp-2018-07-08-20-06-10-189-0 50 0.000354 1024 0.671997 0.030509 0 0.861382
random-hp-2018-07-08-20-06-10-189-112 50 0.000010 1024 0.671997 0.135736 112 0.232466
random-hp-2018-07-08-20-06-10-189-103 50 0.000895 1024 0.671265 0.094724 103 0.477036
random-hp-2018-07-08-20-06-10-189-27 50 0.000418 1024 0.670898 0.051749 27 0.619780
random-hp-2018-07-08-20-06-10-189-53 50 0.000646 1024 0.667236 0.064741 53 0.519927
random-hp-2018-07-08-20-06-10-189-39 50 0.000864 1024 0.665771 0.147716 39 0.013635
random-hp-2018-07-08-20-06-10-189-8 50 0.000274 1024 0.665649 0.093428 8 0.490986
random-hp-2018-07-08-20-06-10-189-116 50 0.000075 1024 0.660278 0.099005 116 0.364318
random-hp-2018-07-08-20-06-10-189-62 50 0.000040 1024 0.658936 0.114799 62 0.237339
random-hp-2018-07-08-20-06-10-189-54 50 0.000137 1024 0.658569 0.117171 54 0.121602
random-hp-2018-07-08-20-06-10-189-52 50 0.000580 1024 0.658447 0.074876 52 0.431588
random-hp-2018-07-08-20-06-10-189-69 50 0.000094 1024 0.656128 0.122641 69 0.118055
random-hp-2018-07-08-20-06-10-189-59 50 0.000820 1024 0.653442 0.116648 59 0.006298
random-hp-2018-07-08-20-06-10-189-48 50 0.000049 1024 0.652466 0.048692 48 0.516305
random-hp-2018-07-08-20-06-10-189-100 50 0.000995 1024 0.641968 0.473007 100 0.063985
random-hp-2018-07-08-20-06-10-189-64 50 0.000209 1024 0.640137 0.465702 64 0.928012
random-hp-2018-07-08-20-06-10-189-86 50 0.000641 1024 0.638550 0.077085 86 0.130483
random-hp-2018-07-08-20-06-10-189-82 50 0.000761 1024 0.638428 0.073119 82 0.193865
random-hp-2018-07-08-20-06-10-189-9 50 0.000321 1024 0.629272 0.049321 9 0.335158
random-hp-2018-07-08-20-06-10-189-58 50 0.000147 1024 0.617188 0.027019 58 0.584695
random-hp-2018-07-08-20-06-10-189-12 50 0.000417 1024 0.615845 0.045045 12 0.261383
random-hp-2018-07-08-20-06-10-189-7 50 0.000642 1024 0.606812 0.033248 7 0.465235
random-hp-2018-07-08-20-06-10-189-93 50 0.000527 1024 0.605591 0.043833 93 0.199631
random-hp-2018-07-08-20-06-10-189-95 50 0.000005 1024 0.568848 0.380469 95 0.948769
random-hp-2018-07-08-20-06-10-189-60 50 0.000769 1024 0.557129 0.018272 60 0.184786
random-hp-2018-07-08-20-06-10-189-84 50 0.000446 1024 0.530518 0.005893 84 0.299194
random-hp-2018-07-08-20-06-10-189-68 50 0.000088 1024 0.511475 0.309750 68 0.968020
random-hp-2018-07-08-20-06-10-189-83 50 0.000377 1024 0.233398 0.466878 83 0.152383
random-hp-2018-07-08-20-06-10-189-99 50 0.000059 1024 NaN 0.310720 99 0.817482

120 rows × 7 columns

As we can see, there’s a huge variation in percent accuracy. Had we initially (unknowingly) set our learning rate near 0.5, momentum at 0.15, and weight decay to 0.0004, we would have an accuracy just over 20% (this is particularly bad considering random guessing would produce 10% accuracy).

But, we also found many successful hyperparameter value combinations, and reached a peak validation accuracy of 74.5%. The peak occured midway through the run, but could have occurred anywhere within the 20 jobs and will change across multiple runs. We can see that with hyperparameter tuning our accuracy is well above the default value baseline of 53%.

To get a rough understanding of how the hyperparameter values relate to one another and the objective metric, let’s quickly plot them.

[32]:
pd.plotting.scatter_matrix(
    random_metrics[["objective", "learning_rate", "momentum", "wd", "job_number"]], figsize=(12, 12)
)
plt.show()
../../_images/hyperparameter_tuning_mxnet_gluon_cifar10_random_search_hyperparameter_tuning_mxnet_gluon_cifar10_random_search_25_0.png

The hyperparameter’s correlation with themselves and over time is essentially non-existent (which makes sense because we selected their values randomly). However, in general, we notice:

  • Very low learning_rates tend to do worse, although too high seems to add variability.

  • momentum seems to have less impact, with potentially a non-linear sweet spot near 0.8.

  • wd has a less consistent impact on accuracy than the other two hyperparameters.


Tune: Automatic Model Tuning

Now, let’s try using Amazon SageMaker’s Automatic Model Tuning. Rather than selecting hyperparameter values randomly, SageMaker builds a second machine learning model which, based on previous hyperparameter and objective metric values, predicts new values that might yield an improvement. This should allow us to train better models, faster and cheaper.

We’ll use the tuner functionality already built-in to the SageMaker Python SDK. Let’s start by defining a new MXNet estimator.

[ ]:
mt = MXNet(
    "cifar10.py",
    role=role,
    train_instance_count=1,
    train_instance_type="ml.p3.8xlarge",
    framework_version="1.4.1",
    py_version="py3",
    hyperparameters={"batch_size": 1024, "epochs": 50},
)

Now we can define our ranges (these take the same arguments as the classes from random_tuner).

[ ]:
hyperparameter_ranges = {
    "learning_rate": ContinuousParameter(0.001, 0.5),
    "momentum": ContinuousParameter(0.0, 0.99),
    "wd": ContinuousParameter(0.0, 0.001),
}

Now, we’ll define our objective metric and provide the regex needed to scrape it from our training jobs’ CloudWatch logs.

[ ]:
objective_metric_name = "Validation-accuracy"
metric_definitions = [{"Name": "Validation-accuracy", "Regex": "validation: accuracy=([0-9\\.]+)"}]

Now we can create a HyperparameterTuner object and fit it by pointing to our data in S3. This kicks our tuning job off in the background.

Notice, we specify a much smaller number of total jobs, and a smaller number of parallel jobs. Since our model uses previous training job runs to predict where to test next, we get better results (although it takes longer) when setting this to a smaller value.

[ ]:
tuner = HyperparameterTuner(
    mt,
    objective_metric_name,
    hyperparameter_ranges,
    metric_definitions,
    max_jobs=5,
    max_parallel_jobs=2,
)
[ ]:
tuner.fit(inputs)

You will be unable to successfully run the following cells until the tuning job completes. This step may take up to 2 hours.

Once the tuning job finishes, we can bring in a table of metrics.

[13]:
bayes_metrics = sagemaker.HyperparameterTuningJobAnalytics(tuner._current_job_name).dataframe()
bayes_metrics.sort_values(["FinalObjectiveValue"], ascending=False)
[13]:
FinalObjectiveValue TrainingElapsedTimeSeconds TrainingEndTime TrainingJobName TrainingJobStatus TrainingStartTime learning_rate momentum wd
12 0.739868 449.0 2018-07-08 06:32:39+00:00 sagemaker-mxnet-180708-0457-018-0672e23c Completed 2018-07-08 06:25:10+00:00 0.218998 7.807712e-01 0.000463
8 0.735352 484.0 2018-07-08 06:53:18+00:00 sagemaker-mxnet-180708-0457-022-602fe997 Completed 2018-07-08 06:45:14+00:00 0.491950 2.074798e-02 0.000330
11 0.734741 527.0 2018-07-08 06:42:58+00:00 sagemaker-mxnet-180708-0457-019-510b97e9 Completed 2018-07-08 06:34:11+00:00 0.495178 1.530975e-02 0.000864
14 0.732422 484.0 2018-07-08 06:22:48+00:00 sagemaker-mxnet-180708-0457-016-3c3a8322 Completed 2018-07-08 06:14:44+00:00 0.394037 7.685146e-01 0.000663
7 0.731079 486.0 2018-07-08 07:03:57+00:00 sagemaker-mxnet-180708-0457-023-feac622d Completed 2018-07-08 06:55:51+00:00 0.443138 6.272739e-01 0.000891
0 0.730591 423.0 2018-07-08 07:35:16+00:00 sagemaker-mxnet-180708-0457-030-b96d5377 Completed 2018-07-08 07:28:13+00:00 0.095172 9.210969e-01 0.000408
1 0.729370 468.0 2018-07-08 07:35:32+00:00 sagemaker-mxnet-180708-0457-029-25c7583d Completed 2018-07-08 07:27:44+00:00 0.400276 5.187911e-02 0.000631
28 0.729248 485.0 2018-07-08 05:07:29+00:00 sagemaker-mxnet-180708-0457-002-65adfb1f Completed 2018-07-08 04:59:24+00:00 0.308497 8.221134e-01 0.000838
22 0.728027 471.0 2018-07-08 05:40:08+00:00 sagemaker-mxnet-180708-0457-008-a1ecd004 Completed 2018-07-08 05:32:17+00:00 0.202243 7.614848e-01 0.000937
15 0.727173 490.0 2018-07-08 06:19:35+00:00 sagemaker-mxnet-180708-0457-015-305bc5d4 Completed 2018-07-08 06:11:25+00:00 0.399027 7.784146e-01 0.000653
18 0.726196 427.0 2018-07-08 06:02:44+00:00 sagemaker-mxnet-180708-0457-012-7dad3bd5 Completed 2018-07-08 05:55:37+00:00 0.383464 1.314750e-02 0.000765
16 0.725098 436.0 2018-07-08 06:12:13+00:00 sagemaker-mxnet-180708-0457-014-031a27ff Completed 2018-07-08 06:04:57+00:00 0.499998 1.880845e-07 0.000024
26 0.723877 508.0 2018-07-08 05:18:42+00:00 sagemaker-mxnet-180708-0457-004-4cdcdee1 Completed 2018-07-08 05:10:14+00:00 0.485280 8.946790e-01 0.000822
24 0.720825 532.0 2018-07-08 05:29:59+00:00 sagemaker-mxnet-180708-0457-006-8a90d72d Completed 2018-07-08 05:21:07+00:00 0.453483 6.834462e-01 0.000155
10 0.719360 457.0 2018-07-08 06:42:14+00:00 sagemaker-mxnet-180708-0457-020-8b0ac13f Completed 2018-07-08 06:34:37+00:00 0.408935 1.748506e-01 0.000529
9 0.719238 426.0 2018-07-08 06:52:40+00:00 sagemaker-mxnet-180708-0457-021-9752b883 Completed 2018-07-08 06:45:34+00:00 0.323677 1.015553e-01 0.000299
13 0.717773 515.0 2018-07-08 06:30:54+00:00 sagemaker-mxnet-180708-0457-017-40ce6e41 Completed 2018-07-08 06:22:19+00:00 0.288515 8.637047e-01 0.000709
2 0.716675 497.0 2018-07-08 07:25:28+00:00 sagemaker-mxnet-180708-0457-028-0fd8fcda Completed 2018-07-08 07:17:11+00:00 0.210008 6.103743e-01 0.000493
3 0.715942 530.0 2018-07-08 07:25:56+00:00 sagemaker-mxnet-180708-0457-027-99e8ad93 Completed 2018-07-08 07:17:06+00:00 0.303875 8.004766e-01 0.000370
17 0.714722 440.0 2018-07-08 06:08:50+00:00 sagemaker-mxnet-180708-0457-013-9338d9e2 Completed 2018-07-08 06:01:30+00:00 0.449889 1.005885e-01 0.000002
23 0.712402 473.0 2018-07-08 05:39:19+00:00 sagemaker-mxnet-180708-0457-007-8c6d6369 Completed 2018-07-08 05:31:26+00:00 0.293129 0.000000e+00 0.000823
4 0.709473 425.0 2018-07-08 07:13:34+00:00 sagemaker-mxnet-180708-0457-026-d328bcb0 Completed 2018-07-08 07:06:29+00:00 0.347373 7.542925e-01 0.000085
25 0.708130 504.0 2018-07-08 05:29:14+00:00 sagemaker-mxnet-180708-0457-005-6644667e Completed 2018-07-08 05:20:50+00:00 0.472519 5.562326e-01 0.000195
20 0.704956 493.0 2018-07-08 05:50:56+00:00 sagemaker-mxnet-180708-0457-010-b43d406a Completed 2018-07-08 05:42:43+00:00 0.138131 6.934573e-01 0.000000
6 0.695679 468.0 2018-07-08 07:03:36+00:00 sagemaker-mxnet-180708-0457-024-a1a9261a Completed 2018-07-08 06:55:48+00:00 0.460894 6.460432e-01 0.000909
5 0.690430 516.0 2018-07-08 07:15:03+00:00 sagemaker-mxnet-180708-0457-025-bd0557c9 Completed 2018-07-08 07:06:27+00:00 0.392321 8.076493e-01 0.000025
29 0.687378 522.0 2018-07-08 05:08:11+00:00 sagemaker-mxnet-180708-0457-001-cb6d6b71 Completed 2018-07-08 04:59:29+00:00 0.094162 5.240358e-01 0.000551
21 0.658325 615.0 2018-07-08 05:52:55+00:00 sagemaker-mxnet-180708-0457-009-f12240b7 Completed 2018-07-08 05:42:40+00:00 0.285951 9.703842e-01 0.000000
27 0.544189 478.0 2018-07-08 05:17:47+00:00 sagemaker-mxnet-180708-0457-003-47aaa565 Completed 2018-07-08 05:09:49+00:00 0.007274 3.568501e-01 0.000816
19 0.417725 374.0 2018-07-08 05:59:31+00:00 sagemaker-mxnet-180708-0457-011-492642e3 Completed 2018-07-08 05:53:17+00:00 0.460787 9.900000e-01 0.000985

Looking at our results, we can see that, with one fourth the total training jobs, SageMaker’s Automatic Model Tuning has produced a model with better accuracy 74% than our random search. In addition, there’s no guarantee that the effectiveness of random search wouldn’t change over subsequent runs.

Let’s compare our hyperparameter’s relationship to eachother and the objective metric.

[15]:
pd.plotting.scatter_matrix(
    pd.concat(
        [
            bayes_metrics[["FinalObjectiveValue", "learning_rate", "momentum", "wd"]],
            bayes_metrics["TrainingStartTime"].rank(),
        ],
        axis=1,
    ),
    figsize=(12, 12),
)
plt.show()
../../_images/hyperparameter_tuning_mxnet_gluon_cifar10_random_search_hyperparameter_tuning_mxnet_gluon_cifar10_random_search_38_0.png

We can see that: * There’s a range of reasonably good values for learning_rate, momentum, and wd, but there seem to be some very bad performers at both ends of the spectrum. * Later training jobs performed consistently better than early ones (SageMaker’s Automatic Model Tuning was learning and effectively exploring the space). * There appears to be somewhat less of a random relationship between the hyperparameter values our meta-model approach tested. This aligns with the knowledge that these hyperparameters are connected and that changing one can be offset by changing another.


Wrap-up

In this notebook, we saw the importance of hyperparameter tuning and discovered how much more effective Amazon SageMaker Automatic Model Tuning can be than random search. We could extend this example by testing another brute force method like grid search, tuning additional hyperparameters, using this first round of hyperparameter tuning to inform a secondary round of hyperparameter tuning where ranges have been narrowed down further, or applying this same comparison to your own problem.

For more information on using SageMaker’s Automatic Model Tuning, see our other example notebooks and documentation.