Detect Stalled Training and Invoke Actions Using SageMaker Debugger Rule

This notebook shows you how to use the StalledTrainingRule built-in rule. This rule can take an action to stop your training job or send you an email/SMS, when the rule detects an inactivity in your training job for a certain time period. This functionality helps you monitor the training job status and reduces redundant resource usage.

How the StalledTrainingRule Built-in Rule Works

Amazon Sagemaker Debugger captures tensors that you want to watch from training jobs on AWS Deep Learning Containers or your local machine. If you use one of the Debugger-integrated Deep Learning Containers, you don’t need to make any changes to your training script to use the functionality of built-in rules. For information about Debugger-supported SageMaker frameworks and versions, see Debugger-supported framework versions for zero script change.

If you want to run a training script that uses partially supported framework by Debugger or your own custom container, you need to manually register the Debugger hook to your training script. The smdebug library provides tools to help the hook registration, and the sample script provided in the src folder includes the hook registration code as comment lines. For more information about how to manually register the Debugger hooks for this case, see the training script at ./src/simple_stalled_training.py, and documentation at smdebug TensorFlow hook, smdebug PyTorch hook, smdebug MXNet hook, and smdebug XGBoost hook.

The Debugger StalledTrainingRule watches tensor updates from your training job. If the rule doesn’t find new tensors updated to the default S3 URI for a threshold period of time, it takes an action to trigger the StopTrainingJob API operation. The following code cells set up a SageMaker TensorFlow estimator with the Debugger StalledTrainingRule to watch the losses pre-built tensor collection.

Install custom packages

These packages were built manually with the changes needed to run rules with actions, since the changes have not been released yet. Remember to refresh the kernel after installing these packages

[ ]:
! pip install -q -U sagemaker

Import SageMaker Python SDK

[ ]:
import sagemaker
from sagemaker.tensorflow import TensorFlow

print(sagemaker.__version__)

Import SageMaker Debugger classes for rule configuration

[ ]:
from sagemaker.debugger import Rule, CollectionConfig, rule_configs

Create the actions to be used in the rules

The following code cells include: * a code line to create the action objects * a stalled training job rule configuration object that uses these actions * a SageMaker TensorFlow estimator configuration with the Debugger rules parameter to run the built-in rule

Valid action objects are individual actions (StopTraining, Email, SMS) or an ActionList with a combination of these.

Note: Debugger collects loss tensors by default every 500 steps.

[ ]:
training_job_prefix = None  # Feel free to customize this if desired.
[ ]:
stop_training_action = (
    rule_configs.StopTraining()
)  # or specify a training job prefix with StopTraining("prefix")
actions = stop_training_action
[ ]:
# Configure a StalledTrainingRule rule parameter object
stalled_training_job_rule = [
    Rule.sagemaker(
        base_config=rule_configs.stalled_training_rule(),
        rule_parameters={
            "threshold": "60",
        },
        actions=actions,
    )
]

# Configure a SageMaker TensorFlow estimator
estimator = TensorFlow(
    role=sagemaker.get_execution_role(),
    base_job_name="stalled-training-test",
    instance_count=1,
    instance_type="ml.m5.4xlarge",
    entry_point="src/simple_stalled_training.py",  # This sample script forces the training job to sleep for 10 minutes
    framework_version="1.15.0",
    py_version="py3",
    max_run=3600,
    ## Debugger-specific parameter
    rules=stalled_training_job_rule,
)
[ ]:
estimator.fit(wait=False)

Monitoring Training and Rule Evaluation Status

Once you execute the estimator.fit() API, SageMaker initiates a training job in the background, and Debugger initiates a StalledTrainingRule rule evaluation job in parallel. Because the training scripts has a few lines of code at the end to force a sleep mode for 10 minutes, the RuleEvaluationStatus for StalledTrainingRule will change to IssuesFound in 2 minutes after the sleep mode is on and trigger the StopTrainingJob API.

Output the current job status and the rule evaluation status

The following cell tracks the status of training job until the SecondaryStatus changes to Stopped or Completed. While training, Debugger collects output tensors from the training job and monitors the training job with the rules.

[ ]:
import time

if description["TrainingJobStatus"] != "Completed":
    while description["SecondaryStatus"] not in {"Stopped", "Completed"}:
        description = client.describe_training_job(TrainingJobName=job_name)
        primary_status = description["TrainingJobStatus"]
        secondary_status = description["SecondaryStatus"]
        print(
            "Current job status: [PrimaryStatus: {}, SecondaryStatus: {}] | {} Rule Evaluation Status: {}".format(
                primary_status,
                secondary_status,
                estimator.latest_training_job.rule_job_summary()[0]["RuleConfigurationName"],
                estimator.latest_training_job.rule_job_summary()[0]["RuleEvaluationStatus"],
            )
        )
        time.sleep(15)
[ ]:
description = client.describe_training_job(TrainingJobName=job_name)
print(description)

### Get a direct Amazon CloudWatch URL to find the current rule processing job log

The following script returns a CloudWatch URL. Copy the URL and Paste it to a browser. This will directly lead you to the rule job log page.

[ ]:
import boto3

# This utility gives the link to monitor the CW event
def _get_rule_job_name(training_job_name, rule_configuration_name, rule_job_arn):
    """Helper function to get the rule job name"""
    return "{}-{}-{}".format(
        training_job_name[:26], rule_configuration_name[:26], rule_job_arn[-8:]
    )


def _get_cw_url_for_rule_job(rule_job_name, region):
    return "https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/ProcessingJobs;prefix={};streamFilter=typeLogStreamPrefix".format(
        region, region, rule_job_name
    )


def get_rule_jobs_cw_urls(estimator):
    region = boto3.Session().region_name
    training_job = estimator.latest_training_job
    training_job_name = training_job.describe()["TrainingJobName"]
    rule_eval_statuses = training_job.describe()["DebugRuleEvaluationStatuses"]

    result = {}
    for status in rule_eval_statuses:
        if status.get("RuleEvaluationJobArn", None) is not None:
            rule_job_name = _get_rule_job_name(
                training_job_name, status["RuleConfigurationName"], status["RuleEvaluationJobArn"]
            )
            result[status["RuleConfigurationName"]] = _get_cw_url_for_rule_job(
                rule_job_name, region
            )
    return result


print(
    "The direct CloudWatch URL to the current rule job:",
    get_rule_jobs_cw_urls(estimator)[
        estimator.latest_training_job.rule_job_summary()[0]["RuleConfigurationName"]
    ],
)

Conclusion

This notebook showed how you can use the Debugger StalledTrainingRule built-in rule for your training job to take action on rule evaluation status changes. To find more information about Debugger, see Amazon SageMaker Debugger Developer Guide and the smdebug GitHub documentation.

[ ]: