Generate fun images of your dog


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


  1. Set Up

  2. Fine-tune the pre-trained model on a custom dataset

## 1. Set Up

Set up credentials and create a local training directory where you will upload your training images.


[ ]:
import botocore
import sagemaker, boto3, json
from sagemaker import get_execution_role
import os


aws_role = get_execution_role()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()

# If uploading to a different folder, change this variable.
local_training_dataset_folder = "training_images"
if not os.path.exists(local_training_dataset_folder):
    os.mkdir(local_training_dataset_folder)

Please upload images of your dog to training_images local folder and change use_local_images=True.


[ ]:
use_local_images = False  # If False, notebook will use the example dataset provided by JumpStart


if not use_local_images:
    # Downloading example dog images from JumpStart S3 bucket

    s3_resource = boto3.resource("s3")
    bucket = s3_resource.Bucket(f"jumpstart-cache-prod-{aws_region}")
    for obj in bucket.objects.filter(Prefix="training-datasets/dogs_sd_finetuning/"):
        bucket.download_file(
            obj.key, os.path.join(local_training_dataset_folder, obj.key.split("/")[-1])
        )  # save to same path
[ ]:
# Instance prompt refers to the textual description of images in the training dataset. Try to be as detailed and as accurate as possible.
# In addition to the textual description, we also need a tag (Doppler in the example below).

instance_prompt = "A photo of a Doppler dog"
[ ]:
# Instance prompt is fed into the training script via dataset_info.json present in the training folder. Here, we write that file.
import os
import json

with open(os.path.join(local_training_dataset_folder, "dataset_info.json"), "w") as f:
    f.write(json.dumps({"instance_prompt": instance_prompt}))

Upload dataset to S3

[ ]:
mySession = boto3.session.Session()
AwsRegion = mySession.region_name
account_id = boto3.client("sts").get_caller_identity().get("Account")

training_bucket = f"stable-diffusion-jumpstart-{AwsRegion}-{account_id}"

If you have an existing bucket you would like to use, please replace the training_bucket with your bucket in the cell above and avoid executing the following cell.


[ ]:
assets_bucket = f"jumpstart-cache-prod-{AwsRegion}"


s3 = boto3.client("s3")
s3.download_file(
    f"jumpstart-cache-prod-{AwsRegion}",
    "ai_services_assets/custom_labels/cl_jumpstart_ic_notebook_utils.py",
    "utils.py",
)


from utils import create_bucket_if_not_exists

create_bucket_if_not_exists(training_bucket)

Next we upload the training datasets (images and dataset_info.json) to the S3 bucket.


[ ]:
train_s3_path = f"s3://{training_bucket}/custom_dog_stable_diffusion_dataset/"

!aws s3 cp --recursive $local_training_dataset_folder $train_s3_path

2. Fine-tune the pre-trained model on a custom dataset

2.1. Retrieve Training Artifacts

[ ]:
from sagemaker import image_uris, model_uris, script_uris

train_model_id, train_model_version, train_scope = (
    "model-txt2img-stabilityai-stable-diffusion-v2-1-base",
    "1.*",
    "training",
)

# Tested with ml.g4dn.2xlarge (16GB GPU memory) and ml.g5.2xlarge (24GB GPU memory) instances. Other instances may work as well.
# If ml.g5.2xlarge instance type is available, please change the following instance type to speed up training.
training_instance_type = "ml.g4dn.2xlarge"

# Retrieve the docker image
train_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    model_id=train_model_id,
    model_version=train_model_version,
    image_scope=train_scope,
    instance_type=training_instance_type,
)

# Retrieve the training script. This contains all the necessary files including data processing, model training etc.
train_source_uri = script_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, script_scope=train_scope
)
# Retrieve the pre-trained model tarball to further fine-tune
train_model_uri = model_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, model_scope=train_scope
)

2.2. Set Training parameters

[ ]:
output_bucket = sess.default_bucket()
output_prefix = "jumpstart-example-sd-training"

s3_output_location = f"s3://{output_bucket}/{output_prefix}/output"
[ ]:
from sagemaker import hyperparameters

# Retrieve the default hyper-parameters for fine-tuning the model
hyperparameters = hyperparameters.retrieve_default(
    model_id=train_model_id, model_version=train_model_version
)

# [Optional] Override default hyperparameters with custom values. This controls the duration of the training and the quality of the output.
# If max_steps is too small, training will be fast but the the model will not be able to generate custom images for your usecase.
# If max_steps is too large, training will be very slow.
hyperparameters["max_steps"] = "200"
print(hyperparameters)

### 2.3. Start Training

We start by creating the estimator object with all the required assets and then launch the training job. It takes less than 10 mins on the default dataset.


[ ]:
%time
from sagemaker.estimator import Estimator
from sagemaker.utils import name_from_base
from sagemaker.tuner import HyperparameterTuner

training_job_name = name_from_base(f"jumpstart-example-{train_model_id}-transfer-learning")

# Create SageMaker Estimator instance
sd_estimator = Estimator(
    role=aws_role,
    image_uri=train_image_uri,
    source_dir=train_source_uri,
    model_uri=train_model_uri,
    entry_point="transfer_learning.py",  # Entry-point file in source_dir and present in train_source_uri.
    instance_count=1,
    instance_type=training_instance_type,
    max_run=360000,
    hyperparameters=hyperparameters,
    output_path=s3_output_location,
    base_job_name=training_job_name,
)

# Launch a SageMaker Training job by passing s3 path of the training data
sd_estimator.fit({"training": train_s3_path}, logs=True)

2.4. Deploy and run inference on the fine-tuned model


A trained model does nothing on its own. We now want to use the model to perform inference. For this example, that means predicting the bounding boxes of an image. We start by retrieving the JumpStart artifacts for deploying an endpoint. However, instead of base_predictor, we deploy the sd_estimator that we have fine-tuned.


[ ]:
%time

inference_instance_type = "ml.g4dn.2xlarge"

# Retrieve the inference docker container uri
deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    image_scope="inference",
    model_id=train_model_id,
    model_version=train_model_version,
    instance_type=inference_instance_type,
)
# Retrieve the inference script uri. This includes scripts for model loading, inference handling etc.
deploy_source_uri = script_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, script_scope="inference"
)

endpoint_name = name_from_base(f"jumpstart-example-FT-{train_model_id}-")

# Use the estimator from the previous step to deploy to a SageMaker endpoint
finetuned_predictor = sd_estimator.deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    entry_point="inference.py",  # entry point file in source_dir and present in deploy_source_uri
    image_uri=deploy_image_uri,
    source_dir=deploy_source_uri,
    endpoint_name=endpoint_name,
)
[ ]:
import matplotlib.pyplot as plt
import numpy as np


def query(model_predictor, text):
    """Query the model predictor."""

    encoded_text = json.dumps(text).encode("utf-8")

    query_response = model_predictor.predict(
        encoded_text,
        {
            "ContentType": "application/x-text",
            "Accept": "application/json",
        },
    )
    return query_response


def parse_response(query_response):
    """Parse response and return generated image and the prompt"""

    response_dict = json.loads(query_response)
    return response_dict["generated_image"], response_dict["prompt"]


def display_img_and_prompt(img, prmpt):
    """Display hallucinated image."""
    plt.figure(figsize=(12, 12))
    plt.imshow(np.array(img))
    plt.axis("off")
    plt.title(prmpt)
    plt.show()
[ ]:
all_prompts = [
    "A photo of a Doppler dog on a beach",
    "A pencil sketch of a Doppler dog",
    "A photo of a Doppler dog with a hat",
]
for prompt in all_prompts:
    query_response = query(finetuned_predictor, prompt)
    img, _ = parse_response(query_response)
    display_img_and_prompt(img, prompt)
[ ]:
# Delete the SageMaker endpoint
finetuned_predictor.delete_model()
finetuned_predictor.delete_endpoint()

## Conclusion

In this notebook, we saw a simple workflow on how you can fine-tune the stable diffusion text-to-image model on your dataset with a small set of images. You can adapt the notebook your dataset by uploading images of the desired subject and changing the prompts. For instance, if you would like to generate images of your cat, please upload cat images in the first step and change dog to cat in the instance_prompt before training and while inocking endpoint with fine-tuned model.

This notebook contains a barebone code to train and deploy the stable diffusion model. Please refer to the Introduction to JumpStart - Text to Image for additional features such as (i) How to deploy a pre-trained Stable Diffusion model (more than 80 available in JumpStart), (ii) How to set parameters such as num_steps, guidance scale during inference, (iii) Prompt Engineering, (iv) How to set training related parameters.


Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable