Skip to content

Commit d351a42

Browse files
committed
breaking: default wait=True for HyperparameterTuner.fit() and Transformer.transform()
1 parent abd873e commit d351a42

File tree

3 files changed

+21
-45
lines changed

3 files changed

+21
-45
lines changed

src/sagemaker/transformer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ def transform(
121121
join_source=None,
122122
experiment_config=None,
123123
model_client_config=None,
124-
wait=False,
125-
logs=False,
124+
wait=True,
125+
logs=True,
126126
):
127127
"""Start a new transform job.
128128
@@ -178,9 +178,9 @@ def transform(
178178
'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'.
179179
(default: ``None``).
180180
wait (bool): Whether the call should wait until the job completes
181-
(default: False).
181+
(default: ``True``).
182182
logs (bool): Whether to show the logs produced by the job.
183-
Only meaningful when wait is True (default: False).
183+
Only meaningful when wait is ``True`` (default: ``True``).
184184
"""
185185
local_mode = self.sagemaker_session.local_mode
186186
if not local_mode and not data.startswith("s3://"):

src/sagemaker/tuner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ def fit(
369369
job_name=None,
370370
include_cls_metadata=False,
371371
estimator_kwargs=None,
372+
wait=True,
372373
**kwargs
373374
):
374375
"""Start a hyperparameter tuning job.
@@ -424,6 +425,7 @@ def fit(
424425
The keys are the estimator names for the estimator_dict argument of create()
425426
method. Each value is a dictionary for the other arguments needed for training
426427
of the corresponding estimator.
428+
wait (bool): Whether the call should wait until the job completes (default: ``True``).
427429
**kwargs: Other arguments needed for training. Please refer to the
428430
``fit()`` method of the associated estimator to see what other
429431
arguments are needed.
@@ -433,6 +435,9 @@ def fit(
433435
else:
434436
self._fit_with_estimator_dict(inputs, job_name, include_cls_metadata, estimator_kwargs)
435437

438+
if wait:
439+
self.latest_tuning_job.wait()
440+
436441
def _fit_with_estimator(self, inputs, job_name, include_cls_metadata, **kwargs):
437442
"""Start tuning for tuner instances that have the ``estimator`` field set"""
438443
self._prepare_estimator_for_tuning(self.estimator, inputs, job_name, **kwargs)

tests/integ/test_tuner.py

Lines changed: 12 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _tune(
130130
hyperparameter_ranges=None,
131131
job_name=None,
132132
warm_start_config=None,
133-
wait_till_terminal=True,
133+
wait=True,
134134
max_jobs=2,
135135
max_parallel_jobs=2,
136136
early_stopping_type="Off",
@@ -155,7 +155,7 @@ def _tune(
155155
tuner.fit([records, test_record_set], job_name=job_name)
156156
print("Started hyperparameter tuning job with name:" + tuner.latest_tuning_job.name)
157157

158-
if wait_till_terminal:
158+
if wait:
159159
tuner.wait()
160160

161161
return tuner
@@ -388,7 +388,7 @@ def test_tuning_kmeans_identical_dataset_algorithm_tuner_from_non_terminal_paren
388388
kmeans_train_set,
389389
job_name=parent_tuning_job_name,
390390
hyperparameter_ranges=hyperparameter_ranges,
391-
wait_till_terminal=False,
391+
wait=False,
392392
max_parallel_jobs=1,
393393
max_jobs=1,
394394
)
@@ -453,15 +453,9 @@ def test_tuning_lda(sagemaker_session, cpu_instance_type):
453453
)
454454

455455
tuning_job_name = unique_name_from_base("test-lda", max_length=32)
456+
print("Started hyperparameter tuning job with name:" + tuning_job_name)
456457
tuner.fit([record_set, test_record_set], mini_batch_size=1, job_name=tuning_job_name)
457458

458-
latest_tuning_job_name = tuner.latest_tuning_job.name
459-
460-
print("Started hyperparameter tuning job with name:" + latest_tuning_job_name)
461-
462-
time.sleep(15)
463-
tuner.wait()
464-
465459
attached_tuner = HyperparameterTuner.attach(
466460
tuning_job_name, sagemaker_session=sagemaker_session
467461
)
@@ -575,12 +569,8 @@ def test_tuning_mxnet(
575569
)
576570

577571
tuning_job_name = unique_name_from_base("tune-mxnet", max_length=32)
578-
tuner.fit({"train": train_input, "test": test_input}, job_name=tuning_job_name)
579-
580572
print("Started hyperparameter tuning job with name:" + tuning_job_name)
581-
582-
time.sleep(15)
583-
tuner.wait()
573+
tuner.fit({"train": train_input, "test": test_input}, job_name=tuning_job_name)
584574

585575
best_training_job = tuner.best_training_job()
586576
with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session):
@@ -628,12 +618,8 @@ def test_tuning_tf(
628618
)
629619

630620
tuning_job_name = unique_name_from_base("tune-tf", max_length=32)
631-
tuner.fit(inputs, job_name=tuning_job_name)
632-
633621
print("Started hyperparameter tuning job with name: " + tuning_job_name)
634-
635-
time.sleep(15)
636-
tuner.wait()
622+
tuner.fit(inputs, job_name=tuning_job_name)
637623

638624

639625
def test_tuning_tf_vpc_multi(
@@ -686,12 +672,8 @@ def test_tuning_tf_vpc_multi(
686672
)
687673

688674
tuning_job_name = unique_name_from_base("tune-tf", max_length=32)
689-
tuner.fit(inputs, job_name=tuning_job_name)
690-
691675
print(f"Started hyperparameter tuning job with name: {tuning_job_name}")
692-
693-
time.sleep(15)
694-
tuner.wait()
676+
tuner.fit(inputs, job_name=tuning_job_name)
695677

696678

697679
@pytest.mark.canary_quick
@@ -740,13 +722,9 @@ def test_tuning_chainer(
740722
)
741723

742724
tuning_job_name = unique_name_from_base("chainer", max_length=32)
725+
print("Started hyperparameter tuning job with name: {}".format(tuning_job_name))
743726
tuner.fit({"train": train_input, "test": test_input}, job_name=tuning_job_name)
744727

745-
print("Started hyperparameter tuning job with name:" + tuning_job_name)
746-
747-
time.sleep(15)
748-
tuner.wait()
749-
750728
best_training_job = tuner.best_training_job()
751729
with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session):
752730
predictor = tuner.deploy(1, cpu_instance_type)
@@ -812,13 +790,9 @@ def test_attach_tuning_pytorch(
812790
)
813791

814792
tuning_job_name = unique_name_from_base("pytorch", max_length=32)
793+
print("Started hyperparameter tuning job with name: {}".format(tuning_job_name))
815794
tuner.fit({"training": training_data}, job_name=tuning_job_name)
816795

817-
print("Started hyperparameter tuning job with name:" + tuning_job_name)
818-
819-
time.sleep(15)
820-
tuner.wait()
821-
822796
endpoint_name = tuning_job_name
823797
model_name = "model-name-1"
824798
attached_tuner = HyperparameterTuner.attach(
@@ -887,17 +861,14 @@ def test_tuning_byo_estimator(sagemaker_session, cpu_instance_type):
887861
max_parallel_jobs=2,
888862
)
889863

864+
tuning_job_name = unique_name_from_base("byo", 32)
865+
print("Started hyperparameter tuning job with name {}:".format(tuning_job_name))
890866
tuner.fit(
891867
{"train": s3_train_data, "test": s3_train_data},
892868
include_cls_metadata=False,
893-
job_name=unique_name_from_base("byo", 32),
869+
job_name=tuning_job_name,
894870
)
895871

896-
print("Started hyperparameter tuning job with name:" + tuner.latest_tuning_job.name)
897-
898-
time.sleep(15)
899-
tuner.wait()
900-
901872
best_training_job = tuner.best_training_job()
902873
with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session):
903874
predictor = tuner.deploy(1, cpu_instance_type, endpoint_name=best_training_job)

0 commit comments

Comments
 (0)