{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# SageMaker Training with MLflow" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.\n", "\n", "![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/sagemaker-mlflow|sagemaker_training_mlflow.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup environment" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Import necessary libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sagemaker\n", "from sagemaker import get_execution_role\n", "from sagemaker.sklearn.estimator import SKLearn\n", "\n", "import boto3\n", "import numpy as np\n", "import pandas as pd\n", "import os" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Declare some variables used later" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define session, role, and region so we can\n", "# perform any SageMaker tasks we need\n", "sagemaker_session = sagemaker.Session()\n", "role = get_execution_role()\n", "region = sagemaker_session.boto_region_name\n", "\n", "# S3 prefix for the training dataset to be uploaded to\n", "prefix = \"DEMO-scikit-iris\"\n", "\n", "# MLflow (replace these values with your own)\n", "tracking_server_arn = \"your tracking server arn\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!mkdir -p training_code" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get some training data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's download the save the Iris dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "os.makedirs(\"./data\", exist_ok=True)\n", "\n", "s3_client = boto3.client(\"s3\")\n", "s3_client.download_file(\n", " f\"sagemaker-example-files-prod-{region}\", \"datasets/tabular/iris/iris.data\", \"./data/iris.csv\"\n", ")\n", "\n", "df_iris = pd.read_csv(\"./data/iris.csv\", header=None)\n", "df_iris[4] = df_iris[4].map({\"Iris-setosa\": 0, \"Iris-versicolor\": 1, \"Iris-virginica\": 2})\n", "iris = df_iris[[4, 0, 1, 2, 3]].to_numpy()\n", "np.savetxt(\"./data/iris.csv\", iris, delimiter=\",\", fmt=\"%1.1f, %1.3f, %1.3f, %1.3f, %1.3f\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And now let's upload that data to S3" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "WORK_DIRECTORY = \"data\"\n", "\n", "train_input = sagemaker_session.upload_data(\n", " WORK_DIRECTORY, key_prefix=\"{}/{}\".format(prefix, WORK_DIRECTORY)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Write your training script" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's write the code to train a Decision Tree model using the scikit-learn framework" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%writefile training_code/train.py\n", "\n", "from __future__ import print_function\n", "\n", "import argparse\n", "import joblib\n", "import os\n", "import pandas as pd\n", "\n", "from sklearn import tree\n", "\n", "import mlflow\n", "\n", "if __name__ == '__main__':\n", " parser = argparse.ArgumentParser()\n", "\n", " # Hyperparameters are described here. In this simple example we are just including one hyperparameter.\n", " parser.add_argument('--max_leaf_nodes', type=int, default=-1)\n", "\n", " # Sagemaker specific arguments. Defaults are set in the environment variables.\n", " parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])\n", " parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])\n", " parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])\n", "\n", " args = parser.parse_args()\n", "\n", " # Take the set of files and read them all into a single pandas dataframe\n", " input_files = [ os.path.join(args.train, file) for file in os.listdir(args.train) if os.path.isfile(os.path.join(args.train, file))]\n", " if len(input_files) == 0:\n", " raise ValueError(('There are no files in {}.\\n' +\n", " 'This usually indicates that the channel ({}) was incorrectly specified,\\n' +\n", " 'the data specification in S3 was incorrectly specified or the role specified\\n' +\n", " 'does not have permission to access the data.').format(args.train, \"train\"))\n", " raw_data = [ pd.read_csv(file, header=None, engine=\"python\") for file in input_files ]\n", " train_data = pd.concat(raw_data)\n", "\n", " # Set the Tracking Server URI using the ARN of the Tracking Server you created\n", " mlflow.set_tracking_uri(os.environ['MLFLOW_TRACKING_ARN'])\n", " \n", " # Enable autologging in MLflow\n", " mlflow.autolog()\n", "\n", " # labels are in the first column\n", " train_y = train_data.iloc[:, 0]\n", " train_X = train_data.iloc[:, 1:]\n", "\n", " # Here we support a single hyperparameter, 'max_leaf_nodes'. Note that you can add as many\n", " # as your training my require in the ArgumentParser above.\n", " max_leaf_nodes = args.max_leaf_nodes\n", "\n", " # Now use scikit-learn's decision tree classifier to train the model.\n", " clf = tree.DecisionTreeClassifier(max_leaf_nodes=max_leaf_nodes)\n", " clf = clf.fit(train_X, train_y)\n", "\n", " # Print the coefficients of the trained classifier, and save the coefficients\n", " joblib.dump(clf, os.path.join(args.model_dir, 'model.joblib'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since we're using MLflow in our training script, let's make sure the container installs `mlflow` along with our MLflow AWS plugin before running our training script. We can do this by creating a `requirements.txt` file and putting it in the same directory as our training script." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%writefile training_code/requirements.txt\n", "mlflow==2.13.2\n", "sagemaker-mlflow==0.1.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## SageMaker Training and MLflow" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train your Decision tree model by launching a SageMaker Training job." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sklearn = SKLearn(\n", " entry_point=\"train.py\",\n", " source_dir=\"training_code\",\n", " framework_version=\"1.2-1\",\n", " instance_type=\"ml.c4.xlarge\",\n", " role=role,\n", " sagemaker_session=sagemaker_session,\n", " hyperparameters={\"max_leaf_nodes\": 30},\n", " keep_alive_period_in_seconds=3600,\n", " environment={\"MLFLOW_TRACKING_ARN\": tracking_server_arn},\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "sklearn.fit({\"train\": train_input})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Notebook CI Test Results\n", "\n", "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", "\n", "\n", "![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/sagemaker-mlflow|sagemaker_training_mlflow.ipynb)\n", "\n", "![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/sagemaker-mlflow|sagemaker_training_mlflow.ipynb)\n", "\n", "![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/sagemaker-mlflow|sagemaker_training_mlflow.ipynb)\n", "\n", "![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/sagemaker-mlflow|sagemaker_training_mlflow.ipynb)\n", "\n", "![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/sagemaker-mlflow|sagemaker_training_mlflow.ipynb)\n", "\n", "![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/sagemaker-mlflow|sagemaker_training_mlflow.ipynb)\n", "\n", "![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/sagemaker-mlflow|sagemaker_training_mlflow.ipynb)\n", "\n", "![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/sagemaker-mlflow|sagemaker_training_mlflow.ipynb)\n", "\n", "![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/sagemaker-mlflow|sagemaker_training_mlflow.ipynb)\n", "\n", "![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/sagemaker-mlflow|sagemaker_training_mlflow.ipynb)\n", "\n", "![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/sagemaker-mlflow|sagemaker_training_mlflow.ipynb)\n", "\n", "![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/sagemaker-mlflow|sagemaker_training_mlflow.ipynb)\n", "\n", "![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/sagemaker-mlflow|sagemaker_training_mlflow.ipynb)\n", "\n", "![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/sagemaker-mlflow|sagemaker_training_mlflow.ipynb)\n", "\n", "![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/sagemaker-mlflow|sagemaker_training_mlflow.ipynb)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" }, "vscode": { "interpreter": { "hash": "3b41de70bedc0e302a3aeb58a0c77b854f2e56c8930e61a4aaa3340c96b01f1d" } } }, "nbformat": 4, "nbformat_minor": 4 }