Skip to content

Commit 3c96986

Browse files
authored
breaking: default wait=True for HyperparameterTuner.fit() and Transformer.transform() (#1790)
1 parent 4fb245a commit 3c96986

File tree

4 files changed

+24
-54
lines changed

4 files changed

+24
-54
lines changed

src/sagemaker/transformer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ def transform(
121121
join_source=None,
122122
experiment_config=None,
123123
model_client_config=None,
124-
wait=False,
125-
logs=False,
124+
wait=True,
125+
logs=True,
126126
):
127127
"""Start a new transform job.
128128
@@ -178,9 +178,9 @@ def transform(
178178
'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'.
179179
(default: ``None``).
180180
wait (bool): Whether the call should wait until the job completes
181-
(default: False).
181+
(default: ``True``).
182182
logs (bool): Whether to show the logs produced by the job.
183-
Only meaningful when wait is True (default: False).
183+
Only meaningful when wait is ``True`` (default: ``True``).
184184
"""
185185
local_mode = self.sagemaker_session.local_mode
186186
if not local_mode and not data.startswith("s3://"):

src/sagemaker/tuner.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ def fit(
369369
job_name=None,
370370
include_cls_metadata=False,
371371
estimator_kwargs=None,
372+
wait=True,
372373
**kwargs
373374
):
374375
"""Start a hyperparameter tuning job.
@@ -424,6 +425,7 @@ def fit(
424425
The keys are the estimator names for the estimator_dict argument of create()
425426
method. Each value is a dictionary for the other arguments needed for training
426427
of the corresponding estimator.
428+
wait (bool): Whether the call should wait until the job completes (default: ``True``).
427429
**kwargs: Other arguments needed for training. Please refer to the
428430
``fit()`` method of the associated estimator to see what other
429431
arguments are needed.
@@ -433,6 +435,9 @@ def fit(
433435
else:
434436
self._fit_with_estimator_dict(inputs, job_name, include_cls_metadata, estimator_kwargs)
435437

438+
if wait:
439+
self.latest_tuning_job.wait()
440+
436441
def _fit_with_estimator(self, inputs, job_name, include_cls_metadata, **kwargs):
437442
"""Start tuning for tuner instances that have the ``estimator`` field set"""
438443
self._prepare_estimator_for_tuning(self.estimator, inputs, job_name, **kwargs)
@@ -1447,9 +1452,6 @@ def start_new(cls, tuner, inputs):
14471452
sagemaker.tuner._TuningJob: Constructed object that captures all
14481453
information about the started job.
14491454
"""
1450-
1451-
logger.info("_TuningJob.start_new!!!")
1452-
14531455
warm_start_config_req = None
14541456
if tuner.warm_start_config:
14551457
warm_start_config_req = tuner.warm_start_config.to_input_req()

tests/integ/test_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def test_single_transformer_multiple_jobs(
325325

326326
def test_stop_transform_job(mxnet_estimator, mxnet_transform_input, cpu_instance_type):
327327
transformer = mxnet_estimator.transformer(1, cpu_instance_type)
328-
transformer.transform(mxnet_transform_input, content_type="text/csv")
328+
transformer.transform(mxnet_transform_input, content_type="text/csv", wait=False)
329329

330330
time.sleep(15)
331331

tests/integ/test_tuner.py

Lines changed: 14 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _tune(
130130
hyperparameter_ranges=None,
131131
job_name=None,
132132
warm_start_config=None,
133-
wait_till_terminal=True,
133+
wait=True,
134134
max_jobs=2,
135135
max_parallel_jobs=2,
136136
early_stopping_type="Off",
@@ -152,11 +152,8 @@ def _tune(
152152
records = kmeans_estimator.record_set(kmeans_train_set[0][:100])
153153
test_record_set = kmeans_estimator.record_set(kmeans_train_set[0][:100], channel="test")
154154

155-
tuner.fit([records, test_record_set], job_name=job_name)
156-
print("Started hyperparameter tuning job with name:" + tuner.latest_tuning_job.name)
157-
158-
if wait_till_terminal:
159-
tuner.wait()
155+
print("Started hyperparameter tuning job with name: {}".format(job_name))
156+
tuner.fit([records, test_record_set], job_name=job_name, wait=wait)
160157

161158
return tuner
162159

@@ -388,7 +385,7 @@ def test_tuning_kmeans_identical_dataset_algorithm_tuner_from_non_terminal_paren
388385
kmeans_train_set,
389386
job_name=parent_tuning_job_name,
390387
hyperparameter_ranges=hyperparameter_ranges,
391-
wait_till_terminal=False,
388+
wait=False,
392389
max_parallel_jobs=1,
393390
max_jobs=1,
394391
)
@@ -453,15 +450,9 @@ def test_tuning_lda(sagemaker_session, cpu_instance_type):
453450
)
454451

455452
tuning_job_name = unique_name_from_base("test-lda", max_length=32)
453+
print("Started hyperparameter tuning job with name:" + tuning_job_name)
456454
tuner.fit([record_set, test_record_set], mini_batch_size=1, job_name=tuning_job_name)
457455

458-
latest_tuning_job_name = tuner.latest_tuning_job.name
459-
460-
print("Started hyperparameter tuning job with name:" + latest_tuning_job_name)
461-
462-
time.sleep(15)
463-
tuner.wait()
464-
465456
attached_tuner = HyperparameterTuner.attach(
466457
tuning_job_name, sagemaker_session=sagemaker_session
467458
)
@@ -516,7 +507,7 @@ def test_stop_tuning_job(sagemaker_session, cpu_instance_type):
516507
)
517508

518509
tuning_job_name = unique_name_from_base("test-randomcutforest", max_length=32)
519-
tuner.fit([records, test_records], tuning_job_name)
510+
tuner.fit([records, test_records], tuning_job_name, wait=False)
520511

521512
time.sleep(15)
522513

@@ -575,12 +566,8 @@ def test_tuning_mxnet(
575566
)
576567

577568
tuning_job_name = unique_name_from_base("tune-mxnet", max_length=32)
578-
tuner.fit({"train": train_input, "test": test_input}, job_name=tuning_job_name)
579-
580569
print("Started hyperparameter tuning job with name:" + tuning_job_name)
581-
582-
time.sleep(15)
583-
tuner.wait()
570+
tuner.fit({"train": train_input, "test": test_input}, job_name=tuning_job_name)
584571

585572
best_training_job = tuner.best_training_job()
586573
with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session):
@@ -628,12 +615,8 @@ def test_tuning_tf(
628615
)
629616

630617
tuning_job_name = unique_name_from_base("tune-tf", max_length=32)
631-
tuner.fit(inputs, job_name=tuning_job_name)
632-
633618
print("Started hyperparameter tuning job with name: " + tuning_job_name)
634-
635-
time.sleep(15)
636-
tuner.wait()
619+
tuner.fit(inputs, job_name=tuning_job_name)
637620

638621

639622
def test_tuning_tf_vpc_multi(
@@ -686,12 +669,8 @@ def test_tuning_tf_vpc_multi(
686669
)
687670

688671
tuning_job_name = unique_name_from_base("tune-tf", max_length=32)
689-
tuner.fit(inputs, job_name=tuning_job_name)
690-
691672
print(f"Started hyperparameter tuning job with name: {tuning_job_name}")
692-
693-
time.sleep(15)
694-
tuner.wait()
673+
tuner.fit(inputs, job_name=tuning_job_name)
695674

696675

697676
@pytest.mark.canary_quick
@@ -740,13 +719,9 @@ def test_tuning_chainer(
740719
)
741720

742721
tuning_job_name = unique_name_from_base("chainer", max_length=32)
722+
print("Started hyperparameter tuning job with name: {}".format(tuning_job_name))
743723
tuner.fit({"train": train_input, "test": test_input}, job_name=tuning_job_name)
744724

745-
print("Started hyperparameter tuning job with name:" + tuning_job_name)
746-
747-
time.sleep(15)
748-
tuner.wait()
749-
750725
best_training_job = tuner.best_training_job()
751726
with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session):
752727
predictor = tuner.deploy(1, cpu_instance_type)
@@ -812,13 +787,9 @@ def test_attach_tuning_pytorch(
812787
)
813788

814789
tuning_job_name = unique_name_from_base("pytorch", max_length=32)
790+
print("Started hyperparameter tuning job with name: {}".format(tuning_job_name))
815791
tuner.fit({"training": training_data}, job_name=tuning_job_name)
816792

817-
print("Started hyperparameter tuning job with name:" + tuning_job_name)
818-
819-
time.sleep(15)
820-
tuner.wait()
821-
822793
endpoint_name = tuning_job_name
823794
model_name = "model-name-1"
824795
attached_tuner = HyperparameterTuner.attach(
@@ -887,17 +858,14 @@ def test_tuning_byo_estimator(sagemaker_session, cpu_instance_type):
887858
max_parallel_jobs=2,
888859
)
889860

861+
tuning_job_name = unique_name_from_base("byo", 32)
862+
print("Started hyperparameter tuning job with name {}:".format(tuning_job_name))
890863
tuner.fit(
891864
{"train": s3_train_data, "test": s3_train_data},
892865
include_cls_metadata=False,
893-
job_name=unique_name_from_base("byo", 32),
866+
job_name=tuning_job_name,
894867
)
895868

896-
print("Started hyperparameter tuning job with name:" + tuner.latest_tuning_job.name)
897-
898-
time.sleep(15)
899-
tuner.wait()
900-
901869
best_training_job = tuner.best_training_job()
902870
with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session):
903871
predictor = tuner.deploy(

0 commit comments

Comments
 (0)