SageMaker JumpStart Foundation Models - Chatbots
This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.
Note: This notebook was tested on ml.t3.medium instance in Amazon SageMaker Studio with Python 3 (Data Science) kernel and in Amazon SageMaker Notebook instance with conda_python3 kernel.
### 1. Set up
Before executing the notebook, there are some initial steps required for set up.
[ ]:
%pip install sagemaker ipywidgets --upgrade --quiet
2. Select a pre-trained model
You can continue with the default model, or can choose a different model from the dropdown generated upon running the next cell. A complete list of SageMaker pre-trained models can also be accessed at SageMaker pre-trained Models. ***
[ ]:
from typing import NamedTuple
from typing import Dict
from typing import Any
class JumpStartChatbotModelConfig(NamedTuple):
model_id: str
model_kwargs: Dict[str, Any] = {}
payload_kwargs: Dict[str, Any] = {}
jumpstart_chatbot_models_config = [
JumpStartChatbotModelConfig(
model_id="huggingface-textgeneration-falcon-7b-instruct-bf16",
payload_kwargs={"return_full_text": True},
),
JumpStartChatbotModelConfig(
model_id="huggingface-textgeneration-falcon-40b-instruct-bf16",
payload_kwargs={"return_full_text": True},
),
JumpStartChatbotModelConfig(
model_id="huggingface-textgeneration1-redpajama-incite-chat-3B-v1-fp16",
),
JumpStartChatbotModelConfig(
model_id="huggingface-textgeneration1-redpajama-incite-chat-7B-v1-fp16",
),
JumpStartChatbotModelConfig(
model_id="huggingface-textgeneration2-gpt-neoxt-chat-base-20b-fp16",
),
]
[ ]:
from IPython.display import Markdown
from ipywidgets import Dropdown
dropdown = Dropdown(
options=[(config.model_id, config) for config in jumpstart_chatbot_models_config],
value=jumpstart_chatbot_models_config[0],
description="JumpStart Image Classification Models:",
style={"description_width": "initial"},
layout={"width": "max-content"},
)
display(Markdown("### Select a JumpStart chatbot model from the dropdown below"))
display(dropdown)
[ ]:
model_config = dropdown.value
3. Retrieve Artifacts & Deploy an Endpoint
Using SageMaker, we can perform inference on the pre-trained model, even without fine-tuning it first on a new dataset. We start by deploying a JumpStartModel
to host the pre-trained model. This may take a few minutes.
[ ]:
from sagemaker.jumpstart.model import JumpStartModel
from sagemaker.predictor import Predictor
model = JumpStartModel(
model_id=model_config.model_id, model_version="1.*", **model_config.model_kwargs
)
predictor = model.deploy()
Next, the SageMaker Predictor
is adjusted to utilize a JSON serializer and the deserializer is custom set to work with all chatbot models supported by this notebook. The deserializer ensures the chatbot will always return a string representing a single generated text sample per query. ***
[ ]:
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
class JumpStartChatbotDeserializer(JSONDeserializer):
"""A deserializer to retrieve the first generated text from JumpStart text generation models."""
def deserialize(self, stream, content_type):
"""Crawl the output of JSON deserialization to obtain first generated text model response."""
data = super().deserialize(stream, content_type)
while True:
if isinstance(data, str):
break
elif isinstance(data, list):
data = data[0]
elif isinstance(data, dict):
for key in ("generated_text", "generated_texts"):
if key in data:
data = data[key]
break
else:
raise ValueError(f"Generated text keys not found in output {data}.")
else:
raise ValueError(f"Output data contains unrecognized type {type(data)}.")
return data
predictor.serializer = JSONSerializer()
predictor.deserializer = JumpStartChatbotDeserializer()
predictor.content_type = "application/json"
4. Query endpoint and parse response
This model also supports many advanced parameters while performing inference. They include:
max_length: Model generates text until the output length (which includes the input context length) reaches
max_length
. If specified, it must be a positive integer.max_time: The maximum amount of time you allow the computation to run for in seconds. Generation will still finish the current pass after allocated time has been passed. This setting can help to generate a response prior to endpoint invocation response time out errors.
num_return_sequences: Number of output sequences returned. If specified, it must be a positive integer.
num_beams: Number of beams used in the greedy search. If specified, it must be integer greater than or equal to
num_return_sequences
.no_repeat_ngram_size: Model ensures that a sequence of words of
no_repeat_ngram_size
is not repeated in the output sequence. If specified, it must be a positive integer greater than 1.temperature: Controls the randomness in the output. Higher temperature results in output sequence with low-probability words and lower temperature results in output sequence with high-probability words. If
temperature
-> 0, it results in greedy decoding. If specified, it must be a positive float.early_stopping: If True, text generation is finished when all beam hypotheses reach the end of sentence token. If specified, it must be boolean.
do_sample: If True, sample the next word as per the likelihood. If specified, it must be boolean.
top_k: In each step of text generation, sample from only the
top_k
most likely words. If specified, it must be a positive integer.top_p: In each step of text generation, sample from the smallest possible set of words with cumulative probability
top_p
. If specified, it must be a float between 0 and 1.seed: Fix the randomized state for reproducibility. If specified, it must be an integer.
We may specify any subset of the parameters mentioned above while invoking an endpoint. Next, we show an example of how to invoke endpoint with these arguments
[ ]:
payload = {
"text_inputs": "<human>: Tell me the steps to make a pizza\n<bot>:",
"max_length": 100,
"max_time": 50,
"top_k": 50,
"top_p": 0.95,
"do_sample": True,
"stopping_criteria": ["<human>"],
}
print(predictor.predict(payload))
Here, we have provided the payload argument "stopping_criteria": ["<human>"]
, which has resulted in the model response ending with the generation of the word sequence "<human>"
. The SageMaker JumpStart model script will accept any list of strings as desired stop words, convert this list to a valid `stopping_criteria
keyword argument <https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationMixin.generate.stopping_criteria>`__ to the transformers
generate API, and terminate text generation when the output sequence contains any specified stop words. This is useful for two reasons: first, inference time is reduced because the endpoint does not continue to generate undesired text beyond the stop words, and, second, this prevents the chatbot model from hallucinating additional human and bot responses until other stop criteria are met. ***
5. Use a shell interpreter to interact with your deployed endpoint
OpenChatKit provides a command line shell to interact with their chatbot. In the following code blocks, we provide a bare-bones simplification of the inference scripts in this OpenChatKit repository that can interact with our deployed SageMaker endpoint. There are two main components to this: 1. A shell interpreter (JumpStartChatbotShell
) that allows for iterative inference invocations of the model endpoint, and 2. A conversation object
(Conversation
) that stores previous human/chatbot interactions locally within the interactive shell and appropriately formats past conversations for future inference context.
[ ]:
import cmd
import re
from typing import List, Optional
class Conversation:
MEANINGLESS_WORDS = ["<pad>", "</s>", "<|endoftext|>"]
def __init__(self, human_id, bot_id):
self.human_tag = f"{human_id}:"
self.bot_tag = f"{bot_id}:"
self.history = ""
def clean_response(self, response):
for word in self.MEANINGLESS_WORDS:
response = response.replace(word, "")
response = response.strip("\n")
return response
def push_human_turn(self, query):
self.history += f"{self.human_tag} {query}\n{self.bot_tag}"
def push_model_response(self, response):
bot_turn = response.split(f"{self.human_tag}")[0]
bot_turn = self.clean_response(bot_turn)
self.history += f"{bot_turn}\n"
def get_last_turn(self):
turns = re.split(f"({self.human_tag}|{self.bot_tag})\W?", self.history)
return turns[-1]
class JumpStartChatbotShell(cmd.Cmd):
intro = (
"Welcome to the SageMaker JumpStart chatbot shell! Type /help or /? to list commands. "
"Type /quit to exit shell.\n"
)
prompt = ">>> "
response_prefix = "<<< "
human_id = "<human>"
bot_id = "<bot>"
def __init__(self, predictor: Predictor, cmd_queue: Optional[List[str]] = None, **kwargs):
super().__init__()
self.predictor = predictor
self.payload_kwargs = kwargs
self.payload_kwargs["stopping_criteria"] = [self.human_id]
if cmd_queue is not None:
self.cmdqueue = cmd_queue
def preloop(self):
self.conversation = Conversation(self.human_id, self.bot_id)
def precmd(self, line):
command = line[1:] if line.startswith("/") else "say " + line
return command
def do_say(self, arg):
self.conversation.push_human_turn(arg)
history = self.conversation.history
payload = {"text_inputs": history, **self.payload_kwargs}
response = self.predictor.predict(payload)[len(history) :]
self.conversation.push_model_response(response)
print(f"{self.response_prefix}{self.conversation.get_last_turn()}")
def do_reset(self, arg):
self.conversation = Conversation(self.human_id, self.bot_id)
def do_hyperparameters(self, arg):
print(f"Hyperparameters: {self.payload_kwargs}\n")
def do_quit(self, arg):
return True
We can now launch this shell as a command loop. This will repeatedly issue a prompt, accept input, parse the input command, and dispatch actions. Because the resulting shell may be utilized in an infinite loop, this notebook provides a default command queue (cmdqueue
) as a queued list of input lines; when the last command in the queue, /quit
, is executed, the shell will terminate. To dynamically interact with this chatbot, please remove the cmdqueue
.
[ ]:
cmd_queue = [
"Hello!",
"Make a markdown table of national parks with the state they are located in and date established.",
"/hyperparameters",
"/quit",
]
payload_kwargs_default = {
"max_new_tokens": 128,
"do_sample": True,
"temperature": 0.6,
"top_k": 40,
}
JumpStartChatbotShell(
predictor=predictor,
cmd_queue=cmd_queue,
**{**payload_kwargs_default, **model_config.payload_kwargs},
).cmdloop()
And that’s it! Just a quick reminder: you can comment out the cmd_queue
in the above cell to have an interactive dialog with the chatbot. ***
6. Clean up the endpoint
[ ]:
predictor.delete_model()
predictor.delete_endpoint()
Notebook CI Test Results
This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.