Customer Churn Prediction with XGBoost

*Using Gradient Boosted Trees to Predict Mobile Customer Departure*



Contents

  1. Background

  2. Setup

  3. Data

  4. Train

  5. Compile

  6. Host

  7. Evaluate

  8. Relative cost of errors

  9. Extensions


Background

This notebook has been adapted from an `AWS blog post <https://aws.amazon.com/blogs/ai/predicting-customer-churn-with-amazon-machine-learning/>`__

Losing customers is costly for any business. Identifying unhappy customers early on gives you a chance to offer them incentives to stay. This notebook describes using machine learning (ML) for the automated identification of unhappy customers, also known as customer churn prediction. ML models rarely give perfect predictions though, so this notebook is also about how to incorporate the relative costs of prediction mistakes when determining the financial outcome of using ML.

We use an example of churn that is familiar to all of us–leaving a mobile phone operator. Seems like I can always find fault with my provider du jour! And if my provider knows that I’m thinking of leaving, it can offer timely incentives–I can always use a phone upgrade or perhaps have a new feature activated–and I might just stick around. Incentives are often much more cost effective than losing and reacquiring a customer.


Setup

This notebook was created and tested on an ml.m4.xlarge notebook instance.

Let’s start by specifying:

  • The S3 bucket and prefix that you want to use for training and model data. This should be within the same region as the Notebook Instance, training, and hosting.

  • The IAM role arn used to give training and hosting access to your data. See the documentation for how to create these. Note, if more than one role is required for notebook instances, training, and/or hosting, please replace the boto regexp with a the appropriate full IAM role arn string(s).

[2]:
import sagemaker

sess = sagemaker.Session()
bucket = sess.default_bucket()
prefix = "sagemaker/DEMO-xgboost-churn"

# Define IAM role
import boto3
import re
from sagemaker import get_execution_role

role = get_execution_role()

Next, we’ll import the Python libraries we’ll need for the remainder of the exercise.

[3]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import io
import os
import sys
import time
import json
from IPython.display import display
from time import strftime, gmtime
from sagemaker.inputs import TrainingInput
from sagemaker.serializers import CSVSerializer

Data

Mobile operators have historical records on which customers ultimately ended up churning and which continued using the service. We can use this historical information to construct an ML model of one mobile operator’s churn using a process called training. After training the model, we can pass the profile information of an arbitrary customer (the same profile information that we used to train the model) to the model, and have the model predict whether this customer is going to churn. Of course, we expect the model to make mistakes–after all, predicting the future is tricky business! But I’ll also show how to deal with prediction errors.

The dataset we use is publicly available and was mentioned in the book Discovering Knowledge in Data by Daniel T. Larose. It is attributed by the author to the University of California Irvine Repository of Machine Learning Datasets. Let’s download and read that dataset in now:

[4]:
!aws s3 cp s3://sagemaker-sample-files/datasets/tabular/synthetic/churn.txt ./
download: s3://sagemaker-sample-files/datasets/tabular/synthetic/churn.txt to ./churn.txt
[5]:
churn = pd.read_csv("./churn.txt")
pd.set_option("display.max_columns", 500)
churn
[5]:
State Account Length Area Code Phone Int'l Plan VMail Plan VMail Message Day Mins Day Calls Day Charge Eve Mins Eve Calls Eve Charge Night Mins Night Calls Night Charge Intl Mins Intl Calls Intl Charge CustServ Calls Churn?
0 PA 163 806 403-2562 no yes 300 8.162204 3 7.579174 3.933035 4 6.508639 4.065759 100 5.111624 4.928160 6 5.673203 3 True.
1 SC 15 836 158-8416 yes no 0 10.018993 4 4.226289 2.325005 0 9.972592 7.141040 200 6.436188 3.221748 6 2.559749 8 False.
2 MO 131 777 896-6253 no yes 300 4.708490 3 4.768160 4.537466 3 4.566715 5.363235 100 5.142451 7.139023 2 6.254157 4 False.
3 WY 75 878 817-5729 yes yes 700 1.268734 3 2.567642 2.528748 5 2.333624 3.773586 450 3.814413 2.245779 6 1.080692 6 False.
4 WY 146 878 450-4942 yes no 0 2.696177 3 5.908916 6.015337 3 3.670408 3.751673 250 2.796812 6.905545 4 7.134343 6 True.
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
4995 NH 4 787 151-3162 yes yes 800 10.862632 5 7.250969 6.936164 1 8.026482 4.921314 350 6.748489 4.872570 8 2.122530 9 False.
4996 SD 140 836 351-5993 no no 0 1.581127 8 3.758307 7.377591 7 1.328827 0.939932 300 4.522661 6.938571 2 4.600473 4 False.
4997 SC 32 836 370-3127 no yes 700 0.163836 5 4.243980 5.841852 3 2.340554 0.939469 450 5.157898 4.388328 7 1.060340 6 False.
4998 MA 142 776 604-2108 yes yes 600 2.034454 5 3.014859 4.140554 3 3.470372 6.076043 150 4.362780 7.173376 3 4.871900 7 True.
4999 AL 141 657 294-2849 yes yes 500 1.803907 0 5.125716 8.357508 0 2.109823 2.624299 400 3.713631 5.798783 6 5.485345 7 False.

5000 rows × 21 columns

[21]:
len(churn.columns)
[21]:
16

By modern standards, it’s a relatively small dataset, with only 5,000 records, where each record uses 21 attributes to describe the profile of a customer of an unknown US mobile operator. The attributes are:

  • State: the US state in which the customer resides, indicated by a two-letter abbreviation; for example, OH or NJ

  • Account Length: the number of days that this account has been active

  • Area Code: the three-digit area code of the corresponding customer’s phone number

  • Phone: the remaining seven-digit phone number

  • Int’l Plan: whether the customer has an international calling plan: yes/no

  • VMail Plan: whether the customer has a voice mail feature: yes/no

  • VMail Message: presumably the average number of voice mail messages per month

  • Day Mins: the total number of calling minutes used during the day

  • Day Calls: the total number of calls placed during the day

  • Day Charge: the billed cost of daytime calls

  • Eve Mins, Eve Calls, Eve Charge: the billed cost for calls placed during the evening

  • Night Mins, Night Calls, Night Charge: the billed cost for calls placed during nighttime

  • Intl Mins, Intl Calls, Intl Charge: the billed cost for international calls

  • CustServ Calls: the number of calls placed to Customer Service

  • Churn?: whether the customer left the service: true/false

The last attribute, Churn?, is known as the target attribute–the attribute that we want the ML model to predict. Because the target attribute is binary, our model will be performing binary prediction, also known as binary classification.

Let’s begin exploring the data:

[6]:
# Frequency tables for each categorical feature
for column in churn.select_dtypes(include=["object"]).columns:
    display(pd.crosstab(index=churn[column], columns="% observations", normalize="columns"))

# Histograms for each numeric features
display(churn.describe())
%matplotlib inline
hist = churn.hist(bins=30, sharey=True, figsize=(10, 10))
col_0 % observations
State
AK 0.0170
AL 0.0200
AR 0.0220
AZ 0.0180
CA 0.0208
CO 0.0182
CT 0.0178
DC 0.0224
DE 0.0182
FL 0.0178
GA 0.0166
HI 0.0190
IA 0.0206
ID 0.0222
IL 0.0198
IN 0.0190
KS 0.0158
KY 0.0182
LA 0.0202
MA 0.0208
MD 0.0226
ME 0.0148
MI 0.0202
MN 0.0220
MO 0.0212
MS 0.0212
MT 0.0180
NC 0.0190
ND 0.0160
NE 0.0218
NH 0.0188
NJ 0.0202
NM 0.0166
NV 0.0198
NY 0.0196
OH 0.0222
OK 0.0186
OR 0.0204
PA 0.0198
RI 0.0240
SC 0.0226
SD 0.0204
TN 0.0194
TX 0.0196
UT 0.0190
VA 0.0198
VT 0.0194
WA 0.0202
WI 0.0170
WV 0.0208
WY 0.0206
col_0 % observations
Phone
100-2030 0.0002
100-2118 0.0002
100-3505 0.0002
100-5224 0.0002
101-3371 0.0002
... ...
999-3178 0.0002
999-5498 0.0002
999-5816 0.0002
999-8494 0.0002
999-9817 0.0002

4999 rows × 1 columns

col_0 % observations
Int'l Plan
no 0.5014
yes 0.4986
col_0 % observations
VMail Plan
no 0.4976
yes 0.5024
col_0 % observations
Churn?
False. 0.5004
True. 0.4996
Account Length Area Code VMail Message Day Mins Day Calls Day Charge Eve Mins Eve Calls Eve Charge Night Mins Night Calls Night Charge Intl Mins Intl Calls Intl Charge CustServ Calls
count 5000.000000 5000.000000 5000.000000 5000.000000 5000.00000 5000.000000 5000.000000 5000.000000 5000.000000 5000.000000 5000.000000 5000.000000 5000.000000 5000.000000 5000.000000 5000.000000
mean 101.675800 773.791400 226.680000 5.518757 3.50460 5.018902 5.026199 3.140400 5.017557 4.000917 224.790000 5.023490 5.025876 5.475400 4.328242 5.525800
std 57.596762 63.470888 273.998527 3.433485 1.68812 2.195759 2.135487 2.525621 2.127857 1.631001 97.302875 1.748900 1.019302 1.877045 2.440311 2.041217
min 1.000000 657.000000 0.000000 0.000215 0.00000 0.004777 0.004659 0.000000 0.013573 0.008468 0.000000 0.054863 1.648514 0.000000 0.000769 0.000000
25% 52.000000 736.000000 0.000000 2.682384 2.00000 3.470151 3.588466 1.000000 3.529613 2.921998 150.000000 3.873157 4.349726 4.000000 2.468225 4.000000
50% 102.000000 778.000000 0.000000 5.336245 3.00000 4.988291 5.145656 3.000000 5.006860 3.962089 200.000000 5.169154 5.034905 5.000000 4.214058 6.000000
75% 151.000000 806.000000 400.000000 7.936151 5.00000 6.559750 6.552962 5.000000 6.491725 5.100128 300.000000 6.272015 5.716386 7.000000 5.960654 7.000000
max 200.000000 878.000000 1300.000000 16.897529 10.00000 12.731936 13.622097 14.000000 12.352871 10.183378 550.000000 10.407778 8.405644 12.000000 14.212261 13.000000
../../_images/introduction_to_applying_machine_learning_xgboost_customer_churn_xgboost_customer_churn_9_6.png

We can see immediately that: - State appears to be quite evenly distributed - Phone takes on too many unique values to be of any practical use. It’s possible parsing out the prefix could have some value, but without more context on how these are allocated, we should avoid using it. - Most of the numeric features are surprisingly nicely distributed, with many showing bell-like gaussianity. VMail Message being a notable exception (and Area Code showing up as a feature we should convert to non-numeric).

[7]:
churn = churn.drop("Phone", axis=1)
churn["Area Code"] = churn["Area Code"].astype(object)

Next let’s look at the relationship between each of the features and our target variable.

[8]:
for column in churn.select_dtypes(include=["object"]).columns:
    if column != "Churn?":
        display(pd.crosstab(index=churn[column], columns=churn["Churn?"], normalize="columns"))

for column in churn.select_dtypes(exclude=["object"]).columns:
    print(column)
    hist = churn[[column, "Churn?"]].hist(by="Churn?", bins=30)
    plt.show()
Churn? False. True.
State
AK 0.015588 0.018415
AL 0.021583 0.018415
AR 0.022782 0.021217
AZ 0.015588 0.020416
CA 0.020384 0.021217
CO 0.018785 0.017614
CT 0.015588 0.020016
DC 0.022382 0.022418
DE 0.018385 0.018014
FL 0.019984 0.015612
GA 0.017986 0.015212
HI 0.019185 0.018815
IA 0.018385 0.022818
ID 0.021583 0.022818
IL 0.021982 0.017614
IN 0.021583 0.016413
KS 0.014788 0.016813
KY 0.017186 0.019215
LA 0.020783 0.019616
MA 0.021183 0.020416
MD 0.019584 0.025620
ME 0.013589 0.016013
MI 0.018785 0.021617
MN 0.022782 0.021217
MO 0.020783 0.021617
MS 0.019584 0.022818
MT 0.017586 0.018415
NC 0.017186 0.020817
ND 0.017186 0.014812
NE 0.019185 0.024420
NH 0.019984 0.017614
NJ 0.022382 0.018014
NM 0.017186 0.016013
NV 0.023181 0.016413
NY 0.015188 0.024019
OH 0.019185 0.025220
OK 0.021183 0.016013
OR 0.019185 0.021617
PA 0.018785 0.020817
RI 0.024380 0.023619
SC 0.021583 0.023619
SD 0.021583 0.019215
TN 0.021982 0.016813
TX 0.019185 0.020016
UT 0.018385 0.019616
VA 0.021183 0.018415
VT 0.022382 0.016413
WA 0.021982 0.018415
WI 0.018785 0.015212
WV 0.019584 0.022018
WY 0.020783 0.020416
Churn? False. True.
Area Code
657 0.037170 0.036829
658 0.022782 0.021217
659 0.015588 0.020416
676 0.020384 0.021217
677 0.018785 0.017614
678 0.015588 0.020016
686 0.040767 0.040432
707 0.019984 0.015612
716 0.017986 0.015212
727 0.019185 0.018815
736 0.039968 0.045637
737 0.043565 0.034027
758 0.031974 0.036029
766 0.020783 0.019616
776 0.054357 0.062050
777 0.062350 0.064452
778 0.037170 0.041233
786 0.053557 0.060048
787 0.059552 0.051641
788 0.038369 0.040432
797 0.040368 0.041233
798 0.019185 0.021617
806 0.018785 0.020817
827 0.024380 0.023619
836 0.043165 0.042834
847 0.021982 0.016813
848 0.019185 0.020016
858 0.018385 0.019616
866 0.021183 0.018415
868 0.022382 0.016413
876 0.021982 0.018415
877 0.018785 0.015212
878 0.040368 0.042434
Churn? False. True.
Int'l Plan
no 0.5 0.502802
yes 0.5 0.497198
Churn? False. True.
VMail Plan
no 0.496403 0.498799
yes 0.503597 0.501201
Account Length
../../_images/introduction_to_applying_machine_learning_xgboost_customer_churn_xgboost_customer_churn_13_5.png
VMail Message
../../_images/introduction_to_applying_machine_learning_xgboost_customer_churn_xgboost_customer_churn_13_7.png
Day Mins
../../_images/introduction_to_applying_machine_learning_xgboost_customer_churn_xgboost_customer_churn_13_9.png
Day Calls
../../_images/introduction_to_applying_machine_learning_xgboost_customer_churn_xgboost_customer_churn_13_11.png
Day Charge
../../_images/introduction_to_applying_machine_learning_xgboost_customer_churn_xgboost_customer_churn_13_13.png
Eve Mins
../../_images/introduction_to_applying_machine_learning_xgboost_customer_churn_xgboost_customer_churn_13_15.png
Eve Calls
../../_images/introduction_to_applying_machine_learning_xgboost_customer_churn_xgboost_customer_churn_13_17.png
Eve Charge
../../_images/introduction_to_applying_machine_learning_xgboost_customer_churn_xgboost_customer_churn_13_19.png
Night Mins
../../_images/introduction_to_applying_machine_learning_xgboost_customer_churn_xgboost_customer_churn_13_21.png
Night Calls
../../_images/introduction_to_applying_machine_learning_xgboost_customer_churn_xgboost_customer_churn_13_23.png
Night Charge
../../_images/introduction_to_applying_machine_learning_xgboost_customer_churn_xgboost_customer_churn_13_25.png
Intl Mins
../../_images/introduction_to_applying_machine_learning_xgboost_customer_churn_xgboost_customer_churn_13_27.png
Intl Calls
../../_images/introduction_to_applying_machine_learning_xgboost_customer_churn_xgboost_customer_churn_13_29.png
Intl Charge
../../_images/introduction_to_applying_machine_learning_xgboost_customer_churn_xgboost_customer_churn_13_31.png
CustServ Calls
../../_images/introduction_to_applying_machine_learning_xgboost_customer_churn_xgboost_customer_churn_13_33.png
[9]:
display(churn.corr())
pd.plotting.scatter_matrix(churn, figsize=(12, 12))
plt.show()
Account Length VMail Message Day Mins Day Calls Day Charge Eve Mins Eve Calls Eve Charge Night Mins Night Calls Night Charge Intl Mins Intl Calls Intl Charge CustServ Calls
Account Length 1.000000 -0.009030 -0.015878 0.011659 -0.007468 0.000213 0.026515 -0.012795 0.016400 -0.002383 -0.034925 0.017277 -0.003735 0.028285 -0.036721
VMail Message -0.009030 1.000000 -0.143272 0.002762 -0.182712 -0.104667 -0.101240 -0.029212 0.061370 0.135042 -0.155475 -0.015162 0.131964 0.010120 0.068657
Day Mins -0.015878 -0.143272 1.000000 -0.087598 0.667941 0.482641 -0.184939 0.766489 0.188190 -0.445212 0.570508 0.001988 0.236131 0.239331 -0.195322
Day Calls 0.011659 0.002762 -0.087598 1.000000 -0.222556 0.033903 0.185881 -0.052051 -0.085222 -0.083050 0.046641 -0.022548 -0.045671 -0.120064 -0.065518
Day Charge -0.007468 -0.182712 0.667941 -0.222556 1.000000 0.574697 0.236626 0.371580 0.150700 -0.130722 0.374861 0.010294 0.119584 0.251748 -0.260945
Eve Mins 0.000213 -0.104667 0.482641 0.033903 0.574697 1.000000 -0.067123 0.269980 -0.090515 0.067315 0.317481 -0.015678 0.070456 0.448910 -0.167347
Eve Calls 0.026515 -0.101240 -0.184939 0.185881 0.236626 -0.067123 1.000000 -0.467814 0.221439 0.218149 -0.324936 -0.001593 -0.112062 0.017036 -0.433467
Eve Charge -0.012795 -0.029212 0.766489 -0.052051 0.371580 0.269980 -0.467814 1.000000 0.184230 -0.454649 0.546137 -0.003569 0.164104 0.243936 -0.011019
Night Mins 0.016400 0.061370 0.188190 -0.085222 0.150700 -0.090515 0.221439 0.184230 1.000000 -0.223023 -0.140482 -0.012781 0.038831 0.271179 -0.332802
Night Calls -0.002383 0.135042 -0.445212 -0.083050 -0.130722 0.067315 0.218149 -0.454649 -0.223023 1.000000 -0.390333 -0.009821 0.181237 -0.155736 0.110211
Night Charge -0.034925 -0.155475 0.570508 0.046641 0.374861 0.317481 -0.324936 0.546137 -0.140482 -0.390333 1.000000 0.012585 -0.009720 -0.330772 0.439805
Intl Mins 0.017277 -0.015162 0.001988 -0.022548 0.010294 -0.015678 -0.001593 -0.003569 -0.012781 -0.009821 0.012585 1.000000 -0.007220 -0.010907 -0.008672
Intl Calls -0.003735 0.131964 0.236131 -0.045671 0.119584 0.070456 -0.112062 0.164104 0.038831 0.181237 -0.009720 -0.007220 1.000000 -0.233809 -0.012260
Intl Charge 0.028285 0.010120 0.239331 -0.120064 0.251748 0.448910 0.017036 0.243936 0.271179 -0.155736 -0.330772 -0.010907 -0.233809 1.000000 -0.661833
CustServ Calls -0.036721 0.068657 -0.195322 -0.065518 -0.260945 -0.167347 -0.433467 -0.011019 -0.332802 0.110211 0.439805 -0.008672 -0.012260 -0.661833 1.000000
../../_images/introduction_to_applying_machine_learning_xgboost_customer_churn_xgboost_customer_churn_14_1.png

We see several features that essentially have 100% correlation with one another. Including these feature pairs in some machine learning algorithms can create catastrophic problems, while in others it will only introduce minor redundancy and bias. Let’s remove one feature from each of the highly correlated pairs: Day Charge from the pair with Day Mins, Night Charge from the pair with Night Mins, Intl Charge from the pair with Intl Mins:

[10]:
churn = churn.drop(["Day Charge", "Eve Charge", "Night Charge", "Intl Charge"], axis=1)

Now that we’ve cleaned up our dataset, let’s determine which algorithm to use. As mentioned above, there appear to be some variables where both high and low (but not intermediate) values are predictive of churn. In order to accommodate this in an algorithm like linear regression, we’d need to generate polynomial (or bucketed) terms. Instead, let’s attempt to model this problem using gradient boosted trees. Amazon SageMaker provides an XGBoost container that we can use to train in a managed, distributed setting, and then host as a real-time prediction endpoint. XGBoost uses gradient boosted trees which naturally account for non-linear relationships between features and the target variable, as well as accommodating complex interactions between features.

Amazon SageMaker XGBoost can train on data in either a CSV or LibSVM format. For this example, we’ll stick with CSV. It should: - Have the predictor variable in the first column - Not have a header row

But first, let’s convert our categorical features into numeric features.

[11]:
model_data = pd.get_dummies(churn)
model_data = pd.concat(
    [model_data["Churn?_True."], model_data.drop(["Churn?_False.", "Churn?_True."], axis=1)], axis=1
)

And now let’s split the data into training, validation, and test sets. This will help prevent us from overfitting the model, and allow us to test the models accuracy on data it hasn’t already seen.

[26]:
train_data, validation_data, test_data = np.split(
    model_data.sample(frac=1, random_state=1729),
    [int(0.7 * len(model_data)), int(0.9 * len(model_data))],
)
train_data.to_csv("train.csv", header=False, index=False)
validation_data.to_csv("validation.csv", header=False, index=False)
[29]:
len(train_data.columns)
[29]:
100

Now we’ll upload these files to S3.

[13]:
boto3.Session().resource("s3").Bucket(bucket).Object(
    os.path.join(prefix, "train/train.csv")
).upload_file("train.csv")
boto3.Session().resource("s3").Bucket(bucket).Object(
    os.path.join(prefix, "validation/validation.csv")
).upload_file("validation.csv")

Train

Moving onto training, first we’ll need to specify the locations of the XGBoost algorithm containers.

[15]:
container = sagemaker.image_uris.retrieve("xgboost", boto3.Session().region_name, "latest")
display(container)
'433757028032.dkr.ecr.us-west-2.amazonaws.com/xgboost:latest'

Then, because we’re training with the CSV file format, we’ll create TrainingInputs that our training function can use as a pointer to the files in S3.

[16]:
s3_input_train = TrainingInput(
    s3_data="s3://{}/{}/train".format(bucket, prefix), content_type="csv"
)
s3_input_validation = TrainingInput(
    s3_data="s3://{}/{}/validation/".format(bucket, prefix), content_type="csv"
)

Now, we can specify a few parameters like what type of training instances we’d like to use and how many, as well as our XGBoost hyperparameters. A few key hyperparameters are: - max_depth controls how deep each tree within the algorithm can be built. Deeper trees can lead to better fit, but are more computationally expensive and can lead to overfitting. There is typically some trade-off in model performance that needs to be explored between a large number of shallow trees and a smaller number of deeper trees. - subsample controls sampling of the training data. This technique can help reduce overfitting, but setting it too low can also starve the model of data. - num_round controls the number of boosting rounds. This is essentially the subsequent models that are trained using the residuals of previous iterations. Again, more rounds should produce a better fit on the training data, but can be computationally expensive or lead to overfitting. - eta controls how aggressive each round of boosting is. Larger values lead to more conservative boosting. - gamma controls how aggressively trees are grown. Larger values lead to more conservative models.

More detail on XGBoost’s hyperparmeters can be found on their GitHub page.

[17]:
sess = sagemaker.Session()

xgb = sagemaker.estimator.Estimator(
    container,
    role,
    instance_count=1,
    instance_type="ml.m4.xlarge",
    output_path="s3://{}/{}/output".format(bucket, prefix),
    sagemaker_session=sess,
)
xgb.set_hyperparameters(
    max_depth=5,
    eta=0.2,
    gamma=4,
    min_child_weight=6,
    subsample=0.8,
    silent=0,
    objective="binary:logistic",
    num_round=100,
)

xgb.fit({"train": s3_input_train, "validation": s3_input_validation})
2021-06-07 20:21:39 Starting - Starting the training job...
2021-06-07 20:21:41 Starting - Launching requested ML instancesProfilerReport-1623097299: InProgress
......
2021-06-07 20:23:07 Starting - Preparing the instances for training.........
2021-06-07 20:24:27 Downloading - Downloading input data...
2021-06-07 20:25:07 Training - Training image download completed. Training in progress..Arguments: train
[2021-06-07:20:25:07:INFO] Running standalone xgboost training.
[2021-06-07:20:25:07:INFO] File size need to be processed in the node: 1.16mb. Available memory size in the node: 8420.01mb
[2021-06-07:20:25:07:INFO] Determined delimiter of CSV input is ','
[20:25:07] S3DistributionType set as FullyReplicated
[20:25:07] 3500x99 matrix with 346500 entries loaded from /opt/ml/input/data/train?format=csv&label_column=0&delimiter=,
[2021-06-07:20:25:07:INFO] Determined delimiter of CSV input is ','
[20:25:07] S3DistributionType set as FullyReplicated
[20:25:07] 1000x99 matrix with 99000 entries loaded from /opt/ml/input/data/validation?format=csv&label_column=0&delimiter=,
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 42 extra nodes, 6 pruned nodes, max_depth=5
[0]#011train-error:0.116857#011validation-error:0.114
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 38 extra nodes, 12 pruned nodes, max_depth=5
[1]#011train-error:0.102857#011validation-error:0.1
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 40 extra nodes, 8 pruned nodes, max_depth=5
[2]#011train-error:0.097714#011validation-error:0.095
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 36 extra nodes, 8 pruned nodes, max_depth=5
[3]#011train-error:0.091714#011validation-error:0.092
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 32 extra nodes, 4 pruned nodes, max_depth=5
[4]#011train-error:0.083143#011validation-error:0.083
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 36 extra nodes, 4 pruned nodes, max_depth=5
[5]#011train-error:0.079714#011validation-error:0.084
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 38 extra nodes, 4 pruned nodes, max_depth=5
[6]#011train-error:0.077429#011validation-error:0.086
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 36 extra nodes, 6 pruned nodes, max_depth=5
[7]#011train-error:0.076286#011validation-error:0.08
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 28 extra nodes, 10 pruned nodes, max_depth=5
[8]#011train-error:0.073143#011validation-error:0.077
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 28 extra nodes, 6 pruned nodes, max_depth=5
[9]#011train-error:0.072286#011validation-error:0.074
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 34 extra nodes, 10 pruned nodes, max_depth=5
[10]#011train-error:0.068286#011validation-error:0.076
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 20 extra nodes, 4 pruned nodes, max_depth=5
[11]#011train-error:0.067143#011validation-error:0.076
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 28 extra nodes, 16 pruned nodes, max_depth=5
[12]#011train-error:0.066286#011validation-error:0.079
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 28 extra nodes, 12 pruned nodes, max_depth=5
[13]#011train-error:0.064286#011validation-error:0.079
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 30 extra nodes, 18 pruned nodes, max_depth=5
[14]#011train-error:0.064571#011validation-error:0.081
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 20 extra nodes, 4 pruned nodes, max_depth=5
[15]#011train-error:0.063143#011validation-error:0.078
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 26 extra nodes, 8 pruned nodes, max_depth=5
[16]#011train-error:0.062286#011validation-error:0.073
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 24 extra nodes, 4 pruned nodes, max_depth=5
[17]#011train-error:0.059714#011validation-error:0.072
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 2 pruned nodes, max_depth=5
[18]#011train-error:0.059429#011validation-error:0.071
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 22 extra nodes, 6 pruned nodes, max_depth=5
[19]#011train-error:0.058286#011validation-error:0.071
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 24 extra nodes, 10 pruned nodes, max_depth=5
[20]#011train-error:0.059714#011validation-error:0.07
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 8 pruned nodes, max_depth=5
[21]#011train-error:0.058857#011validation-error:0.071
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 14 pruned nodes, max_depth=5
[22]#011train-error:0.058571#011validation-error:0.072
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 4 pruned nodes, max_depth=5
[23]#011train-error:0.058571#011validation-error:0.069
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 12 extra nodes, 8 pruned nodes, max_depth=5
[24]#011train-error:0.057143#011validation-error:0.067
[20:25:07] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 24 extra nodes, 6 pruned nodes, max_depth=5
[25]#011train-error:0.051714#011validation-error:0.068
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 18 extra nodes, 10 pruned nodes, max_depth=5
[26]#011train-error:0.052#011validation-error:0.066
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 18 extra nodes, 4 pruned nodes, max_depth=5
[27]#011train-error:0.052286#011validation-error:0.065
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 14 pruned nodes, max_depth=4
[28]#011train-error:0.052#011validation-error:0.063
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 20 pruned nodes, max_depth=4
[29]#011train-error:0.052286#011validation-error:0.063
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 14 pruned nodes, max_depth=0
[30]#011train-error:0.052571#011validation-error:0.063
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 10 pruned nodes, max_depth=5
[31]#011train-error:0.050571#011validation-error:0.065
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 8 pruned nodes, max_depth=5
[32]#011train-error:0.050286#011validation-error:0.063
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 16 pruned nodes, max_depth=0
[33]#011train-error:0.050286#011validation-error:0.063
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 26 pruned nodes, max_depth=0
[34]#011train-error:0.050286#011validation-error:0.063
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 8 extra nodes, 12 pruned nodes, max_depth=4
[35]#011train-error:0.051143#011validation-error:0.063
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 16 extra nodes, 10 pruned nodes, max_depth=5
[36]#011train-error:0.050286#011validation-error:0.061
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 16 pruned nodes, max_depth=4
[37]#011train-error:0.050857#011validation-error:0.061
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 12 pruned nodes, max_depth=5
[38]#011train-error:0.048857#011validation-error:0.061
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 8 extra nodes, 14 pruned nodes, max_depth=3
[39]#011train-error:0.049143#011validation-error:0.061
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 12 extra nodes, 10 pruned nodes, max_depth=5
[40]#011train-error:0.048286#011validation-error:0.062
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 24 extra nodes, 16 pruned nodes, max_depth=5
[41]#011train-error:0.047714#011validation-error:0.064
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 18 extra nodes, 16 pruned nodes, max_depth=5
[42]#011train-error:0.047429#011validation-error:0.063
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 24 pruned nodes, max_depth=5
[43]#011train-error:0.047714#011validation-error:0.066
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 4 pruned nodes, max_depth=5
[44]#011train-error:0.046571#011validation-error:0.065
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 6 pruned nodes, max_depth=5
[45]#011train-error:0.046#011validation-error:0.066
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 12 pruned nodes, max_depth=0
[46]#011train-error:0.046#011validation-error:0.066
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 8 extra nodes, 16 pruned nodes, max_depth=4
[47]#011train-error:0.046#011validation-error:0.064
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 12 extra nodes, 10 pruned nodes, max_depth=4
[48]#011train-error:0.045143#011validation-error:0.06
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 4 extra nodes, 26 pruned nodes, max_depth=2
[49]#011train-error:0.045429#011validation-error:0.06
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 16 extra nodes, 12 pruned nodes, max_depth=5
[50]#011train-error:0.044#011validation-error:0.061
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 16 pruned nodes, max_depth=0
[51]#011train-error:0.044571#011validation-error:0.061
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 22 pruned nodes, max_depth=0
[52]#011train-error:0.044571#011validation-error:0.061
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 6 extra nodes, 14 pruned nodes, max_depth=3
[53]#011train-error:0.043714#011validation-error:0.061
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 18 pruned nodes, max_depth=0
[54]#011train-error:0.044#011validation-error:0.063
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 16 extra nodes, 6 pruned nodes, max_depth=5
[55]#011train-error:0.044#011validation-error:0.061
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 12 extra nodes, 2 pruned nodes, max_depth=5
[56]#011train-error:0.043143#011validation-error:0.058
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 12 pruned nodes, max_depth=5
[57]#011train-error:0.042571#011validation-error:0.06
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 12 pruned nodes, max_depth=0
[58]#011train-error:0.042571#011validation-error:0.06
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 22 pruned nodes, max_depth=4
[59]#011train-error:0.041714#011validation-error:0.062
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 26 pruned nodes, max_depth=0
[60]#011train-error:0.041714#011validation-error:0.062
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 14 pruned nodes, max_depth=0
[61]#011train-error:0.041714#011validation-error:0.062
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 6 extra nodes, 12 pruned nodes, max_depth=3
[62]#011train-error:0.042286#011validation-error:0.06
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 12 extra nodes, 16 pruned nodes, max_depth=4
[63]#011train-error:0.042857#011validation-error:0.061
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 16 pruned nodes, max_depth=0
[64]#011train-error:0.042571#011validation-error:0.061
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 12 pruned nodes, max_depth=0
[65]#011train-error:0.042857#011validation-error:0.061
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 12 extra nodes, 10 pruned nodes, max_depth=4
[66]#011train-error:0.041714#011validation-error:0.061
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 8 pruned nodes, max_depth=4
[67]#011train-error:0.041143#011validation-error:0.063
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 28 pruned nodes, max_depth=0
[68]#011train-error:0.041143#011validation-error:0.063
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 6 extra nodes, 16 pruned nodes, max_depth=3
[69]#011train-error:0.040286#011validation-error:0.062
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 14 pruned nodes, max_depth=0
[70]#011train-error:0.040286#011validation-error:0.062
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 6 extra nodes, 28 pruned nodes, max_depth=3
[71]#011train-error:0.04#011validation-error:0.061
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 20 pruned nodes, max_depth=0
[72]#011train-error:0.04#011validation-error:0.06
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 8 extra nodes, 22 pruned nodes, max_depth=4
[73]#011train-error:0.039429#011validation-error:0.061
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 16 pruned nodes, max_depth=0
[74]#011train-error:0.04#011validation-error:0.06
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 16 pruned nodes, max_depth=0
[75]#011train-error:0.039714#011validation-error:0.06
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 6 extra nodes, 24 pruned nodes, max_depth=3
[76]#011train-error:0.039429#011validation-error:0.06
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 20 pruned nodes, max_depth=0
[77]#011train-error:0.039143#011validation-error:0.06
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 24 pruned nodes, max_depth=0
[78]#011train-error:0.039429#011validation-error:0.06
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 12 extra nodes, 2 pruned nodes, max_depth=5
[79]#011train-error:0.040286#011validation-error:0.062
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 8 extra nodes, 22 pruned nodes, max_depth=4
[80]#011train-error:0.039429#011validation-error:0.061
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 8 extra nodes, 14 pruned nodes, max_depth=3
[81]#011train-error:0.039143#011validation-error:0.061
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 30 pruned nodes, max_depth=0
[82]#011train-error:0.039143#011validation-error:0.062
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 12 extra nodes, 14 pruned nodes, max_depth=5
[83]#011train-error:0.038857#011validation-error:0.063
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 18 pruned nodes, max_depth=0
[84]#011train-error:0.038286#011validation-error:0.063
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 18 pruned nodes, max_depth=0
[85]#011train-error:0.038286#011validation-error:0.063
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 8 extra nodes, 12 pruned nodes, max_depth=4
[86]#011train-error:0.037714#011validation-error:0.064
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 8 extra nodes, 12 pruned nodes, max_depth=4
[87]#011train-error:0.037429#011validation-error:0.063
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 8 extra nodes, 18 pruned nodes, max_depth=3
[88]#011train-error:0.037714#011validation-error:0.062
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 16 pruned nodes, max_depth=0
[89]#011train-error:0.037714#011validation-error:0.062
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 16 pruned nodes, max_depth=0
[90]#011train-error:0.037714#011validation-error:0.062
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 12 pruned nodes, max_depth=5
[91]#011train-error:0.037429#011validation-error:0.064
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 18 pruned nodes, max_depth=0
[92]#011train-error:0.037429#011validation-error:0.063
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 6 extra nodes, 14 pruned nodes, max_depth=3
[93]#011train-error:0.038286#011validation-error:0.063
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 20 pruned nodes, max_depth=3
[94]#011train-error:0.036286#011validation-error:0.061
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 26 pruned nodes, max_depth=0
[95]#011train-error:0.036286#011validation-error:0.061
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 26 pruned nodes, max_depth=0
[96]#011train-error:0.036571#011validation-error:0.061
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 14 extra nodes, 8 pruned nodes, max_depth=5
[97]#011train-error:0.036#011validation-error:0.063
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 0 extra nodes, 20 pruned nodes, max_depth=0
[98]#011train-error:0.036286#011validation-error:0.063
[20:25:08] src/tree/updater_prune.cc:74: tree pruning end, 1 roots, 10 extra nodes, 16 pruned nodes, max_depth=5
[99]#011train-error:0.036286#011validation-error:0.064

2021-06-07 20:25:27 Uploading - Uploading generated training model
2021-06-07 20:25:27 Completed - Training job completed
Training seconds: 62
Billable seconds: 62

Host

Now that we’ve trained the algorithm, let’s create a model and deploy it to a hosted endpoint.

[32]:
xgb_predictor = xgb.deploy(
    initial_instance_count=1, instance_type="ml.m4.xlarge", serializer=CSVSerializer()
)
------!

Evaluate

Now that we have a hosted endpoint running, we can make real-time predictions from our model very easily, simply by making an http POST request. But first, we’ll need to setup serializers and deserializers for passing our test_data NumPy arrays to the model behind the endpoint.

Now, we’ll use a simple function to: 1. Loop over our test dataset 1. Split it into mini-batches of rows 1. Convert those mini-batchs to CSV string payloads 1. Retrieve mini-batch predictions by invoking the XGBoost endpoint 1. Collect predictions and convert from the CSV output our model provides into a NumPy array

[34]:
def predict(data, rows=500):
    split_array = np.array_split(data, int(data.shape[0] / float(rows) + 1))
    predictions = ""
    for array in split_array:
        predictions = ",".join([predictions, xgb_predictor.predict(array).decode("utf-8")])

    return np.fromstring(predictions[1:], sep=",")


predictions = predict(test_data.to_numpy()[:, 1:])
[35]:
print(predictions)
[1.35331511e-01 9.82077539e-01 1.02474750e-03 4.05021757e-03
 3.92824411e-01 9.72177207e-01 9.91238415e-01 7.27955222e-01
 8.88106227e-01 9.90995705e-01 9.52816725e-01 1.19172679e-02
 3.58281261e-03 9.42603052e-01 9.89126205e-01 9.90786612e-01
 9.89062250e-01 6.96188286e-02 9.03073490e-01 9.86119509e-01
 9.62791502e-01 1.74614764e-03 5.25158504e-03 9.58006740e-01
 9.41856503e-01 1.67768657e-01 9.90760803e-01 2.84090172e-03
 9.95878816e-01 1.79207008e-02 5.64881638e-02 3.02827973e-02
 9.39479709e-01 5.42096049e-03 4.67779767e-03 1.76512660e-03
 5.74744403e-01 4.61890012e-01 9.82370138e-01 9.46011901e-01
 7.24184453e-01 9.79002118e-01 4.35622185e-01 9.47760344e-01
 1.98116549e-03 9.55070615e-01 1.80402584e-03 2.31066555e-01
 9.84024167e-01 2.06405832e-03 4.47656447e-03 9.76536214e-01
 9.76774156e-01 2.38812685e-01 8.62769634e-02 9.92498100e-01
 5.51647916e-02 1.81831885e-03 1.40864691e-02 6.51788637e-02
 3.24071288e-01 9.69274342e-01 8.70716691e-01 2.77617481e-02
 3.38171422e-03 9.86501753e-01 9.96587396e-01 2.11618710e-02
 9.75571871e-01 9.65472043e-01 9.96151865e-01 9.69069004e-01
 9.89180505e-01 2.63313879e-03 3.21035296e-01 1.08806137e-02
 2.64573544e-02 9.37089503e-01 6.73267897e-03 2.58733481e-01
 3.24417472e-01 9.28655446e-01 3.31248250e-03 8.90821934e-01
 9.61377740e-01 2.71464825e-01 8.27060938e-02 4.50845510e-02
 1.04574696e-03 2.53715012e-02 7.90762436e-03 5.23460889e-03
 2.56178831e-03 1.91231240e-02 1.13130454e-02 7.38239754e-03
 5.82699895e-01 9.73278880e-01 9.22638085e-03 8.47170115e-01
 9.90330100e-01 3.55595469e-01 3.20115268e-01 1.39923133e-02
 1.42177224e-01 6.04430676e-01 7.18148589e-01 5.39106596e-03
 9.80735362e-01 7.37583041e-01 5.03236465e-02 1.64943077e-02
 9.04652297e-01 9.81644273e-01 7.90716231e-01 9.85898018e-01
 7.62424946e-01 9.84377503e-01 2.63425987e-03 4.38998500e-03
 2.94039130e-01 4.28284109e-02 9.84231472e-01 5.86435676e-01
 3.08436388e-03 9.39401984e-01 1.30047444e-02 1.97171748e-01
 2.52563730e-02 8.10037076e-01 6.91544730e-03 9.95759904e-01
 5.84239900e-01 6.48281515e-01 2.54428573e-03 4.74239998e-02
 1.03579066e-03 1.68707268e-03 7.82937527e-01 9.83555257e-01
 2.19112285e-03 1.72078947e-03 9.97695029e-01 9.94287312e-01
 3.99078667e-01 5.59094250e-01 9.48836029e-01 8.70352332e-03
 3.88872786e-03 1.67273183e-03 3.71541269e-02 9.79852796e-01
 1.06622964e-01 9.77587759e-01 1.03806876e-01 9.88253295e-01
 9.95530903e-01 8.86603259e-03 3.61969098e-02 9.98754025e-01
 9.02955353e-01 1.37167017e-03 9.93803740e-01 2.26997677e-02
 9.87512410e-01 1.05077840e-01 2.04039039e-03 9.87286270e-01
 5.44086145e-03 9.65804696e-01 1.20876392e-03 4.99187075e-02
 9.29978013e-01 2.68173255e-02 1.61253046e-02 9.07470047e-01
 4.03423654e-03 3.44088417e-03 2.38039996e-02 9.94472027e-01
 9.74373102e-01 5.20768529e-03 9.79227364e-01 5.08057186e-03
 9.95805264e-01 6.65849864e-01 1.05473526e-01 9.93587852e-01
 5.58617711e-01 2.41559744e-02 5.02773046e-01 3.22041893e-03
 1.92995416e-03 9.00983155e-01 5.58514297e-01 9.27726090e-01
 5.46179354e-01 9.89919186e-01 5.56463718e-01 2.64735380e-03
 9.73722637e-01 1.13814147e-02 9.58212733e-01 1.30312620e-02
 9.89001274e-01 1.78968444e-01 8.57075810e-01 9.76885080e-01
 5.23749273e-03 3.24952036e-01 1.79608092e-02 2.49289963e-02
 8.60855356e-03 8.52289140e-01 3.09771717e-01 8.76666784e-01
 4.08283435e-02 9.35659528e-01 2.94428766e-01 9.79122996e-01
 9.62184668e-01 6.78600371e-03 9.05742466e-01 9.97201800e-01
 1.14960328e-01 9.08459008e-01 1.19753112e-03 5.32368302e-01
 9.58187222e-01 9.89705026e-01 3.39890003e-01 7.22918008e-03
 9.54267502e-01 2.33254582e-01 5.88130299e-03 6.44563744e-03
 8.89785029e-03 9.94546056e-01 3.72455828e-03 9.71175611e-01
 9.74044800e-01 2.12178566e-03 1.07673155e-02 4.76329684e-01
 9.96783376e-01 1.76805198e-01 9.86611724e-01 5.33065759e-03
 7.30689906e-04 9.94548619e-01 6.52407229e-01 9.34181631e-01
 9.91729438e-01 9.93662238e-01 9.83221948e-01 3.98660809e-01
 6.20443141e-03 7.74233043e-01 9.95967031e-01 6.36944234e-01
 1.17769570e-03 9.50848639e-01 8.04703474e-01 3.06982826e-03
 8.86053592e-03 6.01315558e-01 8.15589039e-04 5.72970212e-01
 6.05024816e-03 3.00737917e-01 4.88646049e-03 2.77840160e-03
 8.07540655e-01 9.92803991e-01 8.98140490e-01 9.72906947e-01
 6.92576945e-01 9.93186355e-01 1.08809747e-01 9.92834687e-01
 9.94885147e-01 3.36409220e-03 8.95250499e-01 5.41887462e-01
 3.20547223e-02 3.00191017e-03 5.35610989e-02 9.97282743e-01
 9.83131170e-01 9.19797793e-02 8.23473513e-01 9.59514558e-01
 9.32907164e-01 7.84317493e-01 9.17944372e-01 3.09938975e-02
 3.54256898e-01 9.92791533e-01 8.52789164e-01 2.45015230e-03
 9.89008367e-01 9.91619706e-01 9.90877688e-01 9.49390650e-01
 1.30883142e-01 9.88085508e-01 9.94880915e-01 3.64250340e-03
 1.00924395e-01 5.84139954e-04 9.90493596e-01 9.98544216e-01
 9.86899495e-01 2.92626163e-03 2.48177201e-01 2.34482929e-01
 8.31019282e-01 9.92612481e-01 8.60626280e-01 7.52302110e-01
 3.23741376e-01 7.65029073e-01 5.58793452e-03 9.93409157e-01
 9.38858867e-01 7.71634397e-04 1.67317316e-02 9.37093079e-01
 7.06784369e-04 9.98139381e-01 8.05393100e-01 5.16605470e-03
 9.78015721e-01 2.05255556e-03 9.12059784e-01 9.25935686e-01
 1.13790333e-02 9.89597142e-01 1.27113104e-01 5.78118026e-01
 9.96469259e-01 9.72206831e-01 9.76976454e-01 9.41750944e-01
 5.77228487e-01 2.19249935e-03 7.88132489e-01 9.81530845e-01
 4.33299989e-02 9.69189465e-01 9.94810939e-01 8.52230430e-01
 3.99795204e-01 6.53925359e-01 2.95181237e-02 3.39665934e-02
 7.94862688e-01 2.04434637e-02 1.28047034e-01 4.97525007e-01
 6.99097216e-01 2.78656930e-03 9.65144753e-01 9.88097727e-01
 9.90343571e-01 1.70749594e-02 9.78468716e-01 6.33870484e-04
 6.30041352e-03 1.74134392e-02 5.17328441e-01 6.64301068e-02
 9.65705335e-01 3.48356972e-03 9.93482172e-01 8.85248363e-01
 7.88051486e-01 5.09877980e-04 4.03372422e-02 9.36156988e-01
 9.61743534e-01 2.84105301e-01 3.35437730e-02 5.53360462e-01
 9.97811019e-01 5.75639447e-03 9.68811095e-01 6.68024877e-03
 1.88899226e-02 9.71860945e-01 9.77428854e-01 3.99328908e-03
 5.62287215e-03 7.96073318e-01 1.99078992e-02 9.77693915e-01
 3.26212645e-01 1.39212295e-01 9.93900180e-01 2.80859950e-03
 4.18449134e-01 9.85604823e-01 7.15814114e-01 9.74938989e-01
 3.95792350e-03 9.89029527e-01 1.45013705e-01 2.24798415e-02
 7.97536910e-01 1.10140465e-01 9.85174477e-01 9.57152069e-01
 8.49451423e-02 9.73486364e-01 7.53005641e-03 7.70264864e-01
 7.48203993e-01 3.52388084e-01 5.94211340e-01 9.95380521e-01
 8.92800987e-01 4.59158495e-02 1.18001075e-02 5.04979268e-02
 8.68584573e-01 8.19599867e-01 9.75251377e-01 1.37978035e-03
 2.46118824e-03 4.94047582e-01 9.51911092e-01 4.96846557e-01
 9.89559054e-01 6.39659760e-04 7.06940293e-02 9.64530110e-01
 9.38779950e-01 8.38272832e-03 9.19750333e-01 4.18015569e-02
 5.70869744e-01 1.06459521e-01 3.62761080e-01 9.55066621e-01
 8.14702332e-01 7.48456717e-02 9.53042647e-04 7.32033730e-01
 5.06092459e-02 8.94822478e-01 2.64998395e-02 1.04296561e-02
 9.36205387e-01 3.36960889e-03 9.75249887e-01 9.76234794e-01
 1.14756696e-01 1.49416341e-03 5.38111746e-01 9.96687353e-01
 1.12269195e-02 3.39865126e-02 9.89515066e-01 9.88028944e-01
 9.10718925e-03 1.09984174e-01 3.62294354e-03 9.03799057e-01
 3.25372512e-03 9.96353388e-01 6.18791813e-03 6.41223550e-01
 4.31292877e-03 3.69082741e-03 2.70273089e-02 1.50334686e-02
 1.34704104e-02 6.39035776e-02 4.14570980e-02 8.61872971e-01
 6.39807759e-03 5.89951640e-03 8.43912712e-04 9.54841495e-01
 9.53769982e-01 9.91130114e-01 9.01047885e-01 9.51251447e-01
 9.05901670e-01 4.49176356e-02 9.81808484e-01 3.27834845e-01
 1.14604257e-01 9.86102939e-01 5.43728583e-02 9.02637482e-01
 9.97059226e-01 9.33740377e-01 9.48212445e-01 9.88600552e-01]

There are many ways to compare the performance of a machine learning model, but let’s start by simply by comparing actual to predicted values. In this case, we’re simply predicting whether the customer churned (1) or not (0), which produces a simple confusion matrix.

[36]:
pd.crosstab(
    index=test_data.iloc[:, 0],
    columns=np.round(predictions),
    rownames=["actual"],
    colnames=["predictions"],
)
[36]:
predictions 0.0 1.0
actual
0 235 18
1 9 238

Note, due to randomized elements of the algorithm, you results may differ slightly.

Of the 48 churners, we’ve correctly predicted 39 of them (true positives). And, we incorrectly predicted 4 customers would churn who then ended up not doing so (false positives). There are also 9 customers who ended up churning, that we predicted would not (false negatives).

An important point here is that because of the np.round() function above we are using a simple threshold (or cutoff) of 0.5. Our predictions from xgboost come out as continuous values between 0 and 1 and we force them into the binary classes that we began with. However, because a customer that churns is expected to cost the company more than proactively trying to retain a customer who we think might churn, we should consider adjusting this cutoff. That will almost certainly increase the number of false positives, but it can also be expected to increase the number of true positives and reduce the number of false negatives.

To get a rough intuition here, let’s look at the continuous values of our predictions.

[37]:
plt.hist(predictions)
plt.show()
../../_images/introduction_to_applying_machine_learning_xgboost_customer_churn_xgboost_customer_churn_39_0.png

The continuous valued predictions coming from our model tend to skew toward 0 or 1, but there is sufficient mass between 0.1 and 0.9 that adjusting the cutoff should indeed shift a number of customers’ predictions. For example…

[38]:
pd.crosstab(index=test_data.iloc[:, 0], columns=np.where(predictions > 0.3, 1, 0))
[38]:
col_0 0 1
Churn?_True.
0 215 38
1 3 244

We can see that changing the cutoff from 0.5 to 0.3 results in 1 more true positives, 3 more false positives, and 1 fewer false negatives. The numbers are small overall here, but that’s 6-10% of customers overall that are shifting because of a change to the cutoff. Was this the right decision? We may end up retaining 3 extra customers, but we also unnecessarily incentivized 5 more customers who would have stayed. Determining optimal cutoffs is a key step in properly applying machine learning in a real-world setting. Let’s discuss this more broadly and then apply a specific, hypothetical solution for our current problem.

Relative cost of errors

Any practical binary classification problem is likely to produce a similarly sensitive cutoff. That by itself isn’t a problem. After all, if the scores for two classes are really easy to separate, the problem probably isn’t very hard to begin with and might even be solvable with simple rules instead of ML.

More important, if I put an ML model into production, there are costs associated with the model erroneously assigning false positives and false negatives. I also need to look at similar costs associated with correct predictions of true positives and true negatives. Because the choice of the cutoff affects all four of these statistics, I need to consider the relative costs to the business for each of these four outcomes for each prediction.

Assigning costs

What are the costs for our problem of mobile operator churn? The costs, of course, depend on the specific actions that the business takes. Let’s make some assumptions here.

First, assign the true negatives the cost of $0. Our model essentially correctly identified a happy customer in this case, and we don’t need to do anything.

False negatives are the most problematic, because they incorrectly predict that a churning customer will stay. We lose the customer and will have to pay all the costs of acquiring a replacement customer, including foregone revenue, advertising costs, administrative costs, point of sale costs, and likely a phone hardware subsidy. A quick search on the Internet reveals that such costs typically run in the hundreds of dollars so, for the purposes of this example, let’s assume $500. This is the cost of false negatives.

Finally, for customers that our model identifies as churning, let’s assume a retention incentive in the amount of $100. If my provider offered me such a concession, I’d certainly think twice before leaving. This is the cost of both true positive and false positive outcomes. In the case of false positives (the customer is happy, but the model mistakenly predicted churn), we will “waste” the $100 concession. We probably could have spent that $100 more effectively, but it’s possible we increased the loyalty of an already loyal customer, so that’s not so bad.

Finding the optimal cutoff

It’s clear that false negatives are substantially more costly than false positives. Instead of optimizing for error based on the number of customers, we should be minimizing a cost function that looks like this:

$500 * FN(C) + $0 * TN(C) + $100 * FP(C) + $100 * TP(C)

FN(C) means that the false negative percentage is a function of the cutoff, C, and similar for TN, FP, and TP. We need to find the cutoff, C, where the result of the expression is smallest.

A straightforward way to do this, is to simply run a simulation over a large number of possible cutoffs. We test 100 possible values in the for loop below.

[39]:
cutoffs = np.arange(0.01, 1, 0.01)
costs = []
for c in cutoffs:
    costs.append(
        np.sum(
            np.sum(
                np.array([[0, 100], [500, 100]])
                * pd.crosstab(index=test_data.iloc[:, 0], columns=np.where(predictions > c, 1, 0))
            )
        )
    )

costs = np.array(costs)
plt.plot(cutoffs, costs)
plt.show()
../../_images/introduction_to_applying_machine_learning_xgboost_customer_churn_xgboost_customer_churn_44_0.png
[40]:
print(
    "Cost is minimized near a cutoff of:",
    cutoffs[np.argmin(costs)],
    "for a cost of:",
    np.min(costs),
)
Cost is minimized near a cutoff of: 0.37 for a cost of: 29200

The above chart shows how picking a threshold too low results in costs skyrocketing as all customers are given a retention incentive. Meanwhile, setting the threshold too high results in too many lost customers, which ultimately grows to be nearly as costly. The overall cost can be minimized at $8400 by setting the cutoff to 0.46, which is substantially better than the $20k+ I would expect to lose by not taking any action.


Extensions

This notebook showcased how to build a model that predicts whether a customer is likely to churn, and then how to optimally set a threshold that accounts for the cost of true positives, false positives, and false negatives. There are several means of extending it including: - Some customers who receive retention incentives will still churn. Including a probability of churning despite receiving an incentive in our cost function would provide a better ROI on our retention programs. - Customers who switch to a lower-priced plan or who deactivate a paid feature represent different kinds of churn that could be modeled separately. - Modeling the evolution of customer behavior. If usage is dropping and the number of calls placed to Customer Service is increasing, you are more likely to experience churn then if the trend is the opposite. A customer profile should incorporate behavior trends. - Actual training data and monetary cost assignments could be more complex. - Multiple models for each type of churn could be needed.

Regardless of additional complexity, similar principles described in this notebook are likely apply.

(Optional) Clean-up

If you’re ready to be done with this notebook, please run the cell below. This will remove the hosted endpoint you created and avoid any charges from a stray instance being left on.

[41]:
xgb_predictor.delete_endpoint()