{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Deploy a MLflow Model to SageMaker" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.\n", "\n", "![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/sagemaker-mlflow|sagemaker_deployment_mlflow.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup environment" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Upgrade SageMaker Python SDK" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install --upgrade --quiet sagemaker>=2.215.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Install MLflow and AWS MLflow plugin" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install mlflow==2.13.2 sagemaker-mlflow==0.1.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Import necessary libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sagemaker\n", "from sagemaker import get_execution_role\n", "from sagemaker.sklearn.estimator import SKLearn\n", "from sagemaker.serve import SchemaBuilder\n", "from sagemaker.serve import ModelBuilder\n", "from sagemaker.serve.mode.function_pointers import Mode\n", "import mlflow\n", "from mlflow import MlflowClient\n", "import boto3\n", "import numpy as np\n", "import pandas as pd\n", "import os" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Declare some variables used later" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define session, role, and region so we can\n", "# perform any SageMaker tasks we need\n", "sagemaker_session = sagemaker.Session()\n", "role = get_execution_role()\n", "region = sagemaker_session.boto_region_name\n", "\n", "# S3 prefix for the training dataset to be uploaded to\n", "prefix = \"DEMO-scikit-iris\"\n", "\n", "# Provide the ARN of the Tracking Server that you want to track your training job with\n", "tracking_server_arn = \"your tracking server arn here\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!mkdir -p training_code" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get some training data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's download the save the Iris dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "os.makedirs(\"./data\", exist_ok=True)\n", "\n", "s3_client = boto3.client(\"s3\")\n", "s3_client.download_file(\n", " f\"sagemaker-example-files-prod-{region}\", \"datasets/tabular/iris/iris.data\", \"./data/iris.csv\"\n", ")\n", "\n", "df_iris = pd.read_csv(\"./data/iris.csv\", header=None)\n", "df_iris[4] = df_iris[4].map({\"Iris-setosa\": 0, \"Iris-versicolor\": 1, \"Iris-virginica\": 2})\n", "iris = df_iris[[4, 0, 1, 2, 3]].to_numpy()\n", "np.savetxt(\"./data/iris.csv\", iris, delimiter=\",\", fmt=\"%1.1f, %1.3f, %1.3f, %1.3f, %1.3f\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And now let's upload that data to S3" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "WORK_DIRECTORY = \"data\"\n", "\n", "train_input = sagemaker_session.upload_data(\n", " WORK_DIRECTORY, key_prefix=\"{}/{}\".format(prefix, WORK_DIRECTORY)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Write your training script" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's write the code to train a Decision Tree model using the scikit-learn framework" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%writefile training_code/train.py\n", "\n", "from __future__ import print_function\n", "\n", "import argparse\n", "import joblib\n", "import os\n", "import pandas as pd\n", "\n", "from sklearn import tree\n", "\n", "import mlflow\n", "\n", "if __name__ == '__main__':\n", " parser = argparse.ArgumentParser()\n", "\n", " # Hyperparameters are described here. In this simple example we are just including one hyperparameter.\n", " parser.add_argument('--max_leaf_nodes', type=int, default=-1)\n", "\n", " # Sagemaker specific arguments. Defaults are set in the environment variables.\n", " parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])\n", " parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])\n", " parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])\n", "\n", " args = parser.parse_args()\n", "\n", " # Take the set of files and read them all into a single pandas dataframe\n", " input_files = [ os.path.join(args.train, file) for file in os.listdir(args.train) if os.path.isfile(os.path.join(args.train, file))]\n", " if len(input_files) == 0:\n", " raise ValueError(('There are no files in {}.\\n' +\n", " 'This usually indicates that the channel ({}) was incorrectly specified,\\n' +\n", " 'the data specification in S3 was incorrectly specified or the role specified\\n' +\n", " 'does not have permission to access the data.').format(args.train, \"train\"))\n", " raw_data = [ pd.read_csv(file, header=None, engine=\"python\") for file in input_files ]\n", " train_data = pd.concat(raw_data)\n", "\n", " # Set the Tracking Server URI using the ARN of the Tracking Server you created\n", " mlflow.set_tracking_uri(os.environ['MLFLOW_TRACKING_ARN'])\n", " \n", " # Enable autologging in MLflow\n", " mlflow.autolog()\n", "\n", " # labels are in the first column\n", " train_y = train_data.iloc[:, 0]\n", " train_X = train_data.iloc[:, 1:]\n", "\n", " # Here we support a single hyperparameter, 'max_leaf_nodes'. Note that you can add as many\n", " # as your training my require in the ArgumentParser above.\n", " max_leaf_nodes = args.max_leaf_nodes\n", "\n", " # Now use scikit-learn's decision tree classifier to train the model.\n", " clf = tree.DecisionTreeClassifier(max_leaf_nodes=max_leaf_nodes)\n", " clf = clf.fit(train_X, train_y)\n", "\n", " # Print the coefficients of the trained classifier, and save the coefficients\n", " joblib.dump(clf, os.path.join(args.model_dir, \"model.joblib\"))\n", " \n", " # Register the model with MLflow\n", " run_id = mlflow.last_active_run().info.run_id\n", " artifact_path = \"model\"\n", " model_uri = \"runs:/{run_id}/{artifact_path}\".format(run_id=run_id, artifact_path=artifact_path)\n", " model_details = mlflow.register_model(model_uri=model_uri, name=\"sm-job-experiment-model\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since we're using MLflow in our training script, let's make sure the container installs `mlflow` along with our MLflow plugin before running our training script. We can do this by creating a `requirements.txt` file and putting it in the same directory as our training script." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%writefile training_code/requirements.txt\n", "mlflow==2.13.2\n", "sagemaker-mlflow==0.1.0\n", "cloudpickle==2.2.1 # Required for Sagemaker Python SDK" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## SageMaker Training and MLflow" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train your Decision tree model by launching a SageMaker Training job." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sklearn = SKLearn(\n", " entry_point=\"train.py\",\n", " source_dir=\"training_code\",\n", " framework_version=\"1.2-1\",\n", " instance_type=\"ml.c4.xlarge\",\n", " role=role,\n", " sagemaker_session=sagemaker_session,\n", " hyperparameters={\"max_leaf_nodes\": 30},\n", " keep_alive_period_in_seconds=3600,\n", " environment={\"MLFLOW_TRACKING_ARN\": tracking_server_arn},\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "sklearn.fit({\"train\": train_input})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Access the model in MLflow UI and SageMaker Studio UI\n", "\n", "After the execution completes, you can find the trained model in both the MLflow UI and SageMaker Studio UI.\n", "\n", "To view the model in the MLflow UI, select the \"Models\" tab:\n", "\n", "![sagemaker-mlflow-model-registry.png](./images/sagemaker-mlflow-model-registry.png)\n", "\n", "To view the model in SageMaker Studio UI, you will need to navigate to SageMaker Studio:\n", "\n", "1. Choose a domain and launch Studio from one of the user profiles associated with it\n", "2. Select \"Models\" in the menu to see the SageMaker Model Registry. From here you will see your `sm-job-experiment-model` model\n", "\n", "![sagemaker-model-registry.png](./images/sagemaker-model-registry.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Deploy MLflow Model to SageMaker" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mlflow.set_tracking_uri(tracking_server_arn)\n", "client = MlflowClient()\n", "registered_model = client.get_registered_model(name=\"sm-job-experiment-model\")\n", "source_path = registered_model.latest_versions[0].source" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Artifact URI of the model\n", "source_path" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Define of the Schema of the sklearn model\n", "\n", "Model Builder requires the definition of the model schema, this is the input and output of the model.\n", "In this case it is a [4x1] vector for the input and an integer for the output." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sklearn_input = np.array([1.0, 2.0, 3.0, 4.0]).reshape(1, -1)\n", "sklearn_output = 1\n", "sklearn_schema_builder = SchemaBuilder(\n", " sample_input=sklearn_input,\n", " sample_output=sklearn_output,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Build and deploy the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create model builder with the schema builder.\n", "model_builder = ModelBuilder(\n", " mode=Mode.SAGEMAKER_ENDPOINT,\n", " schema_builder=sklearn_schema_builder,\n", " role_arn=role,\n", " model_metadata={\"MLFLOW_MODEL_PATH\": source_path},\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "built_model = model_builder.build()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictor = built_model.deploy(initial_instance_count=1, instance_type=\"ml.m5.large\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference on Deployed Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictor.predict(sklearn_input)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cleanup Resources" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sagemaker_session.delete_model(model_name=built_model.name)\n", "sagemaker_session.delete_endpoint_config(endpoint_config_name=built_model.endpoint_name)\n", "sagemaker_session.delete_endpoint(endpoint_name=built_model.endpoint_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Notebook CI Test Results\n", "\n", "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", "\n", "\n", "![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/sagemaker-mlflow|sagemaker_deployment_mlflow.ipynb)\n", "\n", "![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/sagemaker-mlflow|sagemaker_deployment_mlflow.ipynb)\n", "\n", "![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/sagemaker-mlflow|sagemaker_deployment_mlflow.ipynb)\n", "\n", "![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/sagemaker-mlflow|sagemaker_deployment_mlflow.ipynb)\n", "\n", "![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/sagemaker-mlflow|sagemaker_deployment_mlflow.ipynb)\n", "\n", "![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/sagemaker-mlflow|sagemaker_deployment_mlflow.ipynb)\n", "\n", "![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/sagemaker-mlflow|sagemaker_deployment_mlflow.ipynb)\n", "\n", "![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/sagemaker-mlflow|sagemaker_deployment_mlflow.ipynb)\n", "\n", "![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/sagemaker-mlflow|sagemaker_deployment_mlflow.ipynb)\n", "\n", "![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/sagemaker-mlflow|sagemaker_deployment_mlflow.ipynb)\n", "\n", "![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/sagemaker-mlflow|sagemaker_deployment_mlflow.ipynb)\n", "\n", "![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/sagemaker-mlflow|sagemaker_deployment_mlflow.ipynb)\n", "\n", "![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/sagemaker-mlflow|sagemaker_deployment_mlflow.ipynb)\n", "\n", "![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/sagemaker-mlflow|sagemaker_deployment_mlflow.ipynb)\n", "\n", "![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/sagemaker-mlflow|sagemaker_deployment_mlflow.ipynb)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 4 }