Retrieval-Augmented Generation: Question Answering using Llama-2 and Text Embedding Models


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


In this notebook we will demonstrate how to use Llama-2-7b to answer questions using a library of documents as a reference, by using document embeddings and retrieval. Unlike other RAG solutions, embeddings will be generated and combined with the embedding model to identify the nearest neighbors, all from a single endpoint in this solution.

To perform inference on the Llama models, you need to pass custom_attributes=‘accept_eula=true’ as part of header. This means you have read and accept the end-user-license-agreement (EULA) of the model. EULA can be found in model card description or from this webpage.

Note: Custom_attributes used to pass EULA are key/value pairs. The key and value are separated by ‘=’ and pairs are separated by ‘;’. If the user passes the same key more than once, the last value is kept and passed to the script handler (i.e., in this case, used for conditional logic). For example, if ‘accept_eula=false; accept_eula=true’ is passed to the server, then ‘accept_eula=true’ is kept and passed to the script handler.

Other Retrieval Augmented Generation Solutions - - Question Answering using LangChain and Cohere’s Generate and Embedding Models from SageMaker JumpStart - Question Answering based on Custom Dataset - Question Answering based on Custom Dataset with Open-sourced LangChain Library - Question Answering using LLama-2, Pinecone & Custom Dataset

Step 1. Deploy Llama-2 7 Billion Chat Model in SageMaker JumpStart

[ ]:
!pip install -qU \
    sagemaker \
    pinecone-client==2.2.1 \
    ipywidgets==7.0.0

To begin, we will initialize all of the SageMaker session variables we’ll need to use throughout the walkthrough.

[ ]:
import sagemaker
from sagemaker.jumpstart.model import JumpStartModel

role = sagemaker.get_execution_role()

my_model = JumpStartModel(model_id="meta-textgeneration-llama-2-7b-f")

We will use a ml.g5.4xlarge instance to deploy our Llama-2-7 billion model. We can find pricing for all instances here.

[ ]:
predictor_llm = my_model.deploy(initial_instance_count=1, instance_type="ml.g5.4xlarge")

To gain an understanding of the necessity for a retrieval-augmented generation (RAG) approach in addressing the question and answering problem, please refer to this question_answering_pinecone_llama-2_jumpstart.ipynb

Step 2. Use Text Embedding to identify the correct documents, and use them along with prompt and question to query LLM

We plan to use document embeddings to fetch the most relevant documents in our document knowledge library and combine them with the prompt that we provide to LLM.

To achieve that, we will do following.

  • Running a text embedding model training job. The training job will generate embeddings for dataset provided and save them along with the model. These embeddings will be utilized during inference to find the nearest neighbors for an input sentence. The nearest neighbor is based on the cosine similarity between the input sentence embedding and already computed sentence embeddings during the training job. To get more information please refer to text-embedding-sentence-similarity.ipynb

  • Query the text embedding model endpoint created above to Identify top K most relevant documents based on user query

  • Combine the retrieved documents with prompt and question and send them into LLM.

Note: We are saving the dataset here with the model only to get the most similar document unlike the other RAG solutions.

Note: The retrieved document/text should be large enough to contain enough information to answer a question; but small enough to fit into the LLM prompt – maximum sequence length of 1024 tokens.

To train and host on Amazon Sagemaker, we need to setup and authenticate the use of AWS services. Here, we use the execution role associated with the current notebook instance as the AWS account role with SageMaker access. It has necessary permissions, including access to your data in S3.

[ ]:
import sagemaker, boto3, json
from sagemaker.session import Session

sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()
[ ]:
# We are using the huggingface-sentencesimilarity-gte-small model to get embeddings. A different model can also be used.
model_id = "huggingface-sentencesimilarity-gte-small"

2.1. Preparing Dataset

[ ]:
# In this section, we'll be fetching and prepping the Amazon_SageMaker_FAQs dataset to utilize it in finding the nearest neighbour to an input question.

import pandas as pd

!aws s3 cp s3://jumpstart-cache-prod-us-west-2/training-datasets/Amazon_SageMaker_FAQs/Amazon_SageMaker_FAQs.csv Amazon_SageMaker_FAQs.csv

# Preparing the Data in the required format

data = pd.read_csv("Amazon_SageMaker_FAQs.csv", names=["Questions", "Answers"])
data["id"] = data.index

data_req = data[["id", "Answers"]]

data_req.to_csv("data.csv", index=False, header=False)

# Uploading the data in required format to s3 Bucket
output_bucket = sess.default_bucket()
output_prefix = "jumpstart-example-ss-training"

s3_output_location = f"s3://{output_bucket}/{output_prefix}/output"
training_dataset_s3_path = f"s3://{output_bucket}/{output_prefix}/data/data.csv"

!aws s3 cp data.csv {training_dataset_s3_path}

2.2. Getting the Embeddings for the Input data and Training Job

[ ]:
from sagemaker import hyperparameters
from sagemaker.jumpstart.estimator import JumpStartEstimator

# Retrieve the default hyper-parameters for the model
hyperparameters = hyperparameters.retrieve_default(model_id=model_id, model_version="*")

# [Optional] Override default hyperparameters with custom values
# max_seq_length parameter is the max sequence length of the input to process by the embedding model. The default None value results in using the default max_seq_length for the model.
hyperparameters["batch_size"] = "64"
print(hyperparameters)

estimator = JumpStartEstimator(
    model_id=model_id, hyperparameters=hyperparameters, output_path=s3_output_location
)

# Launch a SageMaker Training job by passing s3 path of the data
estimator.fit({"training": f"s3://{output_bucket}/{output_prefix}/data"}, logs=True)

# Use the estimator from the previous step to deploy to a SageMaker endpoint
predictor_nn = estimator.deploy()

2.3. Deploy & run Inference on the model to get nearest neighbor

You can make queries to the endpoint using a JSON payload containing a batch of input texts, to find the nearest neighbors of the input text from the dataset which is provided during the training job.

  • queries: Provide the list of inputs for which to find the closest match from the training data

  • top_k: The number of closest match to find from the training data

  • mode: Supply it as “nn_train_data” for getting the nearest neighbors to input queries within the dataset provided

[ ]:
from sagemaker.serializers import JSONSerializer

newline = "\n"
predictor_nn.serializer = JSONSerializer()
predictor_nn.content_type = "application/json"

payload_nearest_neighbour = {
    "queries": ["Is R supported with Amazon SageMaker?"],
    "top_k": 1,
    "mode": "nn_train_data",
    "return_text": True,
}

response = predictor_nn.predict(payload_nearest_neighbour)

question = payload_nearest_neighbour["queries"][0]
answer = response[0][0]["text"]
# Relating the Input Question with the Answer
print(f"The input Question is: {question}{newline}" f"The Corresponding Answer is: {answer}")

2.4 Combine the retrieved documents, prompt, and question to query the LLM

Now we’re ready begin querying our LLM with a Retrieval Augmented Generation (RAG) pipeline. Let’s see how this will work step-by-step first.

[ ]:
# Get the nearest neighbour for an input question
question = "Which instances can I use with Managed Spot Training in SageMaker?"

payload_nearest_neighbour = {
    "queries": [question],
    "top_k": 2,
    "mode": "nn_train_data",
    "return_text": True,
}

response = predictor_nn.predict(payload_nearest_neighbour)[0]

# We get multiple relevant contexts here. We can use these to contruct a single `context` to feed into our LLM prompt.
contexts = [ans["text"] for ans in response]
[ ]:
max_section_len = 1000
separator = "\n"

from typing import List


def construct_context(contexts: List[str]) -> str:
    chosen_sections = []
    chosen_sections_len = 0

    for text in contexts:
        text = text.strip()
        # Add contexts until we run out of space.
        chosen_sections_len += len(text) + 2
        if chosen_sections_len > max_section_len:
            break
        chosen_sections.append(text)
    concatenated_doc = separator.join(chosen_sections)
    print(
        f"With maximum sequence length {max_section_len}, selected top {len(chosen_sections)} document sections: \n{concatenated_doc}"
    )
    return concatenated_doc
[ ]:
context_str = construct_context(contexts=contexts)
[ ]:
def create_payload(question, context_str) -> dict:
    prompt_template = """Answer the following QUESTION based on the CONTEXT
    given. If you do not know the answer and the CONTEXT doesn't
    contain the answer truthfully say "I don't know".

    CONTEXT:
    {context}


    ANSWER:
    """

    text_input = prompt_template.replace("{context}", context_str).replace("{question}", question)

    payload = {
        "inputs": [
            [
                {"role": "system", "content": text_input},
                {"role": "user", "content": question},
            ]
        ],
        "parameters": {
            "max_new_tokens": 256,
            "top_p": 0.9,
            "temperature": 0.6,
            "return_full_text": False,
        },
    }
    return payload
[ ]:
payload = create_payload(question, context_str)
out = predictor_llm.predict(payload, custom_attributes="accept_eula=true")
generated_text = out[0]["generation"]["content"]
print(f"[Input]: {question}\n[Output]: {generated_text}")

Let’s place all of this logic into a single RAG query function:

[ ]:
def rag_query(question: str) -> str:
    # Get nearest neighbor
    payload_nearest_neighbour = {
        "queries": [question],
        "top_k": 5,
        "mode": "nn_train_data",
        "return_text": True,
    }
    response = predictor_nn.predict(payload_nearest_neighbour)[0]
    # get contexts
    contexts = [ans["text"] for ans in response]
    # build the multiple contexts string
    context_str = construct_context(contexts=contexts)
    # create our retrieval augmented prompt
    payload = create_payload(question, context_str)
    # make prediction
    out = predictor_llm.predict(payload, custom_attributes="accept_eula=true")
    return out[0]["generation"]["content"]

We can now ask the question:

[ ]:
rag_query("Does SageMaker support spot instances?")

We can also ask questions about things that are out of context (not contained within our dataset). From this we expect the model to not hallucinate and honestly tell us that it does not know the answer:

[ ]:
rag_query("Can I deploy a model trained outside of SageMaker?")

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable