Introduction to JumpStart Image editing - Stable Diffusion Inpainting


This notebook’s CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

This us-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable


  1. Set Up

  2. Select a model

  3. Retrieve JumpStart Artifacts & Deploy an Endpoint

  4. Query endpoint and parse response

  5. Clean up the endpoint

Note: This notebook was tested on ml.t3.medium instance in Amazon SageMaker Studio with Python 3 (Data Science) kernel and in Amazon SageMaker Notebook instance with conda_python3 kernel.

Note: After you’re done running the notebook, make sure to delete all resources so that all the resources that you created in the process are deleted and your billing is stopped. Code in Clean up the endpoint deletes model and endpoints that are created.

1. Set Up

[ ]:
!pip install ipywidgets==7.0.0 --quiet
!pip install --upgrade sagemaker

Permissions and environment variables

[ ]:
import sagemaker, boto3, json
from sagemaker import get_execution_role

aws_role = get_execution_role()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()

2. Select a model


You can continue with the default model, or can choose a different model from the dropdown generated upon running the next cell. A complete list of SageMaker pre-trained models can also be accessed at Sagemaker pre-trained Models. ***

[ ]:
model_id, model_version = "model-inpainting-stabilityai-stable-diffusion-2-inpainting-fp16", "1.*"

[Optional] Here, we filter-out all the inpainting models and select a model for inference. ***

[ ]:
from ipywidgets import Dropdown
from sagemaker.jumpstart.notebook_utils import list_jumpstart_models

# Retrieves all Text-to-Image generation models.
filter_value = "task == inpainting"
inpainting_models = list_jumpstart_models(filter=filter_value)

# display the model-ids in a dropdown to select a model for inference.
model_dropdown = Dropdown(
    options=inpainting_models,
    value=model_id,
    description="Select a model",
    style={"description_width": "initial"},
    layout={"width": "max-content"},
)

Chose a model for Inference

[ ]:
display(model_dropdown)
[ ]:
# model_version="*" fetches the latest version of the model
model_id, model_version = model_dropdown.value, "1.*"

3. Retrieve Artifacts & Deploy an Endpoint


Using SageMaker, we can perform inference on the pre-trained model, even without fine-tuning it first on a new dataset. We start by retrieving the deploy_image_uri and model_uri for the pre-trained model. To host the pre-trained model, we create an instance of `sagemaker.model.Model <https://sagemaker.readthedocs.io/en/stable/api/inference/model.html>`__ and deploy it. This may take up to ten minutes on the default model_id.


[ ]:
from sagemaker import image_uris, model_uris, script_uris, hyperparameters
from sagemaker.model import Model
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base


endpoint_name = name_from_base(f"jumpstart-example-{model_id}")

# Instances with more GPU memory supports generation of larger images.
# So, please select instance types such as ml.g5.2xlarge if you want to generate a very large image.
inference_instance_type = "ml.p3.2xlarge"

# Retrieve the inference docker container uri. This is the base HuggingFace container image for the default model above.
deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    image_scope="inference",
    model_id=model_id,
    model_version=model_version,
    instance_type=inference_instance_type,
)

# Retrieve the model uri. This includes the pre-trained model and parameters as well as the inference scripts.
# This includes all dependencies and scripts for model loading, inference handling etc..
model_uri = model_uris.retrieve(
    model_id=model_id, model_version=model_version, model_scope="inference"
)

# Create the SageMaker model instance
model = Model(
    image_uri=deploy_image_uri,
    model_data=model_uri,
    role=aws_role,
    predictor_cls=Predictor,
    name=endpoint_name,
)

# deploy the Model. Note that we need to pass Predictor class when we deploy model through Model class,
# for being able to run inference through the sagemaker API.
model_predictor = model.deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    predictor_cls=Predictor,
    endpoint_name=endpoint_name,
)

4. Query endpoint and parse response

We start by downloading an example image and the mask image.


[ ]:
from IPython.display import Image

region = boto3.Session().region_name
s3_bucket = f"jumpstart-cache-prod-{region}"
key_prefix = "model-metadata/assets"
input_img_file_name = "dog_suit.jpg"

s3 = boto3.client("s3")

s3.download_file(s3_bucket, f"{key_prefix}/{input_img_file_name}", input_img_file_name)

# Displaying the original image
Image(filename=input_img_file_name, width=632, height=632)

Mask


Mask is an image where the part to be replaced is all white and the part which is constant is all black.


[ ]:
input_img_mask_file_name = "dog_suit_mask.jpg"
s3.download_file(s3_bucket, f"{key_prefix}/{input_img_mask_file_name}", input_img_mask_file_name)
Image(filename=input_img_mask_file_name, width=632, height=632)

Next we write some helper function for querying the endpoint, parsing the response and display generated image.

[ ]:
import matplotlib.pyplot as plt
import numpy as np


def query(model_predictor, payload, content_type, accept):
    """Query the model predictor."""

    query_response = model_predictor.predict(
        payload,
        {
            "ContentType": content_type,
            "Accept": accept,
        },
    )
    return query_response


def parse_response(query_response):
    """Parse response and return the generated images."""

    response_dict = json.loads(query_response)
    return response_dict["generated_images"]


def display_img_and_prompt(img, prmpt):
    """Display the generated image."""
    plt.figure(figsize=(12, 12))
    plt.imshow(np.array(img))
    plt.axis("off")
    plt.title(prmpt)
    plt.show()
[ ]:
import base64
from PIL import Image
from io import BytesIO


# content_type = 'application/json;jpeg', endpoint expects payload to be a json with the original image and the mask image as bytes encoded with base64.b64 encoding.
# To send raw image to the endpoint, you can set content_type = 'application/json' and encoded_image as np.array(PIL.Image.open(input_img_file_name.jpg)).tolist()
content_type = "application/json;jpeg"


with open(input_img_file_name, "rb") as f:
    input_img_image_bytes = f.read()
with open(input_img_mask_file_name, "rb") as f:
    input_img_mask_image_bytes = f.read()

encoded_input_image = base64.b64encode(bytearray(input_img_image_bytes)).decode()
encoded_mask = base64.b64encode(bytearray(input_img_mask_image_bytes)).decode()


payload = {
    "prompt": "a white cat, blue eyes, wearing a sweater, lying in park",
    "image": encoded_input_image,
    "mask_image": encoded_mask,
    "num_inference_steps": 50,
    "guidance_scale": 7.5,
    "seed": 0,
    "negative_prompt": "poorly drawn feet",
}


# For accept = 'application/json;jpeg', endpoint returns the jpeg image as bytes encoded with base64.b64 encoding.
# To receive raw image with rgb value set Accept = 'application/json'
accept = "application/json;jpeg"

# Note that sending or receiving payload with raw/rgb values may hit default limits for the input payload and the response size.

query_response = query(model_predictor, json.dumps(payload).encode("utf-8"), content_type, accept)
generated_images = parse_response(query_response)


# For accept = 'application/json;jpeg' mentioned above, returned image is a jpeg as bytes encoded with base64.b64 encoding.
# Here, we decode the image and display the image.
for generated_image in generated_images:
    generated_image_decoded = BytesIO(base64.b64decode(generated_image.encode()))
    generated_image_rgb = Image.open(generated_image_decoded).convert("RGB")
    # You can save the generated image by calling generated_image_rgb.save('inpainted_image.jpg')
    display_img_and_prompt(generated_image_rgb, "Inpainted image generated by the model")

Supported Parameters


This model supports many parameters while performing inference. They include:

  • prompt: prompt to guide the image generation. Must be specified and can be a string or a list of strings.

  • num_inference_steps: number of denoising steps during image generation. More steps lead to higher quality image. If specified, it must a positive integer.

  • guidance_scale: higher guidance scale results in image closely related to the prompt, at the expense of image quality. If specified, it must be a float. guidance_scale<=1 is ignored.

  • negative_prompt: guide image generation against this prompt. If specified, it must be a string or a list of strings and used with guidance_scale. If guidance_scale is disabled, this is also disabled. Moreover, if prompt is a list of strings then negative_prompt must also be a list of strings.

  • num_images_per_prompt: number of images returned per prompt. If specified it must be a positive integer.

  • seed: fix the randomized state for reproducibility. If specified, it must be an integer.

  • batch_size: Number of images to generate in a single forward pass. If using a smaller instance or generating many images, please reduce batch_size to be a small number (1-2). Number of images = number of prompts*num_images_per_prompt.


5. Clean up the endpoint

[ ]:
# Delete the SageMaker endpoint
model_predictor.delete_model()
model_predictor.delete_endpoint()

Conclusion


In this tutorial, we learnt how to deploy a pre-trained Stable Diffusion inpainting model on SageMaker using JumpStart. We saw that Stable Diffusion models is very effective at replacing part of the image and generating highly photo-realistic images. JumpStart provides both Stable Diffusion 1 and Stable Diffusion 2 and their FP16 revisions for inpainting.

You can tweak the image generation process by selecting the appropriate parameters during inference. Guidance on how to set these parameters is provided in Supported Inference parameters section.


This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

This us-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-east-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This us-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ca-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This sa-east-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-west-3 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-central-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This eu-north-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-southeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-northeast-2 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable

This ap-south-1 badge failed to load. Check your device’s internet connectivity, otherwise the service is currently unavailable