Text Classification with Amazon SageMaker HuggingFace and Hyperparameter Tuning

This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

Automatic model tuning, also known as hyperparameter tuning, finds the best version of a model by running many jobs that test a range of hyperparameters on your dataset. You choose the tunable hyperparameters, a range of values for each, and an objective metric. You choose the objective metric from the metrics that the algorithm computes. Automatic model tuning searches the hyperparameters chosen to find the combination of values that result in the model that optimizes the objective metric.


Text Classification can be used to solve various use-cases like sentiment analysis, spam detection, hashtag prediction etc.

This notebook demonstrates the use of the HuggingFace Transformers library together with a custom Amazon sagemaker-sdk extension to fine-tune a pre-trained transformer on multi class text classification. In particular, the pre-trained model will be fine-tuned using the 20 Newsgroups dataset. To get started, we need to set up the environment with a few prerequisite steps, for permissions, configurations, and so on.

Install Python packages

[ ]:
import sys

!{sys.executable} -m pip install "scikit_learn==1.2.2" "sagemaker==2.48.0" "transformers==4.6.1" "datasets[s3]==1.6.2" "nltk==3.4.4"

If you run this notebook in SageMaker Studio, you need to make sure ipywidgets is installed and restart the kernel, so please uncomment the code in the next cell, and run it.

[ ]:
# %%capture
# import IPython
# import sys

!{sys.executable} -m pip install ipywidgets
# IPython.Application.instance().kernel.do_shutdown(True)  # has to restart kernel so changes are used


Let’s start by specifying:

  • The S3 bucket and prefix that you want to use for training and model data. This should be within the same region as the Notebook Instance, training, and hosting. If you don’t specify a bucket, SageMaker SDK will create a default bucket following a pre-defined naming convention in the same region.

  • The IAM role ARN used to give SageMaker access to your data. It can be fetched using the get_execution_role method from sagemaker python SDK.

[ ]:
import sagemaker
from sagemaker import get_execution_role
import json
import boto3
import pandas as pd
import re
import string
from sklearn.model_selection import train_test_split
import sagemaker.huggingface

sess = sagemaker.Session()
region = sess.boto_region_name

role = get_execution_role()
)  # This is the role that SageMaker would use to leverage AWS resources (S3, CloudWatch) on your behalf

bucket = sess.default_bucket()  # Replace with your own bucket name if needed
s3_prefix = "huggingface/20_newsgroups"  # Replace with the prefix under which you want to store the data if needed

Data Preparation

Now we’ll download a dataset from the web on which we want to train the text classification model.

In this example, let us train the text classification model on the 20 Newsgroups dataset. The 20 Newsgroups dataset consists of 20000 messages taken from 20 Usenet newsgroups.

[ ]:
import os
import shutil

data_dir = "20_newsgroups_bulk"
if os.path.exists(data_dir):  # cleanup existing data folder
[ ]:
s3 = boto3.client("s3")
[ ]:
!tar xzf 20_newsgroups_bulk.tar.gz --no-same-owner
!ls 20_newsgroups_bulk
[ ]:
file_list = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]
print("Number of files:", len(file_list))
[ ]:
documents_count = 0
for file in file_list:
    df = pd.read_csv(file, header=None, names=["text"])
    documents_count = documents_count + df.shape[0]
print("Number of documents:", documents_count)

Let’s inspect the dataset files and analyze the categories.

[ ]:
categories_list = [f.split("/")[1] for f in file_list]
[ ]:

We can see that the dataset consists of 20 topics, each in different file.

Let us inspect the dataset to get some understanding about how the data and the label is provided in the dataset.

[ ]:
df = pd.read_csv("./20_newsgroups_bulk/rec.motorcycles", header=None, names=["text"])
[ ]:
[ ]:
df = pd.read_csv("./20_newsgroups_bulk/comp.sys.mac.hardware", header=None, names=["text"])
[ ]:

As we can see from the above, there is a single file for each class in the dataset. Each record is just a plain text paragraphs with header, body, footer and quotes. We will need to process them into a suitable data format.

Data Preprocessing

We need to preprocess the dataset to remove the header, footer, quotes, leading/trailing whitespace, extra spaces, tabs, and HTML tags/markups.

Download the nltk tokenizer and other libraries

[ ]:
import nltk
from nltk.tokenize import word_tokenize

[ ]:
from sklearn.datasets._twenty_newsgroups import (

This following function will remove the header, footer and quotes (of earlier messages in each text).

[ ]:
def strip_newsgroup_item(item):
    item = strip_newsgroup_header(item)
    item = strip_newsgroup_quoting(item)
    item = strip_newsgroup_footer(item)
    return item

The following function will take care of removing leading/trailing whitespace, extra spaces, tabs, and HTML tags/markups.

[ ]:
def process_text(texts):
    final_text_list = []
    for text in texts:
        # Check if the sentence is a missing value
        if isinstance(text, str) == False:
            text = ""

        filtered_sentence = []

        # Lowercase
        text = text.lower()

        # Remove leading/trailing whitespace, extra space, tabs, and HTML tags/markups
        text = text.strip()
        text = re.sub("\[.*?\]", "", text)
        text = re.sub("https?://\S+|www\.\S+", "", text)
        text = re.sub("<.*?>+", "", text)
        text = re.sub("[%s]" % re.escape(string.punctuation), "", text)
        text = re.sub("\n", "", text)
        text = re.sub("\w*\d\w*", "", text)

        for w in word_tokenize(text):
            # We are applying some custom filtering here, feel free to try different things
            # Check if it is not numeric
            if not w.isnumeric():
        final_string = " ".join(filtered_sentence)  # final string of cleaned words


    return final_text_list

Now we will read each of the 20_newsgroups dataset files, call strip_newsgroup_item and process_text functions we defined earlier, and then aggregate all data into one dataframe.

[ ]:
all_categories_df = pd.DataFrame()

for file in file_list:
    print(f"Processing {file}")
    label = file.split("/")[1]
    df = pd.read_csv(file, header=None, names=["text"])
    df["text"] = df["text"].apply(strip_newsgroup_item)
    df["text"] = process_text(df["text"].tolist())
    df["label"] = label
    all_categories_df = all_categories_df.append(df, ignore_index=True)

Let’s inspect how many categories there are in our dataset.

[ ]:

In our dataset there are 20 categories which is too much, so we will combine the sub-categories.

[ ]:
# replace to politics
        "talk.politics.misc": "politics",
        "talk.politics.guns": "politics",
        "talk.politics.mideast": "politics",

# replace to recreational
        "rec.sport.hockey": "recreational",
        "rec.sport.baseball": "recreational",
        "rec.autos": "recreational",
        "rec.motorcycles": "recreational",

# replace to religion
        "soc.religion.christian": "religion",
        "talk.religion.misc": "religion",
        "alt.atheism": "religion",

# replace to computer
        "comp.windows.x": "computer",
        "comp.sys.ibm.pc.hardware": "computer",
        "comp.os.ms-windows.misc": "computer",
        "comp.graphics": "computer",
        "comp.sys.mac.hardware": "computer",
# replace to sales
all_categories_df["label"].replace({"misc.forsale": "sales"}, inplace=True)

# replace to science
        "sci.crypt": "science",
        "sci.electronics": "science",
        "sci.med": "science",
        "sci.space": "science",

Now we are left with 6 categories, which is much better.

[ ]:

Let’s calculate number of words for each row.

[ ]:
all_categories_df["word_count"] = all_categories_df["text"].apply(lambda x: len(str(x).split()))

Let’s get basic statistics about the dataset.

[ ]:

We can see that the mean value is around 159 words. However, there are outliers, such as a text with 11351 words. This can make it harder for the model to result in good performance. We will take care to drop those rows.

Let’s drop empty rows first.

[ ]:
no_text = all_categories_df[all_categories_df["word_count"] == 0]

# drop these rows
all_categories_df.drop(no_text.index, inplace=True)

Let’s drop the rows that are longer than 256 words, as it is a length close to the mean value of the word count. This is done to make it easy for the model to train without outliers.

[ ]:
long_text = all_categories_df[all_categories_df["word_count"] > 256]

# drop these rows
all_categories_df.drop(long_text.index, inplace=True)
[ ]:

Let’s get basic statistics about the dataset after our outliers fixes.

[ ]:

This looks much more balanced.

Now we drop the word_count columns as we will not need it anymore.

[ ]:
all_categories_df.drop(columns="word_count", axis=1, inplace=True)
[ ]:

Let’s convert categorical label to integer number, in order to prepare the dataset for training.

[ ]:
categories = all_categories_df["label"].unique().tolist()
[ ]:
[ ]:
all_categories_df["label"] = all_categories_df["label"].apply(lambda x: categories.index(x))
[ ]:

We partition the dataset into 80% training and 20% validation set and save to csv files.

[ ]:
train_df, test_df = train_test_split(all_categories_df, test_size=0.2)
[ ]:
train_df.to_csv("train.csv", index=None)
[ ]:
test_df.to_csv("test.csv", index=None)

Let’s inspect the label distribution in the training dataset

[ ]:

Let’s inspect the label distribution in the test dataset

[ ]:


A tokenizer is in charge of preparing the inputs for a model. The library contains tokenizers for all the models. Most of the tokenizers are available in two flavors: a full python implementation and a “Fast” implementation based on the Rust library tokenizers. The “Fast” implementations allows:

  • A significant speed-up in particular when doing batched tokenization.

  • Additional methods to map between the original string (character and words) and the token space (e.g. getting the index of the token comprising a given character or the span of characters corresponding to a given token).

[ ]:
from datasets import load_dataset
from transformers import AutoTokenizer

# tokenizer used in preprocessing
tokenizer_name = "distilbert-base-uncased"
[ ]:
# download tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

Load train and test datasets

Let’s create a Dataset from our local csv files for training and test we saved earlier.

[ ]:
dataset = load_dataset("csv", data_files={"train": "train.csv", "test": "test.csv"})
[ ]:
[ ]:
[ ]:
[ ]:
[ ]:
[ ]:
# tokenizer helper function
def tokenize(batch):
    return tokenizer(batch["text"], padding="max_length", truncation=True)
[ ]:
train_dataset = dataset["train"]
test_dataset = dataset["test"]

Tokenize train and test datasets

Let’s tokenize the train dataset

[ ]:
train_dataset = train_dataset.map(tokenize, batched=True)

Let’s tokenize the test dataset

[ ]:
test_dataset = test_dataset.map(tokenize, batched=True)

Set format for PyTorch

[ ]:
train_dataset = train_dataset.rename_column("label", "labels")
train_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
test_dataset = test_dataset.rename_column("label", "labels")
test_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

Uploading data to sagemaker_session_bucket

After we processed the datasets, we are going to upload it to S3.

[ ]:
import botocore
from datasets.filesystems import S3FileSystem

s3 = S3FileSystem()

# save train_dataset to s3
training_input_path = f"s3://{sess.default_bucket()}/{s3_prefix}/train"
train_dataset.save_to_disk(training_input_path, fs=s3)

# save test_dataset to s3
test_input_path = f"s3://{sess.default_bucket()}/{s3_prefix}/test"
test_dataset.save_to_disk(test_input_path, fs=s3)


Set up hyperparameter tuning job

Now that we are done with all the setup that is needed, we are ready to train our HuggingFace model. To begin, let us create a HuggingFace estimator object. This estimator will launch the training job.

Training the HuggingFace model for supervised text classification

In order to create a sagemaker training job we need a HuggingFace Estimator. The Estimator handles end-to-end Amazon SageMaker training and deployment tasks. In an Estimator we define, which fine-tuning script should be used as entry_point, which instance_type should be used, which hyperparameters are passed in …..

huggingface_estimator = HuggingFace(entry_point='train.py',
                            hyperparameters = {'epochs': 1,
                                               'num_labels': 6

When we create a SageMaker training job, SageMaker takes care of starting and managing all the required ec2 instances for us with the huggingface container, uploads the provided fine-tuning script train.py and downloads the data from our sagemaker_session_bucket into the container at /opt/ml/input/data. Then, it starts the training job by running.

/opt/conda/bin/python train.py --epochs 1 --model_name distilbert-base-uncased --num_labels 6

The hyperparameters you define in the HuggingFace estimator are passed in as named arguments.

SageMaker is providing useful properties about the training environment through various environment variables, including the following:

  • SM_MODEL_DIR: A string that represents the path where the training job writes the model artifacts to. After training, artifacts in this directory are uploaded to S3 for model hosting.

  • SM_NUM_GPUS: An integer representing the number of GPUs available to the host.

  • SM_CHANNEL_XXXX: A string that represents the path to the directory that contains the input data for the specified channel. For example, if you specify two input channels in the HuggingFace estimator’s fit call, named train and test, the environment variables SM_CHANNEL_TRAIN and SM_CHANNEL_TEST are set.

To run your training job locally you can define instance_type='local' or instance_type='local-gpu' for gpu usage. Note: this does not work within SageMaker Studio

We create a metric_definition dictionary that contains regex-based definitions that will be used to parse the job logs and extract metrics

[ ]:
metric_definitions = [
    {"Name": "loss", "Regex": "'loss': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "learning_rate", "Regex": "'learning_rate': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "eval_loss", "Regex": "'eval_loss': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "eval_accuracy", "Regex": "'eval_accuracy': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "eval_f1", "Regex": "'eval_f1': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "eval_precision", "Regex": "'eval_precision': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "eval_recall", "Regex": "'eval_recall': ([0-9]+(.|e\-)[0-9]+),?"},
    {"Name": "eval_runtime", "Regex": "'eval_runtime': ([0-9]+(.|e\-)[0-9]+),?"},
        "Name": "eval_samples_per_second",
        "Regex": "'eval_samples_per_second': ([0-9]+(.|e\-)[0-9]+),?",
    {"Name": "epoch", "Regex": "'epoch': ([0-9]+(.|e\-)[0-9]+),?"},
[ ]:
from sagemaker.huggingface import HuggingFace

# hyperparameters, which are passed into the training job
hyperparameters = {"epochs": 1, "model_name": "distilbert-base-uncased", "num_labels": 6}

Now, let’s define the SageMaker HuggingFace estimator with resource configurations and hyperparameters to train Text Classification on 20 Newsgroups dataset, running on a p3.2xlarge instance.

[ ]:
huggingface_estimator = HuggingFace(

Once we’ve defined our estimator we can specify the hyperparameters we’d like to tune and their possible values. We have three different types of hyperparameters. - Categorical parameters need to take one value from a discrete set. We define this by passing the list of possible values to CategoricalParameter(list) - Continuous parameters can take any real number value between the minimum and maximum value, defined by ContinuousParameter(min, max) - Integer parameters can take any integer value between the minimum and maximum value, defined by IntegerParameter(min, max)

Note, if possible, it’s almost always best to specify a value as the least restrictive type. For example, tuning learning rate as a continuous value between 0.01 and 0.2 is likely to yield a better result than tuning as a categorical parameter with values 0.01, 0.1, 0.15, or 0.2.

[ ]:
from sagemaker.tuner import (

hyperparameter_ranges = {
    "train_batch_size": IntegerParameter(8, 32),

Next we’ll specify the objective metric that we’d like to tune and its definition, which includes the regular expression (Regex) needed to extract that metric from the CloudWatch logs of the training job. If you bring your own algorithm, your algorithm emits metrics by itself. In that case, you’ll need to add a MetricDefinition object here to define the format of those metrics through regex, so that SageMaker knows how to extract those metrics from your CloudWatch logs.

In this case, we elected to monitor eval_accuracy as you can see below.

[ ]:
objective_metric_name = "eval_accuracy"
objective_type = "Maximize"
hpo_metric_definitions = [
    {"Name": "eval_accuracy", "Regex": "'eval_accuracy': ([0-9]+(.|e\-)[0-9]+),?"}

Now, we’ll create a HyperparameterTuner object, to which we pass: - The HuggingFace estimator we created above - Our hyperparameter ranges - Objective metric name and definition - Tuning resource configurations such as Number of training jobs to run in total and how many training jobs can be run in parallel.

[ ]:
tuner = HyperparameterTuner(

Launch hyperparameter tuning job

Now we can launch a hyperparameter tuning job by calling fit() function. After the hyperparameter tuning job is created, we can go to SageMaker console to track the progress of the hyperparameter tuning job until it is completed.

This should take around 28 minutes to complete.

[ ]:

tuner.fit({"train": training_input_path, "test": test_input_path}, logs=True)

Analyze Results of a Hyperparameter Tuning job

Once you have completed a tuning job, (or even while the job is still running) you can use the code below to analyze the results to understand how each hyperparameter effects the quality of the model.

[ ]:
sm_client = boto3.Session().client("sagemaker")

tuning_job_name = tuner.latest_tuning_job.name

Track hyperparameter tuning job progress

After you launch a tuning job, you can see its progress by calling describe_tuning_job API. The output from describe-tuning-job is a JSON object that contains information about the current state of the tuning job. You can call list_training_jobs_for_tuning_job to see a detailed list of the training jobs that the tuning job launched.

[ ]:
tuning_job_result = sm_client.describe_hyper_parameter_tuning_job(

status = tuning_job_result["HyperParameterTuningJobStatus"]
if status != "Completed":
    print("Reminder: the tuning job has not been completed.")

job_count = tuning_job_result["TrainingJobStatusCounters"]["Completed"]
print("%d training jobs have completed" % job_count)

is_minimize = (
    != "Maximize"
objective_name = tuning_job_result["HyperParameterTuningJobConfig"][
[ ]:
from pprint import pprint

if tuning_job_result.get("BestTrainingJob", None):
    print("Best model found so far:")
    print("No training jobs have reported results yet.")

Fetch all results as DataFrame

We can list hyperparameters and objective metrics of all training jobs and pick up the training job with the best objective metric.

[ ]:
import pandas as pd

tuner_analytics = sagemaker.HyperparameterTuningJobAnalytics(tuning_job_name)

full_df = tuner_analytics.dataframe()

if len(full_df) > 0:
    df = full_df[full_df["FinalObjectiveValue"] > -float("inf")]
    if len(df) > 0:
        df = df.sort_values("FinalObjectiveValue", ascending=is_minimize)
        print("Number of training jobs with valid objective: %d" % len(df))
        print({"lowest": min(df["FinalObjectiveValue"]), "highest": max(df["FinalObjectiveValue"])})
        pd.set_option("display.max_colwidth", -1)  # Don't truncate TrainingJobName
        print("No training jobs have reported valid results yet.")


Deploy the best trained model

Once the training is done, we can deploy the trained model as an Amazon SageMaker real-time hosted endpoint. This will allow us to make predictions (or inference) from the model. Note that we don’t have to host on the same type of instance that we used to train, because usually for inference, less compute power is needed than for training, and in addition, instance endpoints will be up and running for long, it’s advisable to choose a cheaper instance for inference.

  • ml.p3.2xlarge - deliver high performance compute in the cloud with up to 8 NVIDIA® V100 Tensor Core GPUs and up to 100 Gbps of networking throughput for machine learning and HPC applications.

  • ml.g4dn.xlarge - the industry’s most cost-effective and versatile GPU instances for deploying machine learning models such as image classification, object detection, and speech recognition, and for graphics-intensive applications such as remote graphics workstations, game streaming, and graphics rendering.

[ ]:
predictor = tuner.deploy(1, "ml.g4dn.xlarge")

Then, we use the returned predictor object to call the endpoint.

[ ]:
def predict_sentence(sentence):
    result = predictor.predict({"inputs": sentence})
    index = int(result[0]["label"].split("LABEL_")[1])
[ ]:
sentences = [
    "The modem is an internal AT/(E)ISA 8-bit card (just a little longer than a half-card).",
    "In the cage I usually wave to bikers.  They usually don't wave back.  My wife thinks it's strange but I don't care.",
    "Voyager has the unusual luck to be on a stable trajectory out of the solar system.",

# using the same processing logic that we used during data preparation for training
processed_sentences = process_text(sentences)

for sentence in processed_sentences:

Clean up

Endpoints should be deleted when no longer in use, since (per the SageMaker pricing page) they’re billed by time deployed.

[ ]:

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable