Skip to content

Commit d0c6764

Browse files
ishaaqknakad
authored andcommitted
Add support for Managed Spot Training and Checkpoint support (#990)
1 parent 3cf4f9b commit d0c6764

File tree

6 files changed

+207
-7
lines changed

6 files changed

+207
-7
lines changed

src/sagemaker/estimator.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ def __init__(
8282
model_channel_name="model",
8383
metric_definitions=None,
8484
encrypt_inter_container_traffic=False,
85+
train_use_spot_instances=False,
86+
train_max_wait=None,
87+
checkpoint_s3_uri=None,
88+
checkpoint_local_path=None,
8589
):
8690
"""Initialize an ``EstimatorBase`` instance.
8791
@@ -157,6 +161,28 @@ def __init__(
157161
encrypt_inter_container_traffic (bool): Specifies whether traffic
158162
between training containers is encrypted for the training job
159163
(default: ``False``).
164+
train_use_spot_instances (bool): Specifies whether to use SageMaker
165+
Managed Spot instances for training. If enabled then the
166+
`train_max_wait` arg should also be set.
167+
168+
More information:
169+
https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html
170+
(default: ``False``).
171+
train_max_wait (int): Timeout in seconds waiting for spot training
172+
instances (default: None). After this amount of time Amazon
173+
SageMaker will stop waiting for Spot instances to become
174+
available (default: ``None``).
175+
checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
176+
that the algorithm persists (if any) during training. (default:
177+
``None``).
178+
checkpoint_local_path (str): The local path that the algorithm
179+
writes its checkpoints to. SageMaker will persist all files
180+
under this path to `checkpoint_s3_uri` continually during
181+
training. On job startup the reverse happens - data from the
182+
s3 location is downloaded to this path before the algorithm is
183+
started. If the path is unset then SageMaker assumes the
184+
checkpoints will be provided under `/opt/ml/checkpoints/`.
185+
(default: ``None``).
160186
"""
161187
self.role = role
162188
self.train_instance_count = train_instance_count
@@ -199,6 +225,10 @@ def __init__(
199225
self.security_group_ids = security_group_ids
200226

201227
self.encrypt_inter_container_traffic = encrypt_inter_container_traffic
228+
self.train_use_spot_instances = train_use_spot_instances
229+
self.train_max_wait = train_max_wait
230+
self.checkpoint_s3_uri = checkpoint_s3_uri
231+
self.checkpoint_local_path = checkpoint_local_path
202232

203233
@abstractmethod
204234
def train_image(self):
@@ -795,10 +825,35 @@ def start_new(cls, estimator, inputs):
795825
else:
796826
train_args["image"] = estimator.train_image()
797827

828+
cls._add_spot_checkpoint_args(local_mode, estimator, train_args)
829+
798830
estimator.sagemaker_session.train(**train_args)
799831

800832
return cls(estimator.sagemaker_session, estimator._current_job_name)
801833

834+
@classmethod
835+
def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args):
836+
"""
837+
Args:
838+
local_mode:
839+
estimator:
840+
train_args:
841+
"""
842+
if estimator.train_use_spot_instances:
843+
if local_mode:
844+
raise ValueError("Spot training is not supported in local mode.")
845+
train_args["train_use_spot_instances"] = True
846+
847+
if estimator.checkpoint_s3_uri:
848+
if local_mode:
849+
raise ValueError("Setting checkpoint_s3_uri is not supported in local mode.")
850+
train_args["checkpoint_s3_uri"] = estimator.checkpoint_s3_uri
851+
852+
if estimator.checkpoint_local_path:
853+
if local_mode:
854+
raise ValueError("Setting checkpoint_local_path is not supported in local mode.")
855+
train_args["checkpoint_local_path"] = estimator.checkpoint_local_path
856+
802857
@classmethod
803858
def _is_local_channel(cls, input_uri):
804859
"""
@@ -845,6 +900,10 @@ def __init__(
845900
model_channel_name="model",
846901
metric_definitions=None,
847902
encrypt_inter_container_traffic=False,
903+
train_use_spot_instances=False,
904+
train_max_wait=None,
905+
checkpoint_s3_uri=None,
906+
checkpoint_local_path=None,
848907
):
849908
"""Initialize an ``Estimator`` instance.
850909
@@ -926,6 +985,28 @@ def __init__(
926985
encrypt_inter_container_traffic (bool): Specifies whether traffic
927986
between training containers is encrypted for the training job
928987
(default: ``False``).
988+
train_use_spot_instances (bool): Specifies whether to use SageMaker
989+
Managed Spot instances for training. If enabled then the
990+
`train_max_wait` arg should also be set.
991+
992+
More information:
993+
https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html
994+
(default: ``False``).
995+
train_max_wait (int): Timeout in seconds waiting for spot training
996+
instances (default: None). After this amount of time Amazon
997+
SageMaker will stop waiting for Spot instances to become
998+
available (default: ``None``).
999+
checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
1000+
that the algorithm persists (if any) during training. (default:
1001+
``None``).
1002+
checkpoint_local_path (str): The local path that the algorithm
1003+
writes its checkpoints to. SageMaker will persist all files
1004+
under this path to `checkpoint_s3_uri` continually during
1005+
training. On job startup the reverse happens - data from the
1006+
s3 location is downloaded to this path before the algorithm is
1007+
started. If the path is unset then SageMaker assumes the
1008+
checkpoints will be provided under `/opt/ml/checkpoints/`.
1009+
(default: ``None``).
9291010
"""
9301011
self.image_name = image_name
9311012
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
@@ -948,6 +1029,10 @@ def __init__(
9481029
model_channel_name=model_channel_name,
9491030
metric_definitions=metric_definitions,
9501031
encrypt_inter_container_traffic=encrypt_inter_container_traffic,
1032+
train_use_spot_instances=train_use_spot_instances,
1033+
train_max_wait=train_max_wait,
1034+
checkpoint_s3_uri=checkpoint_s3_uri,
1035+
checkpoint_local_path=checkpoint_local_path,
9511036
)
9521037

9531038
def train_image(self):

src/sagemaker/job.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
8080
estimator.train_volume_size,
8181
estimator.train_volume_kms_key,
8282
)
83-
stop_condition = _Job._prepare_stop_condition(estimator.train_max_run)
83+
stop_condition = _Job._prepare_stop_condition(
84+
estimator.train_max_run, estimator.train_max_wait
85+
)
8486
vpc_config = estimator.get_vpc_config()
8587

8688
model_channel = _Job._prepare_channel(
@@ -312,11 +314,14 @@ def _prepare_resource_config(instance_count, instance_type, volume_size, train_v
312314
return resource_config
313315

314316
@staticmethod
315-
def _prepare_stop_condition(max_run):
317+
def _prepare_stop_condition(max_run, max_wait):
316318
"""
317319
Args:
318320
max_run:
321+
max_wait:
319322
"""
323+
if max_wait:
324+
return {"MaxRuntimeInSeconds": max_run, "MaxWaitTimeInSeconds": max_wait}
320325
return {"MaxRuntimeInSeconds": max_run}
321326

322327
@property

src/sagemaker/session.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,9 @@ def train( # noqa: C901
257257
image=None,
258258
algorithm_arn=None,
259259
encrypt_inter_container_traffic=False,
260+
train_use_spot_instances=False,
261+
checkpoint_s3_uri=None,
262+
checkpoint_local_path=None,
260263
):
261264
"""Create an Amazon SageMaker training job.
262265
@@ -307,6 +310,18 @@ def train( # noqa: C901
307310
algorithm_arn (str): Algorithm Arn from Marketplace.
308311
encrypt_inter_container_traffic (bool): Specifies whether traffic between training
309312
containers is encrypted for the training job (default: ``False``).
313+
train_use_spot_instances (bool): whether to use spot instances for training.
314+
checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
315+
that the algorithm persists (if any) during training. (default:
316+
``None``).
317+
checkpoint_local_path (str): The local path that the algorithm
318+
writes its checkpoints to. SageMaker will persist all files
319+
under this path to `checkpoint_s3_uri` continually during
320+
training. On job startup the reverse happens - data from the
321+
s3 location is downloaded to this path before the algorithm is
322+
started. If the path is unset then SageMaker assumes the
323+
checkpoints will be provided under `/opt/ml/checkpoints/`.
324+
(default: ``None``).
310325
311326
Returns:
312327
str: ARN of the training job, if it is created.
@@ -357,6 +372,15 @@ def train( # noqa: C901
357372
if encrypt_inter_container_traffic:
358373
train_request["EnableInterContainerTrafficEncryption"] = encrypt_inter_container_traffic
359374

375+
if train_use_spot_instances:
376+
train_request["EnableManagedSpotTraining"] = train_use_spot_instances
377+
378+
if checkpoint_s3_uri:
379+
checkpoint_config = {"S3Uri": checkpoint_s3_uri}
380+
if checkpoint_local_path:
381+
checkpoint_config["LocalPath"] = checkpoint_local_path
382+
train_request["CheckpointConfig"] = checkpoint_config
383+
360384
LOGGER.info("Creating training-job with name: %s", job_name)
361385
LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
362386
self.sagemaker_client.create_training_job(**train_request)
@@ -1468,10 +1492,15 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
14681492
print()
14691493
# Customers are not billed for hardware provisioning, so billable time is less than
14701494
# total time
1471-
billable_time = (
1472-
description["TrainingEndTime"] - description["TrainingStartTime"]
1473-
) * instance_count
1474-
print("Billable seconds:", int(billable_time.total_seconds()) + 1)
1495+
training_time = description.get("TrainingTimeInSeconds")
1496+
billable_time = description.get("BillableTimeInSeconds")
1497+
if training_time is not None:
1498+
print("Training seconds:", training_time * instance_count)
1499+
if billable_time is not None:
1500+
print("Billable seconds:", billable_time * instance_count)
1501+
if description.get("EnableManagedSpotTraining"):
1502+
saving = (1 - float(billable_time) / training_time) * 100
1503+
print("Managed Spot Training savings: {:.1f}%".format(saving))
14751504

14761505

14771506
def container_def(image, model_data_url=None, env=None):

tests/unit/test_estimator.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,69 @@ def test_framework_all_init_args(sagemaker_session):
227227
}
228228

229229

230+
def test_framework_with_spot_and_checkpoints(sagemaker_session):
231+
f = DummyFramework(
232+
"my_script.py",
233+
role="DummyRole",
234+
train_instance_count=3,
235+
train_instance_type="ml.m4.xlarge",
236+
sagemaker_session=sagemaker_session,
237+
train_volume_size=123,
238+
train_volume_kms_key="volumekms",
239+
train_max_run=456,
240+
input_mode="inputmode",
241+
output_path="outputpath",
242+
output_kms_key="outputkms",
243+
base_job_name="basejobname",
244+
tags=[{"foo": "bar"}],
245+
subnets=["123", "456"],
246+
security_group_ids=["789", "012"],
247+
metric_definitions=[{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}],
248+
encrypt_inter_container_traffic=True,
249+
train_use_spot_instances=True,
250+
train_max_wait=500,
251+
checkpoint_s3_uri="s3://mybucket/checkpoints/",
252+
checkpoint_local_path="/tmp/checkpoints",
253+
)
254+
_TrainingJob.start_new(f, "s3://mydata")
255+
sagemaker_session.train.assert_called_once()
256+
_, args = sagemaker_session.train.call_args
257+
assert args == {
258+
"input_mode": "inputmode",
259+
"tags": [{"foo": "bar"}],
260+
"hyperparameters": {},
261+
"image": "fakeimage",
262+
"input_config": [
263+
{
264+
"ChannelName": "training",
265+
"DataSource": {
266+
"S3DataSource": {
267+
"S3DataType": "S3Prefix",
268+
"S3DataDistributionType": "FullyReplicated",
269+
"S3Uri": "s3://mydata",
270+
}
271+
},
272+
}
273+
],
274+
"output_config": {"KmsKeyId": "outputkms", "S3OutputPath": "outputpath"},
275+
"vpc_config": {"Subnets": ["123", "456"], "SecurityGroupIds": ["789", "012"]},
276+
"stop_condition": {"MaxRuntimeInSeconds": 456, "MaxWaitTimeInSeconds": 500},
277+
"role": sagemaker_session.expand_role(),
278+
"job_name": None,
279+
"resource_config": {
280+
"VolumeSizeInGB": 123,
281+
"InstanceCount": 3,
282+
"VolumeKmsKeyId": "volumekms",
283+
"InstanceType": "ml.m4.xlarge",
284+
},
285+
"metric_definitions": [{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}],
286+
"encrypt_inter_container_traffic": True,
287+
"train_use_spot_instances": True,
288+
"checkpoint_s3_uri": "s3://mybucket/checkpoints/",
289+
"checkpoint_local_path": "/tmp/checkpoints",
290+
}
291+
292+
230293
def test_framework_init_s3_entry_point_invalid(sagemaker_session):
231294
with pytest.raises(ValueError) as error:
232295
DummyFramework(

tests/unit/test_job.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,10 +563,22 @@ def test_prepare_resource_config_with_volume_kms():
563563

564564
def test_prepare_stop_condition():
565565
max_run = 1
566+
max_wait = 2
566567

567-
stop_condition = _Job._prepare_stop_condition(max_run)
568+
stop_condition = _Job._prepare_stop_condition(max_run, max_wait)
568569

569570
assert stop_condition["MaxRuntimeInSeconds"] == max_run
571+
assert stop_condition["MaxWaitTimeInSeconds"] == max_wait
572+
573+
574+
def test_prepare_stop_condition_no_wait():
575+
max_run = 1
576+
max_wait = None
577+
578+
stop_condition = _Job._prepare_stop_condition(max_run, max_wait)
579+
580+
assert stop_condition["MaxRuntimeInSeconds"] == max_run
581+
assert "MaxWaitTimeInSeconds" not in stop_condition
570582

571583

572584
def test_name(sagemaker_session):

tests/unit/test_session.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,9 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
651651
tags=TAGS,
652652
metric_definitions=METRIC_DEFINITONS,
653653
encrypt_inter_container_traffic=True,
654+
train_use_spot_instances=True,
655+
checkpoint_s3_uri="s3://mybucket/checkpoints/",
656+
checkpoint_local_path="/tmp/checkpoints",
654657
)
655658

656659
_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]
@@ -660,6 +663,9 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
660663
assert actual_train_args["Tags"] == TAGS
661664
assert actual_train_args["AlgorithmSpecification"]["MetricDefinitions"] == METRIC_DEFINITONS
662665
assert actual_train_args["EnableInterContainerTrafficEncryption"] is True
666+
assert actual_train_args["EnableManagedSpotTraining"] is True
667+
assert actual_train_args["CheckpointConfig"]["S3Uri"] == "s3://mybucket/checkpoints/"
668+
assert actual_train_args["CheckpointConfig"]["LocalPath"] == "/tmp/checkpoints"
663669

664670

665671
def test_transform_pack_to_request(sagemaker_session):

0 commit comments

Comments
 (0)