SageMaker JumpStart Foundation Models - Chatbots


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


  1. Set up

  2. Select a pre-trained model

  3. Retrieve artifacts & deploy an endpoint

  4. Query endpoint and parse response

  5. Use a shell interpreter to interact with your deployed endpoint

  6. Clean up the endpoint

Note: This notebook was tested on ml.t3.medium instance in Amazon SageMaker Studio with Python 3 (Data Science) kernel and in Amazon SageMaker Notebook instance with conda_python3 kernel.

### 1. Set up

Before executing the notebook, there are some initial steps required for set up.


[ ]:
%pip install sagemaker ipywidgets --upgrade --quiet

2. Select a pre-trained model


You can continue with the default model, or can choose a different model from the dropdown generated upon running the next cell. A complete list of SageMaker pre-trained models can also be accessed at SageMaker pre-trained Models. ***

[ ]:
from typing import NamedTuple
from typing import Dict
from typing import Any


class JumpStartChatbotModelConfig(NamedTuple):
    model_id: str
    model_kwargs: Dict[str, Any] = {}
    payload_kwargs: Dict[str, Any] = {}


jumpstart_chatbot_models_config = [
    JumpStartChatbotModelConfig(
        model_id="huggingface-textgeneration-falcon-7b-instruct-bf16",
        payload_kwargs={"return_full_text": True},
    ),
    JumpStartChatbotModelConfig(
        model_id="huggingface-textgeneration-falcon-40b-instruct-bf16",
        payload_kwargs={"return_full_text": True},
    ),
    JumpStartChatbotModelConfig(
        model_id="huggingface-textgeneration1-redpajama-incite-chat-3B-v1-fp16",
    ),
    JumpStartChatbotModelConfig(
        model_id="huggingface-textgeneration1-redpajama-incite-chat-7B-v1-fp16",
    ),
    JumpStartChatbotModelConfig(
        model_id="huggingface-textgeneration2-gpt-neoxt-chat-base-20b-fp16",
    ),
]
[ ]:
from IPython.display import Markdown
from ipywidgets import Dropdown


dropdown = Dropdown(
    options=[(config.model_id, config) for config in jumpstart_chatbot_models_config],
    value=jumpstart_chatbot_models_config[0],
    description="JumpStart Image Classification Models:",
    style={"description_width": "initial"},
    layout={"width": "max-content"},
)
display(Markdown("### Select a JumpStart chatbot model from the dropdown below"))
display(dropdown)
[ ]:
model_config = dropdown.value

3. Retrieve Artifacts & Deploy an Endpoint


Using SageMaker, we can perform inference on the pre-trained model, even without fine-tuning it first on a new dataset. We start by deploying a JumpStartModel to host the pre-trained model. This may take a few minutes.


[ ]:
from sagemaker.jumpstart.model import JumpStartModel
from sagemaker.predictor import Predictor


model = JumpStartModel(
    model_id=model_config.model_id, model_version="1.*", **model_config.model_kwargs
)
predictor = model.deploy()

Next, the SageMaker Predictor is adjusted to utilize a JSON serializer and the deserializer is custom set to work with all chatbot models supported by this notebook. The deserializer ensures the chatbot will always return a string representing a single generated text sample per query. ***

[ ]:
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer


class JumpStartChatbotDeserializer(JSONDeserializer):
    """A deserializer to retrieve the first generated text from JumpStart text generation models."""

    def deserialize(self, stream, content_type):
        """Crawl the output of JSON deserialization to obtain first generated text model response."""
        data = super().deserialize(stream, content_type)

        while True:
            if isinstance(data, str):
                break
            elif isinstance(data, list):
                data = data[0]
            elif isinstance(data, dict):
                for key in ("generated_text", "generated_texts"):
                    if key in data:
                        data = data[key]
                        break
                else:
                    raise ValueError(f"Generated text keys not found in output {data}.")
            else:
                raise ValueError(f"Output data contains unrecognized type {type(data)}.")

        return data


predictor.serializer = JSONSerializer()
predictor.deserializer = JumpStartChatbotDeserializer()
predictor.content_type = "application/json"

4. Query endpoint and parse response


This model also supports many advanced parameters while performing inference. They include:

  • max_length: Model generates text until the output length (which includes the input context length) reaches max_length. If specified, it must be a positive integer.

  • max_time: The maximum amount of time you allow the computation to run for in seconds. Generation will still finish the current pass after allocated time has been passed. This setting can help to generate a response prior to endpoint invocation response time out errors.

  • num_return_sequences: Number of output sequences returned. If specified, it must be a positive integer.

  • num_beams: Number of beams used in the greedy search. If specified, it must be integer greater than or equal to num_return_sequences.

  • no_repeat_ngram_size: Model ensures that a sequence of words of no_repeat_ngram_size is not repeated in the output sequence. If specified, it must be a positive integer greater than 1.

  • temperature: Controls the randomness in the output. Higher temperature results in output sequence with low-probability words and lower temperature results in output sequence with high-probability words. If temperature -> 0, it results in greedy decoding. If specified, it must be a positive float.

  • early_stopping: If True, text generation is finished when all beam hypotheses reach the end of sentence token. If specified, it must be boolean.

  • do_sample: If True, sample the next word as per the likelihood. If specified, it must be boolean.

  • top_k: In each step of text generation, sample from only the top_k most likely words. If specified, it must be a positive integer.

  • top_p: In each step of text generation, sample from the smallest possible set of words with cumulative probability top_p. If specified, it must be a float between 0 and 1.

  • seed: Fix the randomized state for reproducibility. If specified, it must be an integer.

We may specify any subset of the parameters mentioned above while invoking an endpoint. Next, we show an example of how to invoke endpoint with these arguments


[ ]:
payload = {
    "text_inputs": "<human>: Tell me the steps to make a pizza\n<bot>:",
    "max_length": 100,
    "max_time": 50,
    "top_k": 50,
    "top_p": 0.95,
    "do_sample": True,
    "stopping_criteria": ["<human>"],
}
print(predictor.predict(payload))

Here, we have provided the payload argument "stopping_criteria": ["<human>"], which has resulted in the model response ending with the generation of the word sequence "<human>". The SageMaker JumpStart model script will accept any list of strings as desired stop words, convert this list to a valid `stopping_criteria keyword argument <https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationMixin.generate.stopping_criteria>`__ to the transformers generate API, and terminate text generation when the output sequence contains any specified stop words. This is useful for two reasons: first, inference time is reduced because the endpoint does not continue to generate undesired text beyond the stop words, and, second, this prevents the chatbot model from hallucinating additional human and bot responses until other stop criteria are met. ***

5. Use a shell interpreter to interact with your deployed endpoint


OpenChatKit provides a command line shell to interact with their chatbot. In the following code blocks, we provide a bare-bones simplification of the inference scripts in this OpenChatKit repository that can interact with our deployed SageMaker endpoint. There are two main components to this: 1. A shell interpreter (JumpStartChatbotShell) that allows for iterative inference invocations of the model endpoint, and 2. A conversation object (Conversation) that stores previous human/chatbot interactions locally within the interactive shell and appropriately formats past conversations for future inference context.


[ ]:
import cmd
import re
from typing import List, Optional


class Conversation:
    MEANINGLESS_WORDS = ["<pad>", "</s>", "<|endoftext|>"]

    def __init__(self, human_id, bot_id):
        self.human_tag = f"{human_id}:"
        self.bot_tag = f"{bot_id}:"
        self.history = ""

    def clean_response(self, response):
        for word in self.MEANINGLESS_WORDS:
            response = response.replace(word, "")
        response = response.strip("\n")
        return response

    def push_human_turn(self, query):
        self.history += f"{self.human_tag} {query}\n{self.bot_tag}"

    def push_model_response(self, response):
        bot_turn = response.split(f"{self.human_tag}")[0]
        bot_turn = self.clean_response(bot_turn)
        self.history += f"{bot_turn}\n"

    def get_last_turn(self):
        turns = re.split(f"({self.human_tag}|{self.bot_tag})\W?", self.history)
        return turns[-1]


class JumpStartChatbotShell(cmd.Cmd):
    intro = (
        "Welcome to the SageMaker JumpStart chatbot shell! Type /help or /? to list commands. "
        "Type /quit to exit shell.\n"
    )
    prompt = ">>> "
    response_prefix = "<<< "
    human_id = "<human>"
    bot_id = "<bot>"

    def __init__(self, predictor: Predictor, cmd_queue: Optional[List[str]] = None, **kwargs):
        super().__init__()
        self.predictor = predictor
        self.payload_kwargs = kwargs
        self.payload_kwargs["stopping_criteria"] = [self.human_id]
        if cmd_queue is not None:
            self.cmdqueue = cmd_queue

    def preloop(self):
        self.conversation = Conversation(self.human_id, self.bot_id)

    def precmd(self, line):
        command = line[1:] if line.startswith("/") else "say " + line
        return command

    def do_say(self, arg):
        self.conversation.push_human_turn(arg)
        history = self.conversation.history
        payload = {"text_inputs": history, **self.payload_kwargs}
        response = self.predictor.predict(payload)[len(history) :]
        self.conversation.push_model_response(response)
        print(f"{self.response_prefix}{self.conversation.get_last_turn()}")

    def do_reset(self, arg):
        self.conversation = Conversation(self.human_id, self.bot_id)

    def do_hyperparameters(self, arg):
        print(f"Hyperparameters: {self.payload_kwargs}\n")

    def do_quit(self, arg):
        return True

We can now launch this shell as a command loop. This will repeatedly issue a prompt, accept input, parse the input command, and dispatch actions. Because the resulting shell may be utilized in an infinite loop, this notebook provides a default command queue (cmdqueue) as a queued list of input lines; when the last command in the queue, /quit, is executed, the shell will terminate. To dynamically interact with this chatbot, please remove the cmdqueue.


[ ]:
cmd_queue = [
    "Hello!",
    "Make a markdown table of national parks with the state they are located in and date established.",
    "/hyperparameters",
    "/quit",
]
payload_kwargs_default = {
    "max_new_tokens": 128,
    "do_sample": True,
    "temperature": 0.6,
    "top_k": 40,
}
JumpStartChatbotShell(
    predictor=predictor,
    cmd_queue=cmd_queue,
    **{**payload_kwargs_default, **model_config.payload_kwargs},
).cmdloop()

And that’s it! Just a quick reminder: you can comment out the cmd_queue in the above cell to have an interactive dialog with the chatbot. ***

6. Clean up the endpoint

[ ]:
predictor.delete_model()
predictor.delete_endpoint()

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable