Fine-tuning and deploying the Mixtral 8x7B LLM In SageMaker with Hugging Face, using QLoRA Parameter-Efficient Fine-Tuning


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


The Mixtral 8x7B Large Language Model by Mistral AI has 46.7 billion parameters, but uses only 12.9 billion per token, thanks to its Mixture of Experts architecture. The model masters 5 languages ​​(French, Spanish, Italian, English and German) and outperforms the much larger Llama 2 70B model from Meta. An instruct version of the model, trained to follow instructions is also available.
QLoRA is a parameter-efficient fine-tuning technique that allows for fine-tuning LLMs in less memory, without changing the weights of the model, but by adding to them. This not only leads to good performance, but it mitigates the risk of Catastrophic Forgetting that comes with regular full fine-tuning. QLoRA:
  1. Freezes model weights, and quantizes the pretrained model to 4 bits.

  2. Attaches additional trainable adapter layers.

  3. Fine-tunes these layers, without changing the frozen, quantized model (while using it as context).

In this notebook, you will learn how to fine-tune the 8x7B model using Hugging Face on Amazon SageMaker. You’ll use the Hugging Face Transformers framework and the Hugging Face extension to the SageMaker Python SDK to fine-tune Mixtral with QLoRA on an example instruction dataset, and run the tuned model in a Hugging Face Deep-Learning Container (DLC) on a SageMaker real-time inference endpoint. This notebook can be run from an Amazon SageMaker Studio notebook or a SageMaker notebook instance, and outside SageMaker (for example on your laptop/development machine). In the latter case you’ll need to handle authentication to SageMaker and other AWS services used in the notebook. When you run the notebook on SageMaker this will be handled for you.

Files

scripts/run_clm.py: The entry point script that’ll be passed to the Hugging Face estimator later in this notebook when launching the QLoRA fine-tuning job (from here)
scripts/requirements.txt: This takes care of installing some dependencies for the fune-tuning job, like Hugging Face Transformers and the PEFT library.

Prerequisites

You need to create an S3 bucket to store the input data for training. This bucket must be located in the same AWS Region that you choose to launch your training job. To learn how to create a S3 bucket, see Create your first S3 bucket in the Amazon S3 documentation. You can also just use the default bucket for the SageMaker session you create without specifying a specific bucket name.

Launching Environment

Amazon SageMaker Notebook

You can run the notebook on an Amazon SageMaker Studio notebook, or a SageMaker notebook instance without manually setting your aws credentials.

Create a new SageMaker notebook instance and open it. Zip the contents of this folder & upload to the instance with the Upload button on the top-right. Open a new terminal with New -> Terminal. Within the terminal, enter the correct directory and unzip the file. cd SageMaker && unzip .zip

Locally

You can run locally by launching a Jupyter notebook server with Jupyter notebook. This requires you to set your aws credentials in the environment manually. See Configure the AWS CLI for more details.

Amazon SageMaker Initialization

Run the following cell to upgrade the SageMaker SDK and the Transformers framework to the latest version. You may need to restart the notebook kernel for the changes to take effect.

[ ]:
%pip install --quiet --upgrade transformers datasets sagemaker s3fs

import SageMaker modules and retrieve information of your current SageMaker work environment, such as the AWS Region and the ARN of your Amazon SageMaker execution role.

[ ]:
import sagemaker
import boto3

sess = sagemaker.Session()

# gets role
try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client("iam")
    role = iam.get_role(RoleName="AmazonSageMaker-ExecutionRole-20231209T154667")["Role"]["Arn"]

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

Here we load the Dolly-15k dataset. This is a high-quality set of prompt/response pairs, human-generated; perfect for instruction fine-tuning LLMs like Mixtral.

[ ]:
from datasets import load_dataset
from random import randrange

# Load dataset from the hub
dataset = load_dataset("databricks/databricks-dolly-15k", split="train")

print(f"dataset size: {len(dataset)}")
print(dataset[randrange(len(dataset))])

Formatting function to convert our data into task prompts. The function takes a sample of the dataset and outputs a prompt string.

[ ]:
def format_dolly(sample):
    instruction = f"### Instruction\n{sample['instruction']}"
    context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
    response = f"### Answer\n{sample['response']}"
    # join all the parts together
    prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
    return prompt
[ ]:
from random import randrange

print(format_dolly(dataset[randrange(len(dataset))]))

Now, we load the tokenizer from the pre-trained Mixtral model, add an EOS token to each sample, tokenize the data and pack it in chunks of 2048 tokens.

[ ]:
from transformers import AutoTokenizer

model_id = "mistralai/Mixtral-8x7B-v0.1"  # sharded weights
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
[ ]:
from random import randint
from itertools import chain
from functools import partial


# template dataset to add prompt to each sample
def template_dataset(sample):
    sample["text"] = f"{format_dolly(sample)}{tokenizer.eos_token}"
    return sample


# apply prompt template per sample
dataset = dataset.map(template_dataset, remove_columns=list(dataset.features))
# print random sample
print(dataset[randint(0, len(dataset))]["text"])

# empty list to save remainder from batches to use in next batch
remainder = {"input_ids": [], "attention_mask": [], "token_type_ids": []}


def chunk(sample, chunk_length=2048):
    # define global remainder variable to save remainder from batches to use in next batch
    global remainder
    # Concatenate all texts and add remainder from previous batch
    concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()}
    concatenated_examples = {
        k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()
    }
    # get total number of tokens for batch
    batch_total_length = len(concatenated_examples[list(sample.keys())[0]])

    # get max number of chunks for batch
    if batch_total_length >= chunk_length:
        batch_chunk_length = (batch_total_length // chunk_length) * chunk_length

    # Split by chunks of max_len.
    result = {
        k: [t[i : i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)]
        for k, t in concatenated_examples.items()
    }
    # add remainder to global variable for next batch
    remainder = {
        k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()
    }
    # prepare labels
    result["labels"] = result["input_ids"].copy()
    return result


# tokenize and chunk dataset
lm_dataset = dataset.map(
    lambda sample: tokenizer(sample["text"]),
    batched=True,
    remove_columns=list(dataset.features),
).map(
    partial(chunk, chunk_length=2048),
    batched=True,
)

# Print total number of samples
print(f"Total number of samples: {len(lm_dataset)}")

Save our processed data to S3 for use in the training job.

[ ]:
import s3fs

# save train_dataset to s3
training_input_path = f"s3://{sess.default_bucket()}/processed/mixtral/dolly/train"
lm_dataset.save_to_disk(training_input_path)

print("uploaded data to:")
print(f"training dataset to: {training_input_path}")

run_clm.py is the entrypoint script for the training job. It implements QLoRA using PEFT to train our model. It merges the fine-tuned LoRA weights into the model weights after training, so you can use the resulting model as normal. Don’t forget to add the requirements.txt into your source_dir folder - that way SageMaker will install the needed libraries, including peft (provides the LoRA API), and bitsandbytes for quantization of the pre-trained model to use in the QLoRA training job.

We use a single g5.24xlarge instance (with 4 24 GB A10G GPUs) for the training job. The quantization that QLoRA provides reduces the memory requirements for the job such that it fits on that instance and doesn’t need an instance type with 8 GPUs. Training for 3 epochs took 9 1/2 hours in my case. If you’re in a hurry and just want to see proof of the concept, reducing the number of epochs helps.

These large GPU instances aren’t available in every AWS region, so make sure that you’re in an AWS region that has g5.24xlarge instances (and you have the quota in your AWS account to use one additional).

[ ]:
import time
from sagemaker.huggingface import HuggingFace

# define Training Job Name
job_name = f'mixtral-8x7b-qlora-{time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())}'

# hyperparameters, which are passed into the training job
hyperparameters = {
    "model_id": model_id,  # pre-trained model
    "dataset_path": "/opt/ml/input/data/training",  # path where sagemaker will save training dataset
    "epochs": 3,  # number of training epochs
    "per_device_train_batch_size": 2,  # batch size for training
    "lr": 2e-4,  # learning rate used during training
    "merge_weights": True,  # wether to merge LoRA into the model (needs more memory)
}

# create the Estimator
huggingface_estimator = HuggingFace(
    entry_point="run_clm.py",  # train script
    source_dir="scripts",  # directory which includes the entrypoint script and the requirements.txt for our training environment
    instance_type="ml.g5.24xlarge",  # instances type used for the training job
    instance_count=1,  # the number of instances used for training
    base_job_name=job_name,  # the name of the training job
    role=role,  # Iam role used in training job to access AWS ressources, e.g. S3
    volume_size=300,  # the size of the EBS volume in GB
    transformers_version="4.28",  # the transformers version used in the training job
    pytorch_version="2.0",  # the pytorch_version version used in the training job
    py_version="py310",  # the python version used in the training job
    hyperparameters=hyperparameters,  # the hyperparameters passed to the training job
    environment={
        "HUGGINGFACE_HUB_CACHE": "/tmp/.cache"
    },  # set env variable to cache models in /tmp
)
[ ]:
# define a data input dictonary with our uploaded s3 uris
data = {"training": training_input_path}

# starting the train job with our uploaded datasets as input
huggingface_estimator.fit(data, wait=True)

Load the Hugging Face LLM inference container that will run the model as a real-time SageMaker inference endpoint.

[ ]:
# 1.3.3 is not yet available in the SDK as of January 2024, when it becomes available, use this rather than static image mapping
# from sagemaker.huggingface import get_huggingface_llm_image_uri

## retrieve the llm image uri
# llm_image = get_huggingface_llm_image_uri(
#  "huggingface",
#  version = "1.3.3"
# )

region_mapping = {
    "af-south-1": "626614931356",
    "il-central-1": "780543022126",
    "ap-east-1": "871362719292",
    "ap-northeast-1": "763104351884",
    "ap-northeast-2": "763104351884",
    "ap-northeast-3": "364406365360",
    "ap-south-1": "763104351884",
    "ap-south-2": "772153158452",
    "ap-southeast-1": "763104351884",
    "ap-southeast-2": "763104351884",
    "ap-southeast-3": "907027046896",
    "ap-southeast-4": "457447274322",
    "ca-central-1": "763104351884",
    "cn-north-1": "727897471807",
    "cn-northwest-1": "727897471807",
    "eu-central-1": "763104351884",
    "eu-central-2": "380420809688",
    "eu-north-1": "763104351884",
    "eu-west-1": "763104351884",
    "eu-west-2": "763104351884",
    "eu-west-3": "763104351884",
    "eu-south-1": "692866216735",
    "eu-south-2": "503227376785",
    "me-south-1": "217643126080",
    "me-central-1": "914824155844",
    "sa-east-1": "763104351884",
    "us-east-1": "763104351884",
    "us-east-2": "763104351884",
    "us-gov-east-1": "446045086412",
    "us-gov-west-1": "442386744353",
    "us-iso-east-1": "886529160074",
    "us-isob-east-1": "094389454867",
    "us-west-1": "763104351884",
    "us-west-2": "763104351884",
}

llm_image = f"{region_mapping[sess.boto_region_name]}.dkr.ecr.{sess.boto_region_name}.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.3.3-gpu-py310-cu121-ubuntu20.04-v1.0"

print(f"llm image uri: {llm_image}")

Now take the instruct-tuned model from S3, and deploy it. Make sure that you’re in an AWS region that has g5.48xlarge instances (and you have the quota in your AWS account to use one additional).

[ ]:
s3_uri = huggingface_estimator.model_data
print(s3_uri)
[ ]:
import json
from sagemaker.huggingface import HuggingFaceModel

# sagemaker config
instance_type = "ml.g5.48xlarge"
number_of_gpu = 8
health_check_timeout = 300

# Define Model and Endpoint configuration parameter
config = {
    "HF_MODEL_ID": "/opt/ml/model",
    "SM_NUM_GPUS": json.dumps(number_of_gpu),  # Number of GPU used per replica
    "MAX_INPUT_LENGTH": json.dumps(24000),  # Max length of input text
    "MAX_BATCH_PREFILL_TOKENS": json.dumps(32000),  # Number of tokens for the prefill operation.
    "MAX_TOTAL_TOKENS": json.dumps(32000),  # Max length of the generation (including input text)
    "MAX_BATCH_TOTAL_TOKENS": json.dumps(
        512000
    ),  # Limits the number of tokens that can be processed in parallel during the generation
}

# create HuggingFaceModel with the image uri
llm_model = HuggingFaceModel(model_data=s3_uri, role=role, image_uri=llm_image, env=config)
[ ]:
# Deploy model to an endpoint
# https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#sagemaker.model.Model.deploy

endpoint_name = sagemaker.utils.name_from_base("Mixtral-8x7B")

llm = llm_model.deploy(
    endpoint_name=endpoint_name,
    initial_instance_count=1,
    instance_type=instance_type,
    container_startup_health_check_timeout=health_check_timeout,  # 10 minutes to be able to load the model
)

Let’s send a prompt! The resulting completion is well-aligned to instructions, quite accurate and concise.

[ ]:
# Prompt to generate
prompt = "What is Amazon SageMaker?"

# Generation arguments
payload = {
    "do_sample": True,
    "top_p": 0.6,
    "temperature": 0.1,
    "top_k": 50,
    "max_new_tokens": 1024,
    "repetition_penalty": 1.03,
    "return_full_text": False,
    "stop": ["</s>"],
}
[ ]:
chat = llm.predict({"inputs": prompt, "parameters": payload})

print(chat[0]["generated_text"])

Finally, cleanup. Delete the SageMaker model and endpoint.

[ ]:
llm.delete_model()
llm.delete_endpoint()

Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable