SageMaker JumpStart Foundation Models - GPT-J, GPT-Neo Few-shot learning


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


  1. Set Up

  2. Select a model

  3. Retrieve Artifacts & Deploy an Endpoint

  4. Query endpoint and parse response

  5. Advanced features: How to use varisous parameters to control the generated text

  6. Advanced features: How to use prompts engineering to solve different tasks

  7. Clean up the endpoint

Note: This notebook was tested on ml.t3.medium instance in Amazon SageMaker Studio with Python 3 (Data Science) kernel and in Amazon SageMaker Notebook instance with conda_python3 kernel.

1. Set Up

[ ]:
!pip install ipywidgets==7.0.0 --quiet
!pip install --upgrade sagemaker --quiet

Permissions and environment variables

[ ]:
import sagemaker, boto3, json
from sagemaker.session import Session

sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()

2. Select a pre-trained model


You can continue with the default model, or can choose a different model from the dropdown generated upon running the next cell. A complete list of SageMaker pre-trained models can also be accessed at Sagemaker pre-trained Models. ***

[ ]:
model_id, model_version, = (
    "huggingface-textgeneration1-gpt-j-6b",
    "1.*",
)

[Optional] Select a different Sagemaker pre-trained model. Here, we download the model_manifest file from the Built-In Algorithms s3 bucket, filter-out all the Text Generation models and select a model for inference.

The notebook in the following sections uses GPT-J-6B as an example. You are welcome to try Bloom 7b1, Bloom 3b, GPT-NEO-2b7 and many others yourself. Please modify model_id based on the dropdown list shown as below. ***

[ ]:
from ipywidgets import Dropdown
from sagemaker.jumpstart.notebook_utils import list_jumpstart_models

# Retrieves all Text Generation models available by SageMaker Built-In Algorithms.
filter_value = "task == textgeneration1"
text_generation_models = list_jumpstart_models(filter=filter_value)

# display the model-ids in a dropdown to select a model for inference.
model_dropdown = Dropdown(
    options=text_generation_models,
    value=model_id,
    description="Select a model",
    style={"description_width": "initial"},
    layout={"width": "max-content"},
)

Choose a model for Inference

[ ]:
display(model_dropdown)
[ ]:
# model_version="*" fetches the latest version of the model
model_id, model_version = model_dropdown.value, "1.*"

3. Retrieve Artifacts & Deploy an Endpoint


Using SageMaker, we can perform inference on the pre-trained model, even without fine-tuning it first on a new dataset. We start by retrieving the deploy_image_uri, deploy_source_uri, and model_uri for the pre-trained model. To host the pre-trained model, we create an instance of `sagemaker.model.Model <https://sagemaker.readthedocs.io/en/stable/api/inference/model.html>`__ and deploy it. This may take a few minutes.


[ ]:
model_id
[ ]:
from sagemaker import image_uris, model_uris, script_uris, hyperparameters
from sagemaker.model import Model
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base


endpoint_name = name_from_base(f"jumpstart-example-{model_id}")

inference_instance_type = "ml.g5.12xlarge"

# Retrieve the inference docker container uri. This is the base HuggingFace container image for the default model above.

deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    image_scope="inference",
    model_id=model_id,
    model_version=model_version,
    instance_type=inference_instance_type,
)

# Retrieve the model uri.
model_uri = model_uris.retrieve(
    model_id=model_id, model_version=model_version, model_scope="inference"
)


model = Model(
    image_uri=deploy_image_uri,
    model_data=model_uri,
    role=aws_role,
    predictor_cls=Predictor,
    name=endpoint_name,
)

# deploy the Model. Note that we need to pass Predictor class when we deploy model through Model class,
# for being able to run inference through the sagemaker API.
model_predictor = model.deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    predictor_cls=Predictor,
    endpoint_name=endpoint_name,
)

4. Query endpoint and parse response

[ ]:
newline, bold, unbold = "\n", "\033[1m", "\033[0m"


def query_endpoint(encoded_text, endpoint_name):
    client = boto3.client("runtime.sagemaker")
    response = client.invoke_endpoint(
        EndpointName=endpoint_name, ContentType="application/x-text", Body=encoded_text
    )
    return response


def parse_response(query_response):
    model_predictions = json.loads(query_response["Body"].read())
    generated_text = model_predictions[0]["generated_text"]
    return generated_text
[ ]:
newline, bold, unbold = "\n", "\033[1m", "\033[0m"

text1 = "I will"
text2 = "The movie is so funny"


for text in [text1, text2]:
    query_response = query_endpoint(json.dumps(text).encode("utf-8"), endpoint_name=endpoint_name)
    generated_text = parse_response(query_response)
    print(
        f"Inference:{newline}"
        f"input text: {text}{newline}"
        f"generated text: {bold}{generated_text}{unbold}{newline}"
    )

5. Advanced features: How to use various advanced parameters to control the generated text


This model also supports many advanced parameters while performing inference. They include:

  • max_length: Model generates text until the output length (which includes the input context length) reaches max_length. If specified, it must be a positive integer.

  • num_return_sequences: Number of output sequences returned. If specified, it must be a positive integer.

  • num_beams: Number of beams used in the greedy search. If specified, it must be integer greater than or equal to num_return_sequences.

  • no_repeat_ngram_size: Model ensures that a sequence of words of no_repeat_ngram_size is not repeated in the output sequence. If specified, it must be a positive integer greater than 1.

  • temperature: Controls the randomness in the output. Higher temperature results in output sequence with low-probability words and lower temperature results in output sequence with high-probability words. If temperature -> 0, it results in greedy decoding. If specified, it must be a positive float.

  • early_stopping: If True, text generation is finished when all beam hypotheses reach the end of stence token. If specified, it must be boolean.

  • do_sample: If True, sample the next word as per the likelyhood. If specified, it must be boolean.

  • top_k: In each step of text generation, sample from only the top_k most likely words. If specified, it must be a positive integer.

  • top_p: In each step of text generation, sample from the smallest possible set of words with cumulative probability top_p. If specified, it must be a float between 0 and 1.

  • seed: Fix the randomized state for reproducibility. If specified, it must be an integer.

We may specify any subset of the parameters mentioned above while invoking an endpoint. Next, we show an example of how to invoke endpoint with these arguments


[ ]:
# Input must be a json
payload = {
    "text_inputs": ["I like living in New York"],
    "max_length": 50,
    "num_return_sequences": 1,
    "top_k": 50,
    "top_p": 0.95,
    "do_sample": True,
    "num_beams": 3,
}


def query_endpoint_with_json_payload(encoded_json, endpoint_name):
    client = boto3.client("runtime.sagemaker")
    response = client.invoke_endpoint(
        EndpointName=endpoint_name, ContentType="application/json", Body=encoded_json
    )
    return response


def parse_response_multiple_texts(query_response):
    model_predictions = json.loads(query_response["Body"].read())
    # generated_text = model_predictions[0]["generated_texts"]
    generated_text = []
    for x in model_predictions[0]:
        generated_text.append(x["generated_text"])
    return generated_text


query_response = query_endpoint_with_json_payload(
    json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
)

generated_texts = parse_response_multiple_texts(query_response)
print(generated_texts)

6. Advanced features: How to use prompts engineering to solve different few shot learning NLP tasks

Note . The notebook in the following sections uses GPT-J-6B as an example. You are welcome to try Bloom 7b1, Bloom 3b, GPT-NEO-2b7 and many others yourself. Please modify model_id in above section.

6.1. Summarization

Define the text article you want to summarize.

[ ]:
text = """Amazon Comprehend uses natural language processing (NLP) to extract insights about the content of documents. It develops insights by recognizing the entities, key phrases, language, sentiments, and other common elements in a document. Use Amazon Comprehend to create new products based on understanding the structure of documents. For example, using Amazon Comprehend you can search social networking feeds for mentions of products or scan an entire document repository for key phrases.
You can access Amazon Comprehend document analysis capabilities using the Amazon Comprehend console or using the Amazon Comprehend APIs. You can run real-time analysis for small workloads or you can start asynchronous analysis jobs for large document sets. You can use the pre-trained models that Amazon Comprehend provides, or you can train your own custom models for classification and entity recognition. """
[ ]:
prompts = [
    """[Original]: Amazon scientists, in collaboration with researchers from the University of Sheffield, are making a large-scale fact extraction and verification dataset publicly available for the first time. The dataset, comprising more than 185,000 evidence-backed claims, is being made available to hopefully catalyze research and development that addresses the problems of fact extraction and verification in software applications or cloud-based services that perform automatic information extraction.
    [Summary]: Amazon and University researchers make fact extraction and verification dataset publicly available.
    ###
    [Original]: Prime members in the U.S. can get even more delivered to their door with a Prime membership. Members can now enjoy one year of Grubhub+ valued at $9.99 per month for free—at no added cost to their Prime membership. To activate this deal, visit amazon.com/grubhub. This new offer includes unlimited, $0 food delivery fees on orders over $12 as well as exclusive perks for Grubhub+ members and rewards like free food and order discounts. Plus, diners can “eat good while doing good” by opting into Grubhub’s Donate the Change program, a donation-matching initiative that raised more than $25 million in 2021 alone, benefiting more than 20 charitable organizations across the country.
    [Summary]: Prime members in the U.S. can enjoy one year of Grubhub+ for free, with no food-delivery fees on eligible orders.
    ###
    [Original]: {text}
    [Summary]:"""
]

num_return_sequences = 1
parameters = {
    "max_length": 600,
    "num_return_sequences": num_return_sequences,
    "top_p": 0.01,
    "do_sample": False,
}

print(f"{bold}Number of return sequences are set as {num_return_sequences}{unbold}{newline}")
for each_prompt in prompts:
    payload = {"text_inputs": each_prompt.replace("{text}", text), **parameters}
    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = parse_response_multiple_texts(query_response)
    print(f"{bold} For prompt: '{each_prompt}'{unbold}{newline}")
    print(f"{bold} The {num_return_sequences} summarized results are{unbold}:{newline}")
    for idx, each_generated_text in enumerate(generated_texts):
        print(f"{bold}Result {idx}{unbold}: {each_generated_text}{newline}")

6.2. Code generation

[ ]:
description = "a Headline saying Welcome to AI"
[ ]:
prompts = [
    """description: a orange button that says stop
    code: <button style=color:white; background-color:orange;>Stop</button>
    ###
    description: a blue box that contains yellow circles with red borders
    code: <div style=background-color: blue; padding: 20px;><div style=background-color: yellow; border: 5px solid red; border-radius: 50%; padding: 20px; width: 100px; height: 100px;>
    ###
    description: {description}
    code:"""
]


parameters = {
    "max_length": 200,
    "num_return_sequences": 1,
    "top_k": 50,
    "top_p": 0.3,
    "do_sample": False,
}


for each_prompt in prompts:
    input_text = each_prompt.replace("{description}", description)
    print(f"{bold} For prompt{unbold}: '{each_prompt}'{newline}")

    payload = {"text_inputs": input_text, **parameters}
    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = parse_response_multiple_texts(query_response)
    print(f"{bold} The reasoning result is{unbold}: '{generated_texts}'{newline}")

How about SQL code?

[ ]:
question = "Fetch three employees from the Employee table"
[ ]:
prompts = [
    """Question: Fetch the companies that have less than five people in it.
    Answer: SELECT COMPANY, COUNT(EMPLOYEE_ID) FROM Employee GROUP BY COMPANY HAVING COUNT(EMPLOYEE_ID) < 5;
    ###
    Question: Show all companies along with the number of employees in each department
    Answer: SELECT COMPANY, COUNT(COMPANY) FROM Employee GROUP BY COMPANY;
    ###
    Question: Show the last record of the Employee table
    Answer: SELECT * FROM Employee ORDER BY LAST_NAME DESC LIMIT 1;
    ###
    Question: {question};
    Answer:"""
]


parameters = {
    "max_length": 200,
    "num_return_sequences": 1,
    "top_p": 0.01,
    "do_sample": False,
}


for each_prompt in prompts:
    input_text = each_prompt.replace("{question}", question)
    print(f"{bold} For prompt{unbold}: '{each_prompt}'{newline}")

    payload = {"text_inputs": input_text, **parameters}
    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = parse_response_multiple_texts(query_response)
    print(f"{bold} The reasoning result is{unbold}: '{generated_texts}'{newline}")

6.3. Name entity recognition (NER)

[ ]:
sentence = "David Melvin is an investment and financial services professional at CITIC CLSA with over 30 years’ experience in investment banking and private equity. He is currently a Senior Adviser of CITIC CLSA."
[ ]:
prompts = [
    """[Text]: Fred is a serial entrepreneur. Co-founder and CEO of Platform.sh, he previously co-founded Commerce Guys, a leading Drupal ecommerce provider. His mission is to guarantee that as we continue on an ambitious journey to profoundly transform how cloud computing is used and perceived, we keep our feet well on the ground continuing the rapid growth we have enjoyed up until now.
    [Name]: Fred
    [Position]: Co-founder and CEO
    [Company]: Platform.sh
    ###
    [Text]: Microsoft (the word being a portmanteau of "microcomputer software") was founded by Bill Gates on April 4, 1975, to develop and sell BASIC interpreters for the Altair 8800. Steve Ballmer replaced Gates as CEO in 2000, and later envisioned a "devices and services" strategy.
    [Name]:  Steve Ballmer
    [Position]: CEO
    [Company]: Microsoft
    ###
    [Text]: Franck Riboud was born on 7 November 1955 in Lyon. He is the son of Antoine Riboud, the previous CEO, who transformed the former European glassmaker BSN Group into a leading player in the food industry. He is the CEO at Danone.
    [Name]:  Franck Riboud
    [Position]: CEO
    [Company]: Danone
    ###
    [Text]: {sentence}
    """
]


parameters = {
    "max_length": 550,
    "num_return_sequences": 1,
    "top_p": 0.1,
    "do_sample": True,
}


for each_prompt in prompts:
    input_text = each_prompt.replace("{sentence}", sentence)
    print(f"{bold} For prompt{unbold}: '{each_prompt}'{newline}")

    payload = {"text_inputs": input_text, **parameters}
    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = parse_response_multiple_texts(query_response)
    print(f"{bold} The reasoning result is{unbold}: '{generated_texts}'{newline}")

6.4. Question answering

[ ]:
question = "Which plan is recommended for GPT-J?"
[ ]:
prompts = [
    """Context: NLP Cloud was founded in 2021 when the team realized there was no easy way to reliably leverage Natural Language Processing in production.
    Question: When was NLP Cloud founded?
    Answer: 2021
    ###
    Context: NLP Cloud developed their API by mid-2020 and they added many pre-trained open-source models since then.
    Question: What did NLP Cloud develop?
    Answer: API
    ###
    Context: All plans can be stopped anytime. You only pay for the time you used the service. In case of a downgrade, you will get a discount on your next invoice.
    Question: When can plans be stopped?
    Answer: Anytime
    ###
    Context: The main challenge with GPT-J is memory consumption. Using a GPU plan is recommended.
    Question: {question}
    Answer:"""
]


parameters = {
    "max_length": 350,
    "num_return_sequences": 1,
    "top_p": 0.1,
    "do_sample": True,
}


for each_prompt in prompts:
    input_text = each_prompt.replace("{question}", question)
    print(f"{bold} For prompt{unbold}: '{each_prompt}'{newline}")

    payload = {"text_inputs": input_text, **parameters}
    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = parse_response_multiple_texts(query_response)
    print(f"{bold} The reasoning result is{unbold}: '{generated_texts}'{newline}")

6.5. Grammar and spelling correction

[ ]:
sentence = "I do not wan to go"
[ ]:
prompts = [
    """I love goin to the beach.
    Correction: I love going to the beach.
    ###
    Let me hav it!
    Correction: Let me have it!
    ###
    It have too many drawbacks.
    Correction: It has too many drawbacks.
    ###
    {sentence}
    Correction:"""
]


parameters = {
    "max_length": 250,
    "num_return_sequences": 1,
    "top_p": 0.1,
    "do_sample": True,
}


for each_prompt in prompts:
    input_text = each_prompt.replace("{sentence}", sentence)
    print(f"{bold} For prompt{unbold}: '{each_prompt}'{newline}")

    payload = {"text_inputs": input_text, **parameters}
    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = parse_response_multiple_texts(query_response)
    print(f"{bold} The reasoning result is{unbold}: '{generated_texts}'{newline}")

6.6. Product description and generalization

[ ]:
sentence = """t-shirt, men, $39"""
[ ]:
prompts = [
    """Generate a product description out of keywords.

    Keywords: shoes, women, $59
    Sentence: Beautiful shoes for women at the price of $59.
    ###
    Keywords: trousers, men, $69
    Sentence: Modern trousers for men, for $69 only.
    ###
    Keywords: gloves, winter, $19
    Sentence: Amazingly hot gloves for cold winters, at $19.
    ###
    Keywords: {sentence}
    Sentence:"""
]


parameters = {
    "max_length": 150,
    "num_return_sequences": 1,
    "do_sample": False,
}


for each_prompt in prompts:
    input_text = each_prompt.replace("{sentence}", sentence)
    print(f"{bold} For prompt{unbold}: '{each_prompt}'{newline}")

    payload = {"text_inputs": input_text, **parameters}
    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = parse_response_multiple_texts(query_response)
    print(f"{bold} The reasoning result is{unbold}: '{generated_texts}'{newline}")

6.7. Sentence / Sentiment Classification

Define the sentence you want to classifiy and the corresponded options.

[ ]:
sentence = "I am trying to cook chicken with tomatoes."
[ ]:
prompts = [
    """Message: When the spaceship landed on Mars, the whole humanity was excited
    Topic: space
    ###
    Message: I love playing tennis and golf. I'm practicing twice a week.
    Topic: sport
    ###
    Message: Managing a team of sales people is a tough but rewarding job.
    Topic: business
    ###
    Message: {sentence}
    Topic:"""
]

parameters = {
    "max_length": 20,
    "num_return_sequences": 1,
    "top_k": 50,
    "top_p": 0.95,
    "do_sample": True,
}


for each_prompt in prompts:
    input_text = each_prompt.replace("{sentence}", sentence)
    print(f"{bold} For prompt{unbold}: '{input_text}'{newline}")
    payload = {"text_inputs": input_text, **parameters}
    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = parse_response_multiple_texts(query_response)
    print(f"{bold} The reasoning result is{unbold}: '{generated_texts}'{newline}")

6.8. Chatbot and Conversational AI

Define the sentence and the language you want to translate the sentence to.

[ ]:
sentence = "I caught flu"
[ ]:
prompts = [
    """This is a discussion between a [human] and a [robot].
    The [robot] is very nice and empathetic.

    [human]: Hello nice to meet you.
    [robot]: Nice to meet you too.
    ###
    [human]: How is it going today?
    [robot]: Not so bad, thank you! How about you?
    ###
    [human]: I am ok, but I am a bit sad...
    [robot]: Oh? Why that?
    ###
    [human]: {sentence}
    [robot]:"""
]

parameters = {
    "max_length": 250,
    "num_return_sequences": 1,
    "do_sample": False,
}


for each_prompt in prompts:
    input_text = each_prompt.replace("{sentence}", sentence)
    print(f"{bold} For prompt{unbold}: '{input_text}'{newline}")
    payload = {"text_inputs": input_text, **parameters}
    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = parse_response_multiple_texts(query_response)
    print(f"{bold} The translated result is{unbold}: '{generated_texts}'{newline}")

6.9. Tweet generation

Define the sentence, pronoun, and options you want to reason.

[ ]:
keyword = "nlp"
[ ]:
prompts = [
    """keyword: markets
    tweet: Take feedback from nature and markets, not from people
    ###
    keyword: children
    tweet: Maybe we die so we can come back as children.
    ###
    keyword: startups
    tweet: Startups should not worry about how to put out fires, they should worry about how to start them.
    ###
    keyword: {keyword}
    tweet:"""
]

parameters = {
    "max_length": 300,
    "num_return_sequences": 1,
    "do_sample": False,
}


for each_prompt in prompts:
    input_text = each_prompt.replace("{keyword}", keyword)
    print(f"{bold} For prompt{unbold}: '{input_text}'{newline}")
    payload = {"text_inputs": input_text, **parameters}
    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = parse_response_multiple_texts(query_response)
    print(f"{bold} The reasoning result is{unbold}: '{generated_texts}'{newline}")

6.10. Machine translation

[ ]:
sentence = "NLP Cloud permet de deployer le NLP en production facilement."
[ ]:
prompts = [
    """Hugging Face a révolutionné le NLP.
    Translation: Hugging Face revolutionized NLP.
    ###
    Cela est incroyable!
    Translation: This is unbelievable!
    ###
    Désolé je ne peux pas.
    Translation: Sorry but I cannot.
    ###
    {sentence}
    Translation:"""
]

parameters = {
    "max_length": 150,
    "num_return_sequences": 1,
    "do_sample": False,
}


for each_prompt in prompts:
    input_text = each_prompt.replace("{sentence}", sentence)
    print(f"{bold} For prompt{unbold}: '{input_text}'{newline}")
    payload = {"text_inputs": input_text, **parameters}
    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = parse_response_multiple_texts(query_response)
    print(f"{bold} The reasoning result is{unbold}: '{generated_texts}'{newline}")

6.11. Paraphrasing

[ ]:
sentence = "What is the best way to learn english?"
[ ]:
prompts = [
    """[Original]: Can you recommed some nice restaurants in New York?
[Paraphrase]: list some excellent restaurants to visit in new york city?
###
[Original]: Which course should I take to get started in data science?
[Paraphrase]: What should I learn to become a data scientist?
###
[Original]: What are the famous places we should not miss in United States?
[Paraphrase]: Recommend some of the best places to visit in United States?
###
[Original]: {sentence}
[Paraphrase]:"""
]


parameters = {
    "max_length": 150,
    "num_return_sequences": 1,
    "top_p": 0.5,
    "do_sample": True,
}


for each_prompt in prompts:
    input_text = each_prompt.replace("{sentence}", sentence)
    print(f"{bold} For prompt{unbold}: '{input_text}'{newline}")
    payload = {"text_inputs": input_text, **parameters}
    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = parse_response_multiple_texts(query_response)
    print(f"{bold} The reasoning result is{unbold}: '{generated_texts}'{newline}")

6.12. Intent classification

[ ]:
sentence = "Can you please teach me Chinese next week?"
[ ]:
prompts = [
    """I want to start coding tomorrow because it seems to be so fun!
    Intent: start coding
    ###
    Show me the last pictures you have please.
    Intent: show pictures
    ###
    Search all these files as fast as possible.
    Intent: search files
    ###
    {sentence}
    Intent:"""
]


parameters = {
    "max_length": 150,
    "num_return_sequences": 1,
    "top_p": 0.5,
    "do_sample": True,
}


for each_prompt in prompts:
    input_text = each_prompt.replace("{sentence}", sentence)
    print(f"{bold} For prompt{unbold}: '{input_text}'{newline}")
    payload = {"text_inputs": input_text, **parameters}
    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = parse_response_multiple_texts(query_response)
    print(f"{bold} The reasoning result is{unbold}: '{generated_texts}'{newline}")

7. Clean up the endpoint

[ ]:
# Delete the SageMaker endpoint
model_predictor.delete_model()
model_predictor.delete_endpoint()

For domain adaption finetuning of text generation models such as GPT-J 6B, please check the notebook ``domain-adaption-finetuning-gpt-j-6b.ipynb`` in the same directory.

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable