Skip to content

Commit c701100

Browse files
speg03laurenyu
authored andcommitted
fix: hyperparameter tuning with spot instances and checkpoints (#1015)
1 parent f81164a commit c701100

File tree

3 files changed

+74
-1
lines changed

3 files changed

+74
-1
lines changed

src/sagemaker/session.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def compile_model(
425425
LOGGER.info("Creating compilation-job with name: %s", job_name)
426426
self.sagemaker_client.create_compilation_job(**compilation_job_request)
427427

428-
def tune(
428+
def tune( # noqa: C901
429429
self,
430430
job_name,
431431
strategy,
@@ -450,6 +450,9 @@ def tune(
450450
early_stopping_type="Off",
451451
encrypt_inter_container_traffic=False,
452452
vpc_config=None,
453+
train_use_spot_instances=False,
454+
checkpoint_s3_uri=None,
455+
checkpoint_local_path=None,
453456
):
454457
"""Create an Amazon SageMaker hyperparameter tuning job
455458
@@ -512,6 +515,18 @@ def tune(
512515
The key in vpc_config is 'Subnets'.
513516
* security_group_ids (list[str]): List of security group ids.
514517
The key in vpc_config is 'SecurityGroupIds'.
518+
train_use_spot_instances (bool): whether to use spot instances for training.
519+
checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
520+
that the algorithm persists (if any) during training. (default:
521+
``None``).
522+
checkpoint_local_path (str): The local path that the algorithm
523+
writes its checkpoints to. SageMaker will persist all files
524+
under this path to `checkpoint_s3_uri` continually during
525+
training. On job startup the reverse happens - data from the
526+
s3 location is downloaded to this path before the algorithm is
527+
started. If the path is unset then SageMaker assumes the
528+
checkpoints will be provided under `/opt/ml/checkpoints/`.
529+
(default: ``None``).
515530
516531
"""
517532
tune_request = {
@@ -569,6 +584,15 @@ def tune(
569584
if encrypt_inter_container_traffic:
570585
tune_request["TrainingJobDefinition"]["EnableInterContainerTrafficEncryption"] = True
571586

587+
if train_use_spot_instances:
588+
tune_request["TrainingJobDefinition"]["EnableManagedSpotTraining"] = True
589+
590+
if checkpoint_s3_uri:
591+
checkpoint_config = {"S3Uri": checkpoint_s3_uri}
592+
if checkpoint_local_path:
593+
checkpoint_config["LocalPath"] = checkpoint_local_path
594+
tune_request["TrainingJobDefinition"]["CheckpointConfig"] = checkpoint_config
595+
572596
LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name)
573597
LOGGER.debug("tune request: %s", json.dumps(tune_request, indent=4))
574598
self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request)

src/sagemaker/tuner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,10 @@ def start_new(cls, tuner, inputs):
890890
"encrypt_inter_container_traffic"
891891
] = tuner.estimator.encrypt_inter_container_traffic
892892

893+
tuner_args["train_use_spot_instances"] = tuner.estimator.train_use_spot_instances
894+
tuner_args["checkpoint_s3_uri"] = tuner.estimator.checkpoint_s3_uri
895+
tuner_args["checkpoint_local_path"] = tuner.estimator.checkpoint_local_path
896+
893897
tuner.estimator.sagemaker_session.tune(**tuner_args)
894898

895899
return cls(tuner.sagemaker_session, tuner._current_job_name)

tests/unit/test_session.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,51 @@ def assert_create_tuning_job_request(**kwrags):
565565
)
566566

567567

568+
def test_tune_with_spot_and_checkpoints(sagemaker_session):
569+
def assert_create_tuning_job_request(**kwrags):
570+
assert (
571+
kwrags["HyperParameterTuningJobConfig"]
572+
== SAMPLE_TUNING_JOB_REQUEST["HyperParameterTuningJobConfig"]
573+
)
574+
assert kwrags["HyperParameterTuningJobName"] == "dummy-tuning-1"
575+
assert kwrags["TrainingJobDefinition"]["EnableManagedSpotTraining"] is True
576+
assert (
577+
kwrags["TrainingJobDefinition"]["CheckpointConfig"]["S3Uri"]
578+
== "s3://mybucket/checkpoints/"
579+
)
580+
assert (
581+
kwrags["TrainingJobDefinition"]["CheckpointConfig"]["LocalPath"] == "/tmp/checkpoints"
582+
)
583+
assert kwrags.get("WarmStartConfig", None) is None
584+
585+
sagemaker_session.sagemaker_client.create_hyper_parameter_tuning_job.side_effect = (
586+
assert_create_tuning_job_request
587+
)
588+
sagemaker_session.tune(
589+
job_name="dummy-tuning-1",
590+
strategy="Bayesian",
591+
objective_type="Maximize",
592+
objective_metric_name="val-score",
593+
max_jobs=100,
594+
max_parallel_jobs=5,
595+
parameter_ranges=SAMPLE_PARAM_RANGES,
596+
static_hyperparameters=STATIC_HPs,
597+
image="dummy-image-1",
598+
input_mode="File",
599+
metric_definitions=SAMPLE_METRIC_DEF,
600+
role=EXPANDED_ROLE,
601+
input_config=SAMPLE_INPUT,
602+
output_config=SAMPLE_OUTPUT,
603+
resource_config=RESOURCE_CONFIG,
604+
stop_condition=SAMPLE_STOPPING_CONDITION,
605+
tags=None,
606+
warm_start_config=None,
607+
train_use_spot_instances=True,
608+
checkpoint_s3_uri="s3://mybucket/checkpoints/",
609+
checkpoint_local_path="/tmp/checkpoints",
610+
)
611+
612+
568613
def test_stop_tuning_job(sagemaker_session):
569614
sms = sagemaker_session
570615
sms.sagemaker_client.stop_hyper_parameter_tuning_job = Mock(

0 commit comments

Comments
 (0)