Skip to content

feature: support RetryStrategy for training jobs #2316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
profiler_config=None,
disable_profiler=False,
environment=None,
max_retry_attempts=None,
**kwargs,
):
"""Initialize an ``EstimatorBase`` instance.
Expand Down Expand Up @@ -269,6 +270,13 @@ def __init__(
will be disabled (default: ``False``).
environment (dict[str, str]) : Environment variables to be set for
use during training job (default: ``None``)
max_retry_attempts (int): The number of times to move a job to the STARTING status.
You can specify between 1 and 30 attempts.
If the value of attempts is greater than zero,
the job is retried on InternalServerFailure
the same number of attempts as the value.
You can cap the total duration for your job by setting ``max_wait`` and ``max_run``
(default: ``None``)

"""
instance_count = renamed_kwargs(
Expand Down Expand Up @@ -357,6 +365,8 @@ def __init__(

self.environment = environment

self.max_retry_attempts = max_retry_attempts

if not _region_supports_profiler(self.sagemaker_session.boto_region_name):
self.disable_profiler = True

Expand Down Expand Up @@ -1114,6 +1124,13 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
if max_wait:
init_params["max_wait"] = max_wait

if job_details.get("RetryStrategy", False):
init_params["max_retry_attempts"] = job_details.get("RetryStrategy", {}).get(
"MaximumRetryAttempts"
)
max_wait = job_details.get("StoppingCondition", {}).get("MaxWaitTimeInSeconds")
if max_wait:
init_params["max_wait"] = max_wait
return init_params

def transformer(
Expand Down Expand Up @@ -1489,6 +1506,11 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
if estimator.enable_network_isolation():
train_args["enable_network_isolation"] = True

if estimator.max_retry_attempts is not None:
train_args["retry_strategy"] = {"MaximumRetryAttempts": estimator.max_retry_attempts}
else:
train_args["retry_strategy"] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it necessary to set this to None? Can we just leave it out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per this test, if I have to assert that retry_strategy is None when max_retry_attempts are not set, then yes we have to explicitly set to None https://github.com/aws/sagemaker-python-sdk/blob/master/tests/unit/test_estimator.py#L1060 . given this my understanding is that we have to set it to None explicitly

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, k


if estimator.encrypt_inter_container_traffic:
train_args["encrypt_inter_container_traffic"] = True

Expand Down Expand Up @@ -1666,6 +1688,7 @@ def __init__(
profiler_config=None,
disable_profiler=False,
environment=None,
max_retry_attempts=None,
**kwargs,
):
"""Initialize an ``Estimator`` instance.
Expand Down Expand Up @@ -1816,6 +1839,13 @@ def __init__(
will be disabled (default: ``False``).
environment (dict[str, str]) : Environment variables to be set for
use during training job (default: ``None``)
max_retry_attempts (int): The number of times to move a job to the STARTING status.
You can specify between 1 and 30 attempts.
If the value of attempts is greater than zero,
the job is retried on InternalServerFailure
the same number of attempts as the value.
You can cap the total duration for your job by setting ``max_wait`` and ``max_run``
(default: ``None``)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please explain here that if this is set to None there will be no retries.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am wondering if we should set the default to 3 or 2.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can set the default to 3 or 2 as we want to customers to explicitly opt-in. However I think I need to re-word it like " If the value of attempts is greater than zero, the job is retried on InternalServerFailure the same number of attempts as the value.
You can cap the total duration for your job by setting max_wait and max_run"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, that is more clear. Let's just do None for the default for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""
self.image_uri = image_uri
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
Expand Down Expand Up @@ -1850,6 +1880,7 @@ def __init__(
profiler_config=profiler_config,
disable_profiler=disable_profiler,
environment=environment,
max_retry_attempts=max_retry_attempts,
**kwargs,
)

Expand Down
12 changes: 12 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ def train( # noqa: C901
profiler_rule_configs=None,
profiler_config=None,
environment=None,
retry_strategy=None,
):
"""Create an Amazon SageMaker training job.

Expand Down Expand Up @@ -529,6 +530,9 @@ def train( # noqa: C901
with SageMaker Profiler. (default: ``None``).
environment (dict[str, str]) : Environment variables to be set for
use during training job (default: ``None``)
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
* max_retry_attsmpts (int): Number of times a job should be retried.
The key in RetryStrategy is 'MaxRetryAttempts'.

Returns:
str: ARN of the training job, if it is created.
Expand Down Expand Up @@ -561,6 +565,7 @@ def train( # noqa: C901
profiler_rule_configs=profiler_rule_configs,
profiler_config=profiler_config,
environment=environment,
retry_strategy=retry_strategy,
)
LOGGER.info("Creating training-job with name: %s", job_name)
LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
Expand Down Expand Up @@ -594,6 +599,7 @@ def _get_train_request( # noqa: C901
profiler_rule_configs=None,
profiler_config=None,
environment=None,
retry_strategy=None,
):
"""Constructs a request compatible for creating an Amazon SageMaker training job.

Expand Down Expand Up @@ -665,6 +671,9 @@ def _get_train_request( # noqa: C901
SageMaker Profiler. (default: ``None``).
environment (dict[str, str]) : Environment variables to be set for
use during training job (default: ``None``)
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
* max_retry_attsmpts (int): Number of times a job should be retried.
The key in RetryStrategy is 'MaxRetryAttempts'.

Returns:
Dict: a training request dict
Expand Down Expand Up @@ -749,6 +758,9 @@ def _get_train_request( # noqa: C901
if profiler_config is not None:
train_request["ProfilerConfig"] = profiler_config

if retry_strategy is not None:
train_request["RetryStrategy"] = retry_strategy

return train_request

def update_training_job(
Expand Down
10 changes: 10 additions & 0 deletions tests/integ/test_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def test_mnist_with_checkpoint_config(
checkpoint_s3_uri=checkpoint_s3_uri,
checkpoint_local_path=checkpoint_local_path,
environment=ENV_INPUT,
max_wait=24 * 60 * 60,
max_retry_attempts=2,
)
inputs = estimator.sagemaker_session.upload_data(
path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="scriptmode/mnist"
Expand Down Expand Up @@ -89,8 +91,16 @@ def test_mnist_with_checkpoint_config(
"Environment"
]
)

expected_retry_strategy = {
"MaximumRetryAttempts": 2,
}
actual_retry_strategy = sagemaker_session.sagemaker_client.describe_training_job(
TrainingJobName=training_job_name
)["RetryStrategy"]
assert actual_training_checkpoint_config == expected_training_checkpoint_config
assert actual_training_environment_variable_config == ENV_INPUT
assert actual_retry_strategy == expected_retry_strategy


def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_version):
Expand Down
1 change: 1 addition & 0 deletions tests/unit/sagemaker/huggingface/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def _create_train_job(version, base_framework_version):
"vpc_config": None,
"metric_definitions": None,
"environment": None,
"retry_strategy": None,
"experiment_config": None,
"debugger_hook_config": {
"CollectionConfigurations": [],
Expand Down
1 change: 1 addition & 0 deletions tests/unit/sagemaker/tensorflow/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def _create_train_job(tf_version, horovod=False, ps=False, py_version="py2", smd
},
"hyperparameters": _hyperparameters(horovod, smdataparallel),
"stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60},
"retry_strategy": None,
"tags": None,
"vpc_config": None,
"metric_definitions": None,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_chainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def _create_train_job(version, py_version):
"sagemaker_region": '"us-west-2"',
},
"stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60},
"retry_strategy": None,
"tags": None,
"vpc_config": None,
"metric_definitions": None,
Expand Down
41 changes: 41 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def test_framework_all_init_args(sagemaker_session):
enable_sagemaker_metrics=True,
enable_network_isolation=True,
environment=ENV_INPUT,
max_retry_attempts=2,
)
_TrainingJob.start_new(f, "s3://mydata", None)
sagemaker_session.train.assert_called_once()
Expand All @@ -269,6 +270,7 @@ def test_framework_all_init_args(sagemaker_session):
"output_config": {"KmsKeyId": "outputkms", "S3OutputPath": "outputpath"},
"vpc_config": {"Subnets": ["123", "456"], "SecurityGroupIds": ["789", "012"]},
"stop_condition": {"MaxRuntimeInSeconds": 456},
"retry_strategy": {"MaximumRetryAttempts": 2},
"role": sagemaker_session.expand_role(),
"job_name": None,
"resource_config": {
Expand Down Expand Up @@ -1092,6 +1094,7 @@ def test_framework_with_spot_and_checkpoints(sagemaker_session):
"checkpoint_local_path": "/tmp/checkpoints",
"environment": None,
"experiment_config": None,
"retry_strategy": None,
}


Expand Down Expand Up @@ -2392,6 +2395,7 @@ def test_unsupported_type_in_dict():
"VolumeSizeInGB": 30,
},
"stop_condition": {"MaxRuntimeInSeconds": 86400},
"retry_strategy": None,
"tags": None,
"vpc_config": None,
"metric_definitions": None,
Expand Down Expand Up @@ -2703,6 +2707,24 @@ def test_add_environment_variables_to_train_args(sagemaker_session):
assert args["environment"] == ENV_INPUT


def test_add_retry_strategy_to_train_args(sagemaker_session):
e = Estimator(
IMAGE_URI,
ROLE,
INSTANCE_COUNT,
INSTANCE_TYPE,
output_path=OUTPUT_PATH,
sagemaker_session=sagemaker_session,
max_retry_attempts=2,
)

e.fit()

sagemaker_session.train.assert_called_once()
args = sagemaker_session.train.call_args[1]
assert args["retry_strategy"] == {"MaximumRetryAttempts": 2}


def test_generic_to_fit_with_sagemaker_metrics_enabled(sagemaker_session):
e = Estimator(
IMAGE_URI,
Expand Down Expand Up @@ -3159,6 +3181,25 @@ def test_prepare_init_params_from_job_description_with_spot_training():
assert init_params["max_wait"] == 87000


def test_prepare_init_params_from_job_description_with_retry_strategy():
job_description = RETURNED_JOB_DESCRIPTION.copy()
job_description["RetryStrategy"] = {"MaximumRetryAttempts": 2}
job_description["StoppingCondition"] = {
"MaxRuntimeInSeconds": 86400,
"MaxWaitTimeInSeconds": 87000,
}

init_params = EstimatorBase._prepare_init_params_from_job_description(
job_details=job_description
)

assert init_params["role"] == "arn:aws:iam::366:role/SageMakerRole"
assert init_params["instance_count"] == 1
assert init_params["max_run"] == 86400
assert init_params["max_wait"] == 87000
assert init_params["max_retry_attempts"] == 2


def test_prepare_init_params_from_job_description_with_invalid_training_job():

invalid_job_description = RETURNED_JOB_DESCRIPTION.copy()
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def _get_train_args(job_name):
"vpc_config": None,
"metric_definitions": None,
"environment": None,
"retry_strategy": None,
"experiment_config": None,
"debugger_hook_config": {
"CollectionConfigurations": [],
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def _create_train_job(version, py_version):
"vpc_config": None,
"metric_definitions": None,
"environment": None,
"retry_strategy": None,
"experiment_config": None,
"debugger_hook_config": {
"CollectionConfigurations": [],
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def _create_train_job(toolkit, toolkit_version, framework):
"profiler_config": {
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
},
"retry_strategy": None,
}


Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,6 +1233,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
}

stop_cond = {"MaxRuntimeInSeconds": MAX_TIME}
RETRY_STRATEGY = {"MaximumRetryAttempts": 2}
hyperparameters = {"foo": "bar"}

sagemaker_session.train(
Expand All @@ -1254,6 +1255,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
checkpoint_local_path="/tmp/checkpoints",
enable_sagemaker_metrics=True,
environment=ENV_INPUT,
retry_strategy=RETRY_STRATEGY,
)

_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]
Expand All @@ -1268,6 +1270,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
assert actual_train_args["CheckpointConfig"]["S3Uri"] == "s3://mybucket/checkpoints/"
assert actual_train_args["CheckpointConfig"]["LocalPath"] == "/tmp/checkpoints"
assert actual_train_args["Environment"] == ENV_INPUT
assert actual_train_args["RetryStrategy"] == RETRY_STRATEGY


def test_transform_pack_to_request(sagemaker_session):
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def _create_train_job(version):
"sagemaker_region": '"us-west-2"',
},
"stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60},
"retry_strategy": None,
"metric_definitions": None,
"tags": None,
"vpc_config": None,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def _create_train_job(version, instance_count=1, instance_type="ml.c4.4xlarge"):
"sagemaker_region": '"us-west-2"',
},
"stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60},
"retry_strategy": None,
"metric_definitions": None,
"tags": None,
"vpc_config": None,
Expand Down