Introduction to SageMaker Built-In Algorithms - Text Generation


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable


  1. Set Up

  2. Select a model

  3. Retrieve Artifacts & Deploy an Endpoint

  4. Query endpoint and parse response

  5. Advanced features

  6. Clean up the endpoint

Note: This notebook was tested on ml.t3.medium instance in Amazon SageMaker Studio with Python 3 (Data Science) kernel and in Amazon SageMaker Notebook instance with conda_pytorch_p39 kernel.

1. Set Up

[ ]:
!pip install sagemaker ipywidgets --upgrade --quiet

Permissions and environment variables

[ ]:
import sagemaker, boto3, json
from sagemaker.session import Session

sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()

2. Select a pre-trained model


You can continue with the default model, or can choose a different model from the dropdown generated upon running the next cell. A complete list of SageMaker pre-trained models can also be accessed at JumpStart pre-trained Models. ***

[ ]:
model_id, model_version, = (
    "huggingface-textgeneration-gpt2",
    "*",
)

[Optional] Select a different Sagemaker pre-trained model. Here, we download the model_manifest file from the Built-In Algorithms s3 bucket, filter-out all the Text Generation models and select a model for inference. ***

[ ]:
from ipywidgets import Dropdown
from sagemaker.jumpstart.notebook_utils import list_jumpstart_models
from sagemaker.jumpstart.filters import And

# Retrieves all Text Generation models available by SageMaker Built-In Algorithms.
filter_value = And("task == textgeneration", "framework == huggingface")
text_generation_models = list_jumpstart_models(filter=filter_value)

# display the model-ids in a dropdown to select a model for inference.
model_dropdown = Dropdown(
    options=text_generation_models,
    value=model_id,
    description="Select a model",
    style={"description_width": "initial"},
    layout={"width": "max-content"},
)

Chose a model for Inference

[ ]:
display(model_dropdown)

Using Models not Present in the Dropdown


If you want to choose any other model which is not present in the dropdown and is available at HugginFace Text Generation please choose huggingface-textgeneration-models in the dropdown and pass the model_id in the HF_MODEL_ID variable. Inference on the models listed in the dropdown menu can be run in network isolation. In such a case, no inbound or outbound network calls can be made to or from the model container. The models listed in the dropdown can also be deployed with custom VPC settings, which can provide your model container with a network connection within your VPC that is not connected to the internet. Refer to AWS documentation for more details.

However, when running inference on a model specified through HF_MODEL_ID, the model container will download the model artifact from HuggingFace. Therefore, the model container cannot run in network isolation. Furthermore, if you want to use custom VPC settings, you must provide access the HuggingFace portal in your VPC. ***

[ ]:
# model_version="*" fetches the latest version of the model
model_id, model_version = model_dropdown.value, "*"

hub = {}
HF_MODEL_ID = "xlnet-base-cased"  # Pass any other HF_MODEL_ID from - https://huggingface.co/models?pipeline_tag=text-classification&sort=downloads
if model_id == "huggingface-textgeneration-models":
    hub["HF_MODEL_ID"] = HF_MODEL_ID
    hub["HF_TASK"] = "text-generation"

3. Retrieve Artifacts & Deploy an Endpoint


Using SageMaker, we can perform inference on the pre-trained model, even without fine-tuning it first on a new dataset. We start by retrieving the deploy_image_uri, deploy_source_uri, and model_uri for the pre-trained model. To host the pre-trained model, we create an instance of `sagemaker.model.Model <https://sagemaker.readthedocs.io/en/stable/api/inference/model.html>`__ and deploy it. This may take a few minutes.


[ ]:
from sagemaker import image_uris, model_uris, script_uris, hyperparameters
from sagemaker.model import Model
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base


endpoint_name = name_from_base(f"jumpstart-example-{model_id}")

inference_instance_type = "ml.p3.2xlarge"

# Retrieve the inference docker container uri. This is the base HuggingFace container image for the default model above.
deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    image_scope="inference",
    model_id=model_id,
    model_version=model_version,
    instance_type=inference_instance_type,
)


# Retrieve the model uri. This includes the pre-trained nvidia-ssd model and parameters.
model_uri = model_uris.retrieve(
    model_id=model_id, model_version=model_version, model_scope="inference"
)


# Create the SageMaker model instance
model = Model(
    image_uri=deploy_image_uri,
    model_data=model_uri,
    role=aws_role,
    predictor_cls=Predictor,
    name=endpoint_name,
    env=hub,
)

# deploy the Model. Note that we need to pass Predictor class when we deploy model through Model class,
# for being able to run inference through the sagemaker API.
model_predictor = model.deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    predictor_cls=Predictor,
    endpoint_name=endpoint_name,
)

4. Query endpoint and parse response

[ ]:
def query(model_predictor, text):
    """Query the model predictor."""

    encoded_text = text.encode("utf-8")

    query_response = model_predictor.predict(
        encoded_text,
        {
            "ContentType": "application/x-text",
            "Accept": "application/json",
        },
    )
    return query_response


def parse_response(query_response):
    """Parse response and return the generated text."""

    model_predictions = json.loads(query_response)
    generated_text = model_predictions["generated_text"]
    return generated_text
[ ]:
newline, bold, unbold = "\n", "\033[1m", "\033[0m"

text1 = "As far as I am concerned, I will"
text2 = "The movie is"

for text in [text1, text2]:
    query_response = query(model_predictor, text)
    generated_text = parse_response(query_response)
    print(f"Input text: {text}{newline}" f"Generated text: {bold}{generated_text}{unbold}{newline}")

5. Advanced features


This model also supports many advanced parameters while performing inference. They include:

  • max_length: Model generates text until the output length (which includes the input context length) reaches max_length. If specified, it must be a positive integer.

  • num_return_sequences: Number of output sequences returned. If specified, it must be a positive integer.

  • num_beams: Number of beams used in the greedy search. If specified, it must be integer greater than or equal to num_return_sequences.

  • no_repeat_ngram_size: Model ensures that a sequence of words of no_repeat_ngram_size is not repeated in the output sequence. If specified, it must be a positive integer greater than 1.

  • temperature: Controls the randomness in the output. Higher temperature results in output sequence with low-probability words and lower temperature results in output sequence with high-probability words. If temperature -> 0, it results in greedy decoding. If specified, it must be a positive float.

  • early_stopping: If True, text generation is finished when all beam hypotheses reach the end of stence token. If specified, it must be boolean.

  • do_sample: If True, sample the next word as per the likelyhood. If specified, it must be boolean.

  • top_k: In each step of text generation, sample from only the top_k most likely words. If specified, it must be a positive integer.

  • top_p: In each step of text generation, sample from the smallest possible set of words with cumulative probability top_p. If specified, it must be a float between 0 and 1.

  • seed: Fix the randomized state for reproducibility. If specified, it must be an integer.

  • return_full_text: If True, input text will be part of the output generated text. If specified, it must be boolean. The default value for it is False.

We may specify any subset of the parameters mentioned above while invoking an endpoint. Next, we show an example of how to invoke endpoint with these arguments


[ ]:
import json

payload = {
    "text_inputs": "My name is Lewis and I like to",
    "max_length": 50,
    "num_return_sequences": 3,
    "top_k": 50,
    "top_p": 0.95,
    "do_sample": True,
}


def query_endpoint_with_json_payload(model_predictor, payload):
    """Query the model predictor with json payload."""

    encoded_payload = json.dumps(payload).encode("utf-8")

    query_response = model_predictor.predict(
        encoded_payload,
        {
            "ContentType": "application/json",
            "Accept": "application/json",
        },
    )
    return query_response


def parse_response_multiple_texts(query_response):
    """Parse response and return the generated texts."""

    model_predictions = json.loads(query_response)
    generated_texts = model_predictions["generated_texts"]
    return generated_texts


query_response = query_endpoint_with_json_payload(model_predictor, payload)
generated_texts = parse_response_multiple_texts(query_response)
print(f"Input text: {text}{newline}" f"Generated text: {bold}{generated_texts}{unbold}{newline}")

6. Clean up the endpoint

[ ]:
# Delete the SageMaker endpoint
model_predictor.delete_model()
model_predictor.delete_endpoint()

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable