SageMaker JumpStart Foundation Models - HuggingFace Text2Text Instruction Fine-Tuning


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


Welcome to Amazon SageMaker JumpStart! You can use SageMaker JumpStart to solve many Machine Learning tasks through one-click in SageMaker Studio, or through SageMaker Python SDK.

In this demo notebook, we use the SageMaker Python SDK to fine-tune a Text2Text model. Such a model takes prompting text as input and generates text as output. The prompt can include a task description in natural language. Accordingly, the model can be used for a variety of NLP tasks (e.g., text summarization, question answering, etc.).

We will fine-tune a pre-trained FLAN T5 model from Hugging Face. While pre-trained FLAN T5 models can be used “as is” for many tasks, fine-tuning can improve model performance on a particular task or language domain. As an example, we will fine-tune the model for a task that was not used for pre-training. After fine-tuning we will deploy two inference endpoints, one with a pre-trained and one with a fine-tuned model. We will then run the same inference query against both endpoints and compare results.

This notebook was tested on ml.t3.medium Amazon SageMaker Notebook instance with conda_python3 kernel.

In this notebook:

  1. Setting up

  2. Fine-tuning a model

  3. Deploying inference endpoints

  4. Running inference queries

  5. Cleaning up resources

1. Setting up

We begin by installing and upgrading necessary packages. Restart the kernel after executing the cell below for the first time.

[34]:
!pip install sagemaker==2.166.0 --quiet

We will use the following variables throughout the notebook. In particular, we select FLAN T5 model size and select training and inference instance types. We also obtain execution role associated with the current notebook instance.

[ ]:
import boto3
import sagemaker

# Get current region, role, and default bucket
aws_region = boto3.Session().region_name
aws_role = sagemaker.session.Session().get_caller_identity_arn()
output_bucket = sagemaker.Session().default_bucket()

# This will be useful for printing
newline, bold, unbold = "\n", "\033[1m", "\033[0m"

print(f"{bold}aws_region:{unbold} {aws_region}")
print(f"{bold}aws_role:{unbold} {aws_role}")
print(f"{bold}output_bucket:{unbold} {output_bucket}")
[ ]:
from sagemaker.jumpstart.model import JumpStartModel

# Replace with larger model if needed
model_id, model_version = "huggingface-text2text-flan-t5-base", "1.*"
pretrained_model = JumpStartModel(model_id=model_id)
pretrained_predictor = pretrained_model.deploy()
[ ]:
print(pretrained_model.endpoint_name)
pre_trained_name = pretrained_model.endpoint_name
[78]:
text = """Summarize this content - Amazon Comprehend uses natural language processing (NLP) to extract insights about the content of documents. It develops insights by recognizing the entities, key phrases, language, sentiments, and other common elements in a document. Use Amazon Comprehend to create new products based on understanding the structure of documents. For example, using Amazon Comprehend you can search social networking feeds for mentions of products or scan an entire document repository for key phrases.
You can access Amazon Comprehend document analysis capabilities using the Amazon Comprehend console or using the Amazon Comprehend APIs. You can run real-time analysis for small workloads or you can start asynchronous analysis jobs for large document sets. You can use the pre-trained models that Amazon Comprehend provides, or you can train your own custom models for classification and entity recognition. """
query_response = pretrained_predictor.predict(text)
[41]:
print(query_response["generated_text"])
Understand the benefits of Amazon Comprehend. Understand the benefits of Amazon Comprehend
[88]:
from sagemaker import serializers
from sagemaker import content_types


serializer_options = serializers.retrieve_options(model_id=model_id, model_version=model_version)
content_type_options = content_types.retrieve_options(
    model_id=model_id, model_version=model_version
)

pretrained_predictor.serializer = serializers.JSONSerializer()
pretrained_predictor.content_type = "application/json"
[ ]:
from sagemaker import serializers

input_text = """Summarize this content - Amazon Comprehend uses natural language processing (NLP) to extract insights about the content of documents. It develops insights by recognizing the entities, key phrases, language, sentiments, and other common elements in a document. Use Amazon Comprehend to create new products based on understanding the structure of documents. For example, using Amazon Comprehend you can search social networking feeds for mentions of products or scan an entire document repository for key phrases.
You can access Amazon Comprehend document analysis capabilities using the Amazon Comprehend console or using the Amazon Comprehend APIs. You can run real-time analysis for small workloads or you can start asynchronous analysis jobs for large document sets. You can use the pre-trained models that Amazon Comprehend provides, or you can train your own custom models for classification and entity recognition. """

parameters = {
    "max_length": 600,
    "num_return_sequences": 1,
    "top_p": 0.01,
    "do_sample": False,
}

payload = {"text_inputs": input_text, **parameters}  # JSON Input format
query_response = pretrained_predictor.predict(payload)
print(query_response["generated_texts"][0])

2. Fine-tuning a model

FLAN T5 models were pre-trained on a variety of tasks. In this demo, we fine-tune a model for a new task. In this task, given a piece of text, the model is asked to generate questions that are relevant to the text, but cannot be answered based on provided information. Examples are given in the inference section of this notebook.

2.1. Preparing training data

We will use a subset of SQuAD2.0 for supervised fine-tuning. This dataset contains questions posed by human annotators on a set of Wikipedia articles. In addition to questions with answers, SQuAD2.0 contains about 50k unanswerable questions. Such questions are plausible, but cannot be directly answered from the articles’ content. We only use unanswerable questions for our task.

Citation: @article{rajpurkar2018know, title={Know what you don’t know: Unanswerable questions for SQuAD}, author={Rajpurkar, Pranav and Jia, Robin and Liang, Percy}, journal={arXiv preprint arXiv:1806.03822}, year={2018} }

License: Creative Commons Attribution-ShareAlike License (CC BY-SA 4.0)

[42]:
from sagemaker.s3 import S3Downloader

# We will use the train split of SQuAD2.0
original_data_file = "train-v2.0.json"

# The data was mirrored in the following bucket
original_data_location = (
    f"s3://sagemaker-example-files-prod-{aws_region}/datasets/text/squad2.0/{original_data_file}"
)
S3Downloader.download(original_data_location, ".")

The Text2Text generation model can be fine-tuned on any text data provided that the data is in the expected format. The data must include a training and an optional validation parts. The best model is selected according to the validation loss, calculated at the end of each epoch. If a validation set is not given, an (adjustable) percentage of the training data is automatically split and used for validation.

The training data must be formatted in JSON lines (.jsonl) format, where each line is a dictionary representing a single data sample. All training data must be in a single folder, however it can be saved in multiple jsonl files. The .jsonl file extension is mandatory. The training folder can also contain a template.json file describing the input and output formats.

If no template file is given, the following default template will be used:

{
    "prompt": "{prompt}",
    "completion": "{completion}"
}

In this case, the data in the JSON lines entries must include prompt and completion fields.

In this demo, we are going to use a custom template (see below).

[43]:
import json

local_data_file = "task-data.jsonl"  # any name with .jsonl extension

with open(original_data_file) as f:
    data = json.load(f)

with open(local_data_file, "w") as f:
    for article in data["data"]:
        for paragraph in article["paragraphs"]:
            # iterate over questions for a given paragraph
            for qas in paragraph["qas"]:
                if qas["is_impossible"]:
                    # the question is relevant, but cannot be answered
                    example = {"context": paragraph["context"], "question": qas["question"]}
                    json.dump(example, f)
                    f.write("\n")

template = {
    "prompt": "Ask a question which is related to the following text, but cannot be answered based on the text. Text: {context}",
    "completion": "{question}",
}
with open("template.json", "w") as f:
    json.dump(template, f)
[44]:
from sagemaker.s3 import S3Uploader

train_data_location = f"s3://{output_bucket}/train_data"
S3Uploader.upload(local_data_file, train_data_location)
S3Uploader.upload("template.json", train_data_location)
print(f"{bold}training data:{unbold} {train_data_location}")

We will use simplified JumpStart SDK that defaults various input paramters.

2.2. Starting training

We are now ready to start the training job. This can take a while to complete, from 20 minutes to several hours, depending on the model size, amount of data, and so on (e.g., it can take a few hours for the xl model, 40k examples and 3 epochs).

[ ]:
from sagemaker.jumpstart.estimator import JumpStartEstimator

estimator = JumpStartEstimator(
    model_id=model_id,
)

estimator.set_hyperparameters(instruction_tuned="True", epoch="3", max_input_length="1024")
estimator.fit({"training": train_data_location})

Performance metrics such as training and validation loss can be accessed through CloudWatch during training. We can also fetch the most recent snapshot of metrics as follows.

3. Deploying inference endpoints

Remainder of the notebook should be executed once the training job is successfully completed. We will inference endpoints of pretrained and finetuned models and run the same request against the two endpoints and compare the results. Note that each endpoint deployment can take a few minutes.

[ ]:
finetuned_predictor = estimator.deploy()
[69]:
fine_tuned_name = finetuned_predictor.endpoint_name
print(fine_tuned_name)

4. Running inference queries

As the name suggests, a Text2Text model such as FLAN T5 receives a piece of text as input, and generates text as output. The input text will contain the description of the task. In this demo, our task is to generate questions given a piece of text. The questions must be relevant to the text, but the text should contain no answer. Such a task could arise when automating gathering additional information, or identifying gaps in technical documentation.

[ ]:
parameters = {
    "max_length": 40,  # restrict the length of the generated text
    "num_return_sequences": 5,  # we will inspect several model outputs
    "num_beams": 10,  # use beam search
}

prompt = "Ask a question which is related to the following text, but cannot be answered based on the text. Text: {context}"
context = "Amazon Comprehend uses natural language processing (NLP) to extract insights about the content of documents. It develops insights by recognizing the entities, key phrases, language, sentiments, and other common elements in a document. Use Amazon Comprehend to create new products based on understanding the structure of documents. For example, using Amazon Comprehend you can search social networking feeds for mentions of products or scan an entire document repository for key phrases. You can access Amazon Comprehend document analysis capabilities using the Amazon Comprehend console or using the Amazon Comprehend APIs. You can run real-time analysis for small workloads or you can start asynchronous analysis jobs for large document sets. You can use the pre-trained models that Amazon Comprehend provides, or you can train your own custom models for classification and entity recognition. All of the Amazon Comprehend features accept UTF-8 text documents as the input. In addition, custom classification and custom entity recognition accept image files, PDF files, and Word files as input. Amazon Comprehend can examine and analyze documents in a variety of languages, depending on the specific feature. For more information, see Languages supported in Amazon Comprehend. Amazon Comprehend's Dominant language capability can examine documents and determine the dominant language for a far wider selection of languages."
expanded_prompt = prompt.replace("{context}", context)

finetuned_predictor.serializer = serializers.JSONSerializer()
finetuned_predictor.content_type = "application/json"
payload = {"text_inputs": expanded_prompt, **parameters}  # JSON Input format

# query_response = finetuned_predictor.predict(json.dumps(payload).encode("utf-8"))
query_response = finetuned_predictor.predict(payload)
print(query_response["generated_texts"][0])

5. Cleaning up resources

[ ]:
# Delete resources
pretrained_predictor.delete_model()
pretrained_predictor.delete_endpoint()
finetuned_predictor.delete_model()
finetuned_predictor.delete_endpoint()

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable