Retrieval-Augmented Generation: Question Answering based on Custom Dataset


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


Many use cases such as building a chatbot require text (text2text) generation models like BloomZ 7B1, Flan T5 XXL, and Flan T5 UL2 to respond to user questions with insightful answers. The BloomZ 7B1, Flan T5 XXL, and Flan T5 UL2 models have picked up a lot of general knowledge in training, but we often need to ingest and use a large library of more specific information.

In this notebook we will demonstrate how to use BloomZ 7B1, Flan T5 XXL, and Flan T5 UL2 to answer questions using a library of documents as a reference, by using document embeddings and retrieval. The embeddings are generated from GPT-J-6B embedding model.

This notebook serves a template such that you can easily replace the example dataset by your own to build a custom question and asnwering application.

Step 1. Deploy large language model (LLM) in SageMaker JumpStart

To better illustrate the idea, let’s first deploy all the models that are required to perform the demo. You can choose either deploying all three Flan T5 XL, BloomZ 7B1, and Flan UL2 models as the large language model (LLM) to compare their model performances, or select subset of the models based on your preference. To do that, you need modify the _MODEL_CONFIG_ python dictionary defined as below.

[ ]:
!pip install --upgrade sagemaker --quiet
!pip install ipywidgets==7.0.0 --quiet
[ ]:
import time
import sagemaker, boto3, json
from sagemaker.session import Session
from sagemaker.model import Model
from sagemaker import image_uris, model_uris, script_uris, hyperparameters
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base

sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()
model_version = "1.*"
[ ]:
def query_endpoint_with_json_payload(encoded_json, endpoint_name, content_type="application/json"):
    client = boto3.client("runtime.sagemaker")
    response = client.invoke_endpoint(
        EndpointName=endpoint_name, ContentType=content_type, Body=encoded_json
    )
    return response


def parse_response_model_flan_t5(query_response):
    model_predictions = json.loads(query_response["Body"].read())
    generated_text = model_predictions["generated_texts"]
    return generated_text


def parse_response_multiple_texts_bloomz(query_response):
    generated_text = []
    model_predictions = json.loads(query_response["Body"].read())
    for x in model_predictions[0]:
        generated_text.append(x["generated_text"])
    return generated_text

Please uncomment the entries as below if you want to deploy multiple LLM models to compare their performance.

[ ]:
_MODEL_CONFIG_ = {
    "huggingface-text2text-flan-t5-xxl": {
        "instance type": "ml.g5.12xlarge",
        "env": {"SAGEMAKER_MODEL_SERVER_WORKERS": "1", "TS_DEFAULT_WORKERS_PER_MODEL": "1"},
        "parse_function": parse_response_model_flan_t5,
        "prompt": """Answer based on context:\n\n{context}\n\n{question}""",
    },
    # "huggingface-textgeneration1-bloomz-7b1-fp16": {
    #     "instance type": "ml.g5.12xlarge",
    #     "env": {},
    #     "parse_function": parse_response_multiple_texts_bloomz,
    #     "prompt": """question: \"{question}"\\n\nContext: \"{context}"\\n\nAnswer:""",
    # },
    # "huggingface-text2text-flan-ul2-bf16": {
    #     "instance type": "ml.g5.24xlarge",
    #     "env": {
    #         "SAGEMAKER_MODEL_SERVER_WORKERS": "1",
    #         "TS_DEFAULT_WORKERS_PER_MODEL": "1"
    #     },
    #     "parse_function": parse_response_model_flan_t5,
    #     "prompt": """Answer based on context:\n\n{context}\n\n{question}""",
    # }
}
[ ]:
newline, bold, unbold = "\n", "\033[1m", "\033[0m"

for model_id in _MODEL_CONFIG_:
    endpoint_name = name_from_base(f"jumpstart-example-ragknn-{model_id}")
    inference_instance_type = _MODEL_CONFIG_[model_id]["instance type"]

    # Retrieve the inference container uri. This is the base HuggingFace container image for the default model above.
    deploy_image_uri = image_uris.retrieve(
        region=None,
        framework=None,  # automatically inferred from model_id
        image_scope="inference",
        model_id=model_id,
        model_version=model_version,
        instance_type=inference_instance_type,
    )
    # Retrieve the model uri.
    model_uri = model_uris.retrieve(
        model_id=model_id, model_version=model_version, model_scope="inference"
    )
    model_inference = Model(
        image_uri=deploy_image_uri,
        model_data=model_uri,
        role=aws_role,
        predictor_cls=Predictor,
        name=endpoint_name,
        env=_MODEL_CONFIG_[model_id]["env"],
    )
    model_predictor_inference = model_inference.deploy(
        initial_instance_count=1,
        instance_type=inference_instance_type,
        predictor_cls=Predictor,
        endpoint_name=endpoint_name,
    )
    print(f"{bold}Model {model_id} has been deployed successfully.{unbold}{newline}")
    _MODEL_CONFIG_[model_id]["endpoint_name"] = endpoint_name

Step 2. Ask a question to LLM without providing the context

To better illustrate why we need retrieval-augmented generation (RAG) based approach to solve the question and anwering problem. Let’s directly ask the model a question and see how they respond.

[ ]:
question = "Which instances can I use with Managed Spot Training in SageMaker?"
[ ]:
payload = {
    "text_inputs": question,
    "max_length": 100,
    "num_return_sequences": 1,
    "top_k": 50,
    "top_p": 0.95,
    "do_sample": True,
}


for model_id in _MODEL_CONFIG_:
    endpoint_name = _MODEL_CONFIG_[model_id]["endpoint_name"]
    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = _MODEL_CONFIG_[model_id]["parse_function"](query_response)
    print(f"For model: {model_id}, the generated output is: {generated_texts[0]}\n")

You can see the generated answer is wrong or doesn’t make much sense.

Step 3. Improve the answer to the same question using prompt engineering with insightful context

To better answer the question well, we provide extra contextual information, combine it with a prompt, and send it to model together with the question. Below is an example.

[ ]:
context = """Managed Spot Training can be used with all instances supported in Amazon SageMaker. Managed Spot Training is supported in all AWS Regions where Amazon SageMaker is currently available."""
[ ]:
parameters = {
    "max_length": 200,
    "num_return_sequences": 1,
    "top_k": 250,
    "top_p": 0.95,
    "do_sample": False,
    "temperature": 1,
}

for model_id in _MODEL_CONFIG_:
    endpoint_name = _MODEL_CONFIG_[model_id]["endpoint_name"]

    prompt = _MODEL_CONFIG_[model_id]["prompt"]

    text_input = prompt.replace("{context}", context)
    text_input = text_input.replace("{question}", question)
    payload = {"text_inputs": text_input, **parameters}

    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = _MODEL_CONFIG_[model_id]["parse_function"](query_response)
    print(
        f"{bold}For model: {model_id}, the generated output is: {generated_texts[0]}{unbold}{newline}"
    )

Step 4. Use RAG based approach to identify the correct documents, and use them along with prompt and question to query LLM

We plan to use document embeddings to fetch the most relevant documents in our document knowledge library and combine them with the prompt that we provide to LLM.

To achieve that, we will do following.

  • Generate embedings for each of document in the knowledge library with the GPT-J-6B embedding model.

  • Identify top K most relevant documents based on user query.

    • For a query of your interest, generate the embedding of the query using the same embedding model.

    • Search the indexes of top K most relevant documents in the embedding space using the SageMaker KNN algorithm.

    • Use the indexes to retrieve the corresponded documents.

  • Combine the retrieved documents with prompt and question and send them into LLM.

Note: The retrieved document/text should be large enough to contain enough information to answer a question; but small enough to fit into the LLM prompt – maximum sequence length of 1024 tokens.

4.1 Deploying the model endpoint for GPT-J-6B embedding model

[ ]:
model_id, model_version = "huggingface-textembedding-gpt-j-6b", "1.*"

endpoint_name_embed = name_from_base(f"jumpstart-example-ragknn-{model_id}")

inference_instance_type = "ml.g5.24xlarge"

# Retrieve the inference container uri. This is the base HuggingFace container image for the default model above.
deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    image_scope="inference",
    model_id=model_id,
    model_version=model_version,
    instance_type=inference_instance_type,
)
# Retrieve the model uri.
model_uri = model_uris.retrieve(
    model_id=model_id, model_version=model_version, model_scope="inference"
)
model_inference = Model(
    image_uri=deploy_image_uri,
    model_data=model_uri,
    role=aws_role,
    name=endpoint_name_embed,
    env={"TS_DEFAULT_WORKERS_PER_MODEL": "1"},
)

model_predictor_embed = model_inference.deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    endpoint_name=endpoint_name_embed,
)
[ ]:
from tqdm import tqdm


def query_endpoint_with_json_payload(encoded_json, endpoint_name, content_type="application/json"):
    client = boto3.client("runtime.sagemaker")
    response = client.invoke_endpoint(
        EndpointName=endpoint_name, ContentType=content_type, Body=encoded_json
    )
    return response


def parse_response_multiple_texts(query_response):
    model_predictions = json.loads(query_response["Body"].read())
    embeddings = model_predictions["embedding"]
    return embeddings


def build_embed_table(df_knowledge, endpoint_name_embed, col_name_4_embed, batch_size=10):
    res_embed = []
    N = df_knowledge.shape[0]
    for idx in tqdm(range(0, N, batch_size)):
        content = df_knowledge.loc[idx : (idx + batch_size - 1)][
            col_name_4_embed
        ].tolist()  ## minus -1 as pandas loc slicing is end-inclusive
        payload = {"text_inputs": content}
        query_response = query_endpoint_with_json_payload(
            json.dumps(payload).encode("utf-8"), endpoint_name_embed
        )
        generated_embed = parse_response_multiple_texts(query_response)
        res_embed.extend(generated_embed)
    res_embed_df = pd.DataFrame(res_embed)
    return res_embed_df

4.2. Generate embedings for each of document in the knowledge library with the GPT-J-6B embedding model.

For the purpose of the demo we will use Amazon SageMaker FAQs as knowledge library. The data are formatted in a CSV file with two columns Question and Answer. We use only the Answer column as the documents of knowledge library, from which relevant documents are retrieved based on a query.

Each row in the CSV format dataset corresponds to a textual document. We will iterate each document to get its embedding vector via the GPT-J-6B embedding models. For your purpose, you can replace the example dataset of your own to build a custom question and answering application.

First, we download the dataset from our S3 bucket to the local.

[ ]:
s3_path = f"s3://jumpstart-cache-prod-us-east-2/training-datasets/Amazon_SageMaker_FAQs/Amazon_SageMaker_FAQs.csv"
[ ]:
# Downloading the Database
!aws s3 cp $s3_path Amazon_SageMaker_FAQs.csv
[ ]:
import pandas as pd

df_knowledge = pd.read_csv("Amazon_SageMaker_FAQs.csv", header=None, names=["Question", "Answer"])
[ ]:
df_knowledge.head(5)

Drop the Question column since it is not used in this notebook.

[ ]:
df_knowledge.drop(["Question"], axis=1, inplace=True)
[ ]:
df_knowledge.shape
[ ]:
df_knowledge_embed = build_embed_table(
    df_knowledge, endpoint_name_embed, col_name_4_embed="Answer", batch_size=20
)
[ ]:
df_knowledge_embed.head(5)
[ ]:
assert df_knowledge_embed.shape[0] == df_knowledge.shape[0]

Save the embedding data for further usage.

[ ]:
df_knowledge_embed.to_csv("Amazon_SageMaker_FAQs_embedding.csv", header=None, index=False)

4.3. Index the embedding knowledge library using SageMaker KNN algorithm

The SageMaker KNN will conduct following.

  1. Start a training job to index the embedding knowledge data. The underlying algorithm used to index the data is Faiss.

  2. Start an endpoint to take the embedding of the query as input and return the top K nearest indexes of the documents.

Note. For the KNN training job, the features are N by P matrix, where N is the number of documetns in the knowledge library, P is the embedding dimension, and each row corresponds to an embedding of a document. The labels are ordinal integers starting from 0. During inference, given an embedding of query, the labels of the top K nearest documents with respect to the query are used as indexes to retrieve the corresponded textual documents.

We first upload the prepared dataset to the S3 bucket.

[ ]:
import numpy as np
import os
import io
import sagemaker.amazon.common as smac


train_features = np.array(df_knowledge_embed)

# Providing each answer embedding label
train_labels = np.array([i for i in range(len(train_features))])

print("train_features shape = ", train_features.shape)
print("train_labels shape = ", train_labels.shape)

buf = io.BytesIO()
smac.write_numpy_to_dense_tensor(buf, train_features, train_labels)
buf.seek(0)


bucket = sess.default_bucket()  # modify to your bucket name
prefix = "RAGDatabase"
key = "Amazon-SageMaker-RAG"

boto3.resource("s3").Bucket(bucket).Object(os.path.join(prefix, "train", key)).upload_fileobj(buf)
s3_train_data = f"s3://{bucket}/{prefix}/train/{key}"
print(f"uploaded training data location: {s3_train_data}")

We want to retrieve the top 5 most relevant documents.

[ ]:
TOP_K = 5
[ ]:
from sagemaker.amazon.amazon_estimator import get_image_uri


def trained_estimator_from_hyperparams(s3_train_data, hyperparams, output_path):
    """
    Create an Estimator from the given hyperparams, fit to training data,
    and return a deployed predictor

    """
    # set up the estimator
    knn = sagemaker.estimator.Estimator(
        get_image_uri(boto3.Session().region_name, "knn"),
        aws_role,
        instance_count=1,
        instance_type="ml.m5.2xlarge",
        output_path=output_path,
        sagemaker_session=sess,
    )
    knn.set_hyperparameters(**hyperparams)

    # train a model. fit_input contains the locations of the train data
    fit_input = {"train": s3_train_data}
    knn.fit(fit_input)
    return knn


hyperparams = {
    "feature_dim": train_features.shape[1],
    "k": TOP_K,
    "sample_size": train_features.shape[0],
    "predictor_type": "classifier",
}
output_path = f"s3://{bucket}/{prefix}/default_example/output"
knn_estimator = trained_estimator_from_hyperparams(s3_train_data, hyperparams, output_path)

Deploy the KNN endpoint for retrieving indexes of top K most relevant docuemnts.

[ ]:
from sagemaker.serializers import CSVSerializer
from sagemaker.deserializers import JSONDeserializer


def predictor_from_estimator(knn_estimator, instance_type, endpoint_name=None):
    knn_predictor = knn_estimator.deploy(
        initial_instance_count=1, instance_type=instance_type, endpoint_name=endpoint_name
    )
    knn_predictor.serializer = CSVSerializer()
    knn_predictor.deserializer = JSONDeserializer()
    return knn_predictor


instance_type = "ml.m4.xlarge"
endpoint_name = name_from_base(f"jumpstart-example-ragknn")

knn_predictor = predictor_from_estimator(knn_estimator, instance_type, endpoint_name=endpoint_name)

4.4 Retrieve the most relevant documents

Given the embedding of a query, we will query the endpoint to get the indexes of top K most relevant documents and use the indexes to retrieve the corresponded textual documents.

Next, the textual documents are concatenated with maximum length of MAX_SECTION_LEN. This is to make sure the context we send into the prompt contains a good enough amount of information all the while not exceeding model’s capacity.

[ ]:
MAX_SECTION_LEN = 500
SEPARATOR = "\n* "


def construct_context(context_predictions_arr, df_knowledge) -> str:
    chosen_sections = []
    chosen_sections_len = 0

    for index in context_predictions_arr:
        # Add contexts until we run out of space.
        document_section = df_knowledge.loc[index]
        chosen_sections_len += len(document_section) + 2
        if chosen_sections_len > MAX_SECTION_LEN:
            break

        chosen_sections.append(SEPARATOR + document_section.replace("\n", " "))
    concatenated_doc = "".join(chosen_sections)
    print(
        f"With maximum sequence length {MAX_SECTION_LEN}, selected top {len(chosen_sections)} document sections: {concatenated_doc}"
    )

    return concatenated_doc
[ ]:
start = time.time()

query_response = query_endpoint_with_json_payload(
    question, endpoint_name_embed, content_type="application/x-text"
)
question_embedding = parse_response_multiple_texts(query_response)

# Getting the most relevant context using KNN
context_predictions_arr = knn_predictor.predict(
    np.array(question_embedding),
    initial_args={"ContentType": "text/csv", "Accept": "application/json; verbose=true"},
)["predictions"][0]["labels"]

context_embed_retrieve = construct_context(context_predictions_arr, df_knowledge["Answer"])

print(
    f"{newline}{bold}Elastic time for computing the embedding of a query and retrieved the top K most relevant docuemnts: {time.time() - start} seconds.{unbold}{newline}"
)

4.5 Combine the retrieved documents, prompt, and question to query the LLM

[ ]:
for model_id in _MODEL_CONFIG_:
    endpoint_name = _MODEL_CONFIG_[model_id]["endpoint_name"]

    prompt = _MODEL_CONFIG_[model_id]["prompt"]

    text_input = prompt.replace("{context}", context_embed_retrieve)
    text_input = text_input.replace("{question}", question)

    payload = {"text_inputs": text_input, **parameters}

    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = _MODEL_CONFIG_[model_id]["parse_function"](query_response)
    print(f"For model: {model_id}, the generated output is: {generated_texts[0]}\n")

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable