Skip to content

Commit 206c807

Browse files
author
Shikha Panghal
committed
feature: support RetryStrategy for training jobs
1 parent 7b93234 commit 206c807

File tree

13 files changed

+108
-3
lines changed

13 files changed

+108
-3
lines changed

src/sagemaker/estimator.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def __init__(
124124
profiler_config=None,
125125
disable_profiler=False,
126126
environment=None,
127+
max_retry_attempts=None,
127128
**kwargs,
128129
):
129130
"""Initialize an ``EstimatorBase`` instance.
@@ -269,6 +270,11 @@ def __init__(
269270
will be disabled (default: ``False``).
270271
environment (dict[str, str]) : Environment variables to be set for
271272
use during training job (default: ``None``)
273+
max_retry_attempts (int): The number of times to move a job to the STARTING status.
274+
You can specify between 1 and 30 attempts.
275+
If the value of attempts is greater than one, the job is retried on InternalServerFailure the same number of attempts as the value.
276+
You can cap the total duration for your job by setting ``max_wait`` and ``max_run``
277+
(default: ``None``)
272278
273279
"""
274280
instance_count = renamed_kwargs(
@@ -357,6 +363,8 @@ def __init__(
357363

358364
self.environment = environment
359365

366+
self.max_retry_attempts = max_retry_attempts
367+
360368
if not _region_supports_profiler(self.sagemaker_session.boto_region_name):
361369
self.disable_profiler = True
362370

@@ -1114,6 +1122,13 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
11141122
if max_wait:
11151123
init_params["max_wait"] = max_wait
11161124

1125+
if job_details.get("RetryStrategy", False):
1126+
init_params["max_retry_attempts"] = job_details.get("RetryStrategy", {}).get("MaximumRetryAttempts")
1127+
max_wait = job_details.get("StoppingCondition", {}).get("MaxWaitTimeInSeconds")
1128+
if max_wait:
1129+
init_params["max_wait"] = max_wait
1130+
1131+
11171132
return init_params
11181133

11191134
def transformer(
@@ -1489,6 +1504,11 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
14891504
if estimator.enable_network_isolation():
14901505
train_args["enable_network_isolation"] = True
14911506

1507+
if estimator.max_retry_attempts is not None:
1508+
train_args["retry_strategy"] = {"MaximumRetryAttempts": estimator.max_retry_attempts}
1509+
else:
1510+
train_args["retry_strategy"] = None
1511+
14921512
if estimator.encrypt_inter_container_traffic:
14931513
train_args["encrypt_inter_container_traffic"] = True
14941514

@@ -1522,6 +1542,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
15221542

15231543
return train_args
15241544

1545+
15251546
@classmethod
15261547
def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args):
15271548
"""Placeholder docstring"""
@@ -1666,6 +1687,7 @@ def __init__(
16661687
profiler_config=None,
16671688
disable_profiler=False,
16681689
environment=None,
1690+
max_retry_attempts=None,
16691691
**kwargs,
16701692
):
16711693
"""Initialize an ``Estimator`` instance.
@@ -1816,6 +1838,11 @@ def __init__(
18161838
will be disabled (default: ``False``).
18171839
environment (dict[str, str]) : Environment variables to be set for
18181840
use during training job (default: ``None``)
1841+
max_retry_attempts (int): The number of times to move a job to the STARTING status.
1842+
You can specify between 1 and 30 attempts.
1843+
If the value of attempts is greater than one, the job is retried on InternalServerFailure the same number of attempts as the value.
1844+
You can cap the total duration for your job by setting ``max_wait`` and ``max_run``
1845+
(default: ``None``)
18191846
"""
18201847
self.image_uri = image_uri
18211848
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
@@ -1850,6 +1877,7 @@ def __init__(
18501877
profiler_config=profiler_config,
18511878
disable_profiler=disable_profiler,
18521879
environment=environment,
1880+
max_retry_attempts=max_retry_attempts,
18531881
**kwargs,
18541882
)
18551883

src/sagemaker/session.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ def train( # noqa: C901
457457
profiler_rule_configs=None,
458458
profiler_config=None,
459459
environment=None,
460+
retry_strategy=None,
460461
):
461462
"""Create an Amazon SageMaker training job.
462463
@@ -529,6 +530,9 @@ def train( # noqa: C901
529530
with SageMaker Profiler. (default: ``None``).
530531
environment (dict[str, str]) : Environment variables to be set for
531532
use during training job (default: ``None``)
533+
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
534+
* max_retry_attsmpts (int): Number of times a job should be retried.
535+
The key in RetryStrategy is 'MaxRetryAttempts'.
532536
533537
Returns:
534538
str: ARN of the training job, if it is created.
@@ -561,6 +565,7 @@ def train( # noqa: C901
561565
profiler_rule_configs=profiler_rule_configs,
562566
profiler_config=profiler_config,
563567
environment=environment,
568+
retry_strategy=retry_strategy,
564569
)
565570
LOGGER.info("Creating training-job with name: %s", job_name)
566571
LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
@@ -594,6 +599,7 @@ def _get_train_request( # noqa: C901
594599
profiler_rule_configs=None,
595600
profiler_config=None,
596601
environment=None,
602+
retry_strategy=None,
597603
):
598604
"""Constructs a request compatible for creating an Amazon SageMaker training job.
599605
@@ -665,6 +671,9 @@ def _get_train_request( # noqa: C901
665671
SageMaker Profiler. (default: ``None``).
666672
environment (dict[str, str]) : Environment variables to be set for
667673
use during training job (default: ``None``)
674+
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
675+
* max_retry_attsmpts (int): Number of times a job should be retried.
676+
The key in RetryStrategy is 'MaxRetryAttempts'.
668677
669678
Returns:
670679
Dict: a training request dict
@@ -749,6 +758,9 @@ def _get_train_request( # noqa: C901
749758
if profiler_config is not None:
750759
train_request["ProfilerConfig"] = profiler_config
751760

761+
if retry_strategy is not None:
762+
train_request["RetryStrategy"] = retry_strategy
763+
752764
return train_request
753765

754766
def update_training_job(

tests/integ/test_tf.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def test_mnist_with_checkpoint_config(
6161
checkpoint_s3_uri=checkpoint_s3_uri,
6262
checkpoint_local_path=checkpoint_local_path,
6363
environment=ENV_INPUT,
64+
max_wait=24 * 60 * 60,
65+
max_retry_attempts=2,
6466
)
6567
inputs = estimator.sagemaker_session.upload_data(
6668
path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="scriptmode/mnist"
@@ -89,8 +91,19 @@ def test_mnist_with_checkpoint_config(
8991
"Environment"
9092
]
9193
)
94+
95+
expected_retry_strategy = {
96+
"MaximumRetryAttempts": 2,
97+
}
98+
actual_retry_strategy = (
99+
sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)[
100+
"RetryStrategy"
101+
]
102+
)
92103
assert actual_training_checkpoint_config == expected_training_checkpoint_config
93104
assert actual_training_environment_variable_config == ENV_INPUT
105+
assert actual_retry_strategy == expected_retry_strategy
106+
94107

95108

96109
def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_version):

tests/unit/sagemaker/huggingface/test_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def _create_train_job(version, base_framework_version):
150150
"vpc_config": None,
151151
"metric_definitions": None,
152152
"environment": None,
153+
"retry_strategy": None,
153154
"experiment_config": None,
154155
"debugger_hook_config": {
155156
"CollectionConfigurations": [],

tests/unit/sagemaker/tensorflow/test_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def _create_train_job(tf_version, horovod=False, ps=False, py_version="py2", smd
127127
},
128128
"hyperparameters": _hyperparameters(horovod, smdataparallel),
129129
"stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60},
130+
"retry_strategy": None,
130131
"tags": None,
131132
"vpc_config": None,
132133
"metric_definitions": None,

tests/unit/test_chainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def _create_train_job(version, py_version):
140140
"sagemaker_region": '"us-west-2"',
141141
},
142142
"stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60},
143+
"retry_strategy": None,
143144
"tags": None,
144145
"vpc_config": None,
145146
"metric_definitions": None,

tests/unit/test_estimator.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def test_framework_all_init_args(sagemaker_session):
245245
enable_sagemaker_metrics=True,
246246
enable_network_isolation=True,
247247
environment=ENV_INPUT,
248+
max_retry_attempts=2,
248249
)
249250
_TrainingJob.start_new(f, "s3://mydata", None)
250251
sagemaker_session.train.assert_called_once()
@@ -269,6 +270,7 @@ def test_framework_all_init_args(sagemaker_session):
269270
"output_config": {"KmsKeyId": "outputkms", "S3OutputPath": "outputpath"},
270271
"vpc_config": {"Subnets": ["123", "456"], "SecurityGroupIds": ["789", "012"]},
271272
"stop_condition": {"MaxRuntimeInSeconds": 456},
273+
"retry_strategy": {"MaximumRetryAttempts": 2},
272274
"role": sagemaker_session.expand_role(),
273275
"job_name": None,
274276
"resource_config": {
@@ -1029,7 +1031,6 @@ def test_training_job_with_rule_job_summary(sagemaker_session, training_job_desc
10291031
},
10301032
]
10311033

1032-
10331034
def test_framework_with_spot_and_checkpoints(sagemaker_session):
10341035
f = DummyFramework(
10351036
"my_script.py",
@@ -1092,6 +1093,7 @@ def test_framework_with_spot_and_checkpoints(sagemaker_session):
10921093
"checkpoint_local_path": "/tmp/checkpoints",
10931094
"environment": None,
10941095
"experiment_config": None,
1096+
"retry_strategy": None,
10951097
}
10961098

10971099

@@ -2392,6 +2394,7 @@ def test_unsupported_type_in_dict():
23922394
"VolumeSizeInGB": 30,
23932395
},
23942396
"stop_condition": {"MaxRuntimeInSeconds": 86400},
2397+
"retry_strategy": None,
23952398
"tags": None,
23962399
"vpc_config": None,
23972400
"metric_definitions": None,
@@ -2703,6 +2706,24 @@ def test_add_environment_variables_to_train_args(sagemaker_session):
27032706
assert args["environment"] == ENV_INPUT
27042707

27052708

2709+
def test_add_retry_strategy_to_train_args(sagemaker_session):
2710+
e = Estimator(
2711+
IMAGE_URI,
2712+
ROLE,
2713+
INSTANCE_COUNT,
2714+
INSTANCE_TYPE,
2715+
output_path=OUTPUT_PATH,
2716+
sagemaker_session=sagemaker_session,
2717+
max_retry_attempts=2,
2718+
)
2719+
2720+
e.fit()
2721+
2722+
sagemaker_session.train.assert_called_once()
2723+
args = sagemaker_session.train.call_args[1]
2724+
assert args["retry_strategy"] == {"MaximumRetryAttempts": 2}
2725+
2726+
27062727
def test_generic_to_fit_with_sagemaker_metrics_enabled(sagemaker_session):
27072728
e = Estimator(
27082729
IMAGE_URI,
@@ -3159,6 +3180,27 @@ def test_prepare_init_params_from_job_description_with_spot_training():
31593180
assert init_params["max_wait"] == 87000
31603181

31613182

3183+
def test_prepare_init_params_from_job_description_with_retry_strategy():
3184+
job_description = RETURNED_JOB_DESCRIPTION.copy()
3185+
job_description["RetryStrategy"] = {
3186+
"MaximumRetryAttempts": 2
3187+
}
3188+
job_description["StoppingCondition"] = {
3189+
"MaxRuntimeInSeconds": 86400,
3190+
"MaxWaitTimeInSeconds": 87000,
3191+
}
3192+
3193+
init_params = EstimatorBase._prepare_init_params_from_job_description(
3194+
job_details=job_description
3195+
)
3196+
3197+
assert init_params["role"] == "arn:aws:iam::366:role/SageMakerRole"
3198+
assert init_params["instance_count"] == 1
3199+
assert init_params["max_run"] == 86400
3200+
assert init_params["max_wait"] == 87000
3201+
assert init_params["max_retry_attempts"] == 2
3202+
3203+
31623204
def test_prepare_init_params_from_job_description_with_invalid_training_job():
31633205

31643206
invalid_job_description = RETURNED_JOB_DESCRIPTION.copy()

tests/unit/test_mxnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def _get_train_args(job_name):
147147
"vpc_config": None,
148148
"metric_definitions": None,
149149
"environment": None,
150+
"retry_strategy": None,
150151
"experiment_config": None,
151152
"debugger_hook_config": {
152153
"CollectionConfigurations": [],
@@ -993,7 +994,6 @@ def test_mx_missing_environment_variables(
993994
)
994995
assert not mx.environment
995996

996-
997997
def test_mx_enable_sm_metrics(sagemaker_session, mxnet_training_version, mxnet_training_py_version):
998998
mx = MXNet(
999999
entry_point=SCRIPT_PATH,

tests/unit/test_pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def _create_train_job(version, py_version):
149149
"vpc_config": None,
150150
"metric_definitions": None,
151151
"environment": None,
152+
"retry_strategy": None,
152153
"experiment_config": None,
153154
"debugger_hook_config": {
154155
"CollectionConfigurations": [],

tests/unit/test_rl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def _create_train_job(toolkit, toolkit_version, framework):
162162
"profiler_config": {
163163
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
164164
},
165+
"retry_strategy": None,
165166
}
166167

167168

tests/unit/test_session.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def test_process(boto_session):
154154
},
155155
"role_arn": ROLE,
156156
"tags": [{"Name": "my-tag", "Value": "my-tag-value"}],
157-
"experiment_config": {"ExperimentName": "AnExperiment"},
157+
"experiment_config": {"ExperimentName": "AnExperiment"}
158158
}
159159
session.process(**process_request_args)
160160

@@ -1208,6 +1208,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
12081208
}
12091209

12101210
stop_cond = {"MaxRuntimeInSeconds": MAX_TIME}
1211+
RETRY_STRATEGY = {"MaximumRetryAttempts": 2}
12111212
hyperparameters = {"foo": "bar"}
12121213

12131214
sagemaker_session.train(
@@ -1229,6 +1230,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
12291230
checkpoint_local_path="/tmp/checkpoints",
12301231
enable_sagemaker_metrics=True,
12311232
environment=ENV_INPUT,
1233+
retry_strategy=RETRY_STRATEGY,
12321234
)
12331235

12341236
_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]
@@ -1243,6 +1245,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
12431245
assert actual_train_args["CheckpointConfig"]["S3Uri"] == "s3://mybucket/checkpoints/"
12441246
assert actual_train_args["CheckpointConfig"]["LocalPath"] == "/tmp/checkpoints"
12451247
assert actual_train_args["Environment"] == ENV_INPUT
1248+
assert actual_train_args["RetryStrategy"] == RETRY_STRATEGY
12461249

12471250

12481251
def test_transform_pack_to_request(sagemaker_session):

tests/unit/test_sklearn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def _create_train_job(version):
129129
"sagemaker_region": '"us-west-2"',
130130
},
131131
"stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60},
132+
"retry_strategy": None,
132133
"metric_definitions": None,
133134
"tags": None,
134135
"vpc_config": None,

tests/unit/test_xgboost.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def _create_train_job(version, instance_count=1, instance_type="ml.c4.4xlarge"):
142142
"sagemaker_region": '"us-west-2"',
143143
},
144144
"stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60},
145+
"retry_strategy": None,
145146
"metric_definitions": None,
146147
"tags": None,
147148
"vpc_config": None,

0 commit comments

Comments
 (0)