Amazon SageMaker LightGBM Distributed training using Dask

This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

1. Set Up

[ ]:
!pip install sagemaker ipywidgets --upgrade
[ ]:
import sagemaker, boto3, json
from sagemaker import get_execution_role

aws_role = get_execution_role()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()

bucket = sess.default_bucket()
prefix = "sagemaker/DEMO-churn-dt"
[ ]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import io
import os
import sys
import time
import json
from IPython.display import display
from time import strftime, gmtime
from sagemaker.inputs import TrainingInput
from sagemaker.serializers import CSVSerializer
from sklearn import preprocessing

2. Data Preparation and Visualization

Mobile operators have historical records on which customers ultimately ended up churning and which continued using the service. We can use this historical information to construct an ML model of one mobile operator’s churn using a process called training. After training the model, we can pass the profile information of an arbitrary customer (the same profile information that we used to train the model) to the model, and have the model predict whether this customer is going to churn. Of course, we expect the model to make mistakes. After all, predicting the future is tricky business! But we’ll learn how to deal with prediction errors.

The dataset we use is publicly available and was mentioned in the book Discovering Knowledge in Data by Daniel T. Larose. It is attributed by the author to the University of California Irvine Repository of Machine Learning Datasets. Let’s download and read that dataset in now:

[ ]:
s3 = boto3.client("s3")
[ ]:
churn = pd.read_csv("./churn.txt")
pd.set_option("display.max_columns", 500)

By modern standards, it’s a relatively small dataset, with only 5,000 records, where each record uses 21 attributes to describe the profile of a customer of an unknown US mobile operator. The attributes are:

State: the US state in which the customer resides, indicated by a two-letter abbreviation; for example, OH or NJ

Account Length: the number of days that this account has been active

Area Code: the three-digit area code of the corresponding customer’s phone number

Phone: the remaining seven-digit phone number

Int’l Plan: whether the customer has an international calling plan: yes/no

VMail Plan: whether the customer has a voice mail feature: yes/no

VMail Message: the average number of voice mail messages per month

Day Mins: the total number of calling minutes used during the day

Day Calls: the total number of calls placed during the day

Day Charge: the billed cost of daytime calls

Eve Mins, Eve Calls, Eve Charge: the billed cost for calls placed during the evening

Night Mins, Night Calls, Night Charge: the billed cost for calls placed during nighttime

Intl Mins, Intl Calls, Intl Charge: the billed cost for international calls

CustServ Calls: the number of calls placed to Customer Service

Churn?: whether the customer left the service: true/false

The last attribute, Churn?, is known as the target attribute: the attribute that we want the ML model to predict. Because the target attribute is binary, our model will be performing binary prediction, also known as binary classification.

Let’s begin exploring the data:

[ ]:
# Histograms for each numeric features
%matplotlib inline
hist = churn.hist(bins=30, sharey=True, figsize=(10, 10))

We can see immediately that: - State appears to be quite evenly distributed. - Phone takes on too many unique values to be of any practical use. It’s possible that parsing out the prefix could have some value, but without more context on how these are allocated, we should avoid using it. - Most of the numeric features are surprisingly nicely distributed, with many showing bell-like gaussianity. VMail Message is a notable exception.

[ ]:
churn = churn.drop("Phone", axis=1)
churn["Area Code"] = churn["Area Code"].astype(object)

Next let’s look at the relationship between each of the features and our target variable.

[ ]:
for column in churn.select_dtypes(include=["object"]).columns:
    if column != "Churn?":
        display(pd.crosstab(index=churn[column], columns=churn["Churn?"], normalize="columns"))

for column in churn.select_dtypes(exclude=["object"]).columns:
    hist = churn[[column, "Churn?"]].hist(by="Churn?", bins=30)

We convert the target attribute to binary value and move it to the first column of the dataset to meet requirements of SageMaker built-in tabular algorithms (For an example, see SageMaker LightGBM documentation).

[ ]:
churn["target"] = churn["Churn?"].map({"True.": 1, "False.": 0})
churn.drop(["Churn?"], axis=1, inplace=True)
[ ]:
churn = churn[["target"] + churn.columns.tolist()[:-1]]

We identify the column indexes of the categorical attribute, which is required by LightGBM, CatBoost, and TabTransformer algorithm (AutoGluon-Tabular has built-in feature engineering to identify the categorical attribute automatically, and thus does not require such input).

[ ]:
cat_columns = [
    "Account Length",
    "Area Code",
    "Int'l Plan",
    "VMail Plan",
    "VMail Message",
    "Day Calls",
    "Eve Calls",
    "Night Calls",
    "Intl Calls",
    "CustServ Calls",

cat_idx = []
for idx, col_name in enumerate(churn.columns.tolist()):
    if col_name in cat_columns:
[ ]:
with open("cat_idx.json", "w") as outfile:
    json.dump({"cat_idx": cat_idx}, outfile)

LightGBM official documentation requires that all categorical features should be encoded as non-negative integers.

[ ]:
for idx, col_name in enumerate(churn.columns.tolist()):
    if col_name in cat_columns:
        le = preprocessing.LabelEncoder()
        churn[col_name] = le.fit_transform(churn[col_name])

We split the churn dataset into train, validation, and test set using stratified sampling. Validation set is used for early stopping and AMT. Test set is used for performance evaluations in the end. Next, we upload them into a S3 path for training.

The structure of the S3 path for training should be structured as below.

  • The supported input data format for training is csv. You are allowed to put more than 1 data file under both train and valdiation channel. The name of data file can be any one as long as it ends with .csv.

  • The first column corresponds to the target and the rest of columns correspond to features. This follows the convention of SageMaker XGBoost algorithm.

  • The cat_idx.json is categorical column indexes. It contains a dictionary of a key-value pair. The key can be any string. The value is the list of column indexes of categorical features. The index starts with value 1 as value 0 corresponds to the target variable. Please see example above to format the cat_idx.json.

  • For the validation data, we encourage you to include one data file under its channel such that the all of the validation data points can be assigned to one machine. Thus, the validation score is for all of the validation data points and can be easily parsed by the AMT for hyperparameter optimization.

  • Current distributed training only supports CPU.

train      – data_1.csv      – data_2.csv      – data_3.csv      – cat_idx.json

validation      – data.csv

[ ]:
from sklearn.model_selection import train_test_split

train, val_n_test = train_test_split(
    churn, test_size=0.3, random_state=42, stratify=churn["target"]
[ ]:
val, test = train_test_split(
    val_n_test, test_size=0.3, random_state=42, stratify=val_n_test["target"]
[ ]:
train.to_csv("train.csv", header=False, index=False)
val.to_csv("validation.csv", header=False, index=False)
test.to_csv("test.csv", header=False, index=False)

For demonstartion purpose on including multiple files under the training channel, we simply duplicate the training data multiple times as shown below.

[ ]:
from tqdm import tqdm

for i in tqdm(range(200)):
        os.path.join(prefix, f"train/data_{i}.csv")
[ ]:
    os.path.join(prefix, "validation/data.csv")
[ ]:
    os.path.join(prefix, "test/data.csv")
[ ]:
    os.path.join(prefix, "train/cat_idx.json")

3. Distributedly Train A SageMaker LightGBM Model with AMT

3.1. Retrieve Training Artifacts

Here, we retrieve the training docker container, the training algorithm source, and the tabular algorithm. Note that model_version=”*” fetches the latest model.

For the training algorithm, we have four choices in this demonstration for classification task. * LightGBM: To use this algorithm, specify train_model_id as lightgbm-classification-model in the cell below.

For regression task, the train_model_id is lightgbm-regression-model.

[ ]:
from sagemaker import image_uris, model_uris, script_uris

train_model_id, train_model_version, train_scope = "lightgbm-classification-model", "*", "training"
training_instance_type = "ml.m5.4xlarge"

# Retrieve the docker image
train_image_uri = image_uris.retrieve(
# Retrieve the training script
train_source_uri = script_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, script_scope=train_scope
# Retrieve the pre-trained model tarball to further fine-tune
train_model_uri = model_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, model_scope=train_scope

3.2. Set Training Parameters

Now that we are done with all the setup that is needed, we are ready to train our tabular algorithm. To begin, let us create a `sageMaker.estimator.Estimator <>`__ object. This estimator will launch the training job.

There are two kinds of parameters that need to be set for training. The first one are the parameters for the training job. These include: (i) Training data path. This is S3 folder in which the input data is stored, (ii) Output path: This the s3 folder in which the training output is stored. (iii) Training instance type: This indicates the type of machine on which to run the training.

The second set of parameters are algorithm specific training hyper-parameters.

[ ]:
training_dataset_s3_path = f"s3://{bucket}/{prefix}/train"
validation_dataset_s3_path = f"s3://{bucket}/{prefix}/validation"

output_prefix = "jumpstart-example-tabular-training"
s3_output_location = f"s3://{bucket}/{output_prefix}/output_lgb"
[ ]:
from sagemaker import hyperparameters

# Retrieve the default hyper-parameters for fine-tuning the model
hyperparameters = hyperparameters.retrieve_default(
    model_id=train_model_id, model_version=train_model_version

# [Optional] Override default hyperparameters with custom values
hyperparameters["num_boost_round"] = "200"

hyperparameters["metric"] = "auc"
hyperparameters["tree_learner"] = "voting"  # use AllReduce method for distributed training

del hyperparameters[
]  # current distributed training with early stopping has some issues. See
# thus it is disabled for distributed training.

3.3. Train with Automatic Model Tuning

Amazon SageMaker automatic model tuning, also known as hyperparameter tuning, finds the best version of a model by running many training jobs on your dataset using the algorithm and ranges of hyperparameters that you specify. It then chooses the hyperparameter values that result in a model that performs the best, as measured by a metric that you choose. We will use a HyperparameterTuner object to interact with Amazon SageMaker hyperparameter tuning APIs.

  • Note. In this notebook, we set AMT budget (total tuning jobs) as 10 for each of the tabular algorithm except AutoGluon-Tabular. For AutoGluon-Tabular, it succeeds by ensembling multiple models and stacking them in multiple layers.

[ ]:
from sagemaker.tuner import ContinuousParameter, IntegerParameter, HyperparameterTuner

use_amt = True

hyperparameter_ranges_lgb = {
    "learning_rate": ContinuousParameter(1e-4, 1, scaling_type="Logarithmic"),
    "num_boost_round": IntegerParameter(2, 30),
    "num_leaves": IntegerParameter(10, 50),
    "feature_fraction": ContinuousParameter(0.1, 1),
    "bagging_fraction": ContinuousParameter(0.1, 1),
    "bagging_freq": IntegerParameter(1, 10),
    "max_depth": IntegerParameter(5, 30),
    "min_data_in_leaf": IntegerParameter(5, 50),

3.4. Start Training

[ ]:
from sagemaker.estimator import Estimator
import random

training_job_name = "jumpstart-distr-lgb-g" + str(int(time.time()))

# Create SageMaker Estimator instance
tabular_estimator = Estimator(
    instance_count=4,  ### select the instance count you would like to use for distributed training
    volume_size=30,  ### volume_size (int or PipelineVariable): Size in GB of the storage volume to use for storing input and output data during training (default: 30).

if use_amt:
    tuner = HyperparameterTuner(
        [{"Name": "auc", "Regex": "auc: ([0-9\\.]+)"}],
            "train": training_dataset_s3_path,
            "validation": validation_dataset_s3_path,
    # Launch a SageMaker Training job by passing s3 path of the training data
            "train": training_dataset_s3_path,
            "validation": validation_dataset_s3_path,

3.5. Deploy and Run Inference on the Trained Tabular Model

In this section, you learn how to query an existing endpoint and make predictions of the examples you input. For each example, the model will output the probability of the sample for each class in the model. Next, the predicted class label is obtained by taking the class label with the maximum probability over others.

We start by retrieving the artifacts and deploy the tabular_estimator that we trained.

[ ]:
inference_instance_type = "ml.m5.4xlarge"

# Retrieve the inference docker container uri
deploy_image_uri = image_uris.retrieve(
# Retrieve the inference script uri
deploy_source_uri = script_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, script_scope="inference"

endpoint_name = "jumpstart-distr-lgb-g" + str(int(time.time()))

# Use the estimator from the previous step to deploy to a SageMaker endpoint
predictor = (tuner if use_amt else tabular_estimator).deploy(
[ ]:
newline, bold, unbold = "\n", "\033[1m", "\033[0m"

import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

# read the data
test_data_file_name = "test.csv"
test_data = pd.read_csv(test_data_file_name, header=None)
test_data.columns = ["Target"] + [f"Feature_{i}" for i in range(1, test_data.shape[1])]

num_examples, num_columns = test_data.shape
    f"{bold}The test dataset contains {num_examples} examples and {num_columns} columns.{unbold}\n"

# prepare the ground truth target and predicting features to send into the endpoint.
ground_truth_label, features = test_data.iloc[:, :1], test_data.iloc[:, 1:]

print(f"{bold}The first 5 observations of the data: {unbold} \n")
[ ]:
content_type = "text/csv"

def query_endpoint(encoded_tabular_data, endpoint_name):
    client = boto3.client("runtime.sagemaker")
    response = client.invoke_endpoint(
    return response

def parse_response(query_response):
    model_predictions = json.loads(query_response["Body"].read())
    predicted_probabilities = model_predictions["probabilities"]
    return np.array(predicted_probabilities)

# split the test data into smaller size of batches to query the endpoint if test data has large size.
batch_size = 1500
predict_prob = []
for i in np.arange(0, num_examples, step=batch_size):
    query_response_batch = query_endpoint(
        features.iloc[i : (i + batch_size), :].to_csv(header=False, index=False).encode("utf-8"),
    predict_prob_batch = parse_response(query_response_batch)  # prediction probability per batch

predict_prob = np.concatenate(predict_prob, axis=0)
predict_label = np.argmax(predict_prob, axis=1)

3.6. Evaluate the Prediction Results Returned from the Endpoint

[ ]:
# Visualize the predictions results by plotting the confusion matrix.
conf_matrix = confusion_matrix(y_true=ground_truth_label.values, y_pred=predict_label)
fig, ax = plt.subplots(figsize=(7.5, 7.5))
ax.matshow(conf_matrix,, alpha=0.3)
for i in range(conf_matrix.shape[0]):
    for j in range(conf_matrix.shape[1]):
        ax.text(x=j, y=i, s=conf_matrix[i, j], va="center", ha="center", size="xx-large")

plt.xlabel("Predictions", fontsize=18)
plt.ylabel("Actuals", fontsize=18)
plt.title("Confusion Matrix", fontsize=18)
[ ]:
# Measure the prediction results quantitatively.
eval_accuracy = accuracy_score(ground_truth_label.values, predict_label)
eval_f1 = f1_score(ground_truth_label.values, predict_label)
eval_auc = roc_auc_score(ground_truth_label.values, predict_prob[:, 1])

lgb_results = pd.DataFrame.from_dict(
        "Accuracy": eval_accuracy,
        "F1": eval_f1,
        "AUC": eval_auc,
    columns=["LightGBM with AMT"],

[ ]:
# Delete the SageMaker endpoint and the attached resources

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable