Skip to content

Commit aa14097

Browse files
Add wait for tuner (aws#37)
1 parent abbdbd5 commit aa14097

File tree

4 files changed

+79
-6
lines changed

4 files changed

+79
-6
lines changed

src/sagemaker/session.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -479,21 +479,39 @@ def wait_for_job(self, job, poll=5):
479479
ValueError: If the training job fails.
480480
"""
481481
desc = _wait_until(lambda: _train_done(self.sagemaker_client, job), poll)
482-
self._check_job_status(job, desc)
482+
self._check_job_status(job, desc, 'TrainingJobStatus')
483483
return desc
484484

485-
def _check_job_status(self, job, desc):
485+
def wait_for_tuning_job(self, job, poll=5):
486+
"""Wait for an Amazon SageMaker tuning job to complete.
487+
488+
Args:
489+
job (str): Name of the tuning job to wait for.
490+
poll (int): Polling interval in seconds (default: 5).
491+
492+
Returns:
493+
(dict): Return value from the ``DescribeHyperParameterTuningJob`` API.
494+
495+
Raises:
496+
ValueError: If the hyperparameter tuning job fails.
497+
"""
498+
desc = _wait_until(lambda: _tuning_job_status(self.sagemaker_client, job), poll)
499+
self._check_job_status(job, desc, 'HyperParameterTuningJobStatus')
500+
return desc
501+
502+
def _check_job_status(self, job, desc, status_key_name):
486503
"""Check to see if the job completed successfully and, if not, construct and
487504
raise a ValueError.
488505
489506
Args:
490507
job (str): The name of the job to check.
491508
desc (dict[str, str]): The result of ``describe_training_job()``.
509+
status_key_name (str): Status key name to check for.
492510
493511
Raises:
494512
ValueError: If the training job fails.
495513
"""
496-
status = desc['TrainingJobStatus']
514+
status = desc[status_key_name]
497515

498516
if status != 'Completed' and status != 'Stopped':
499517
reason = desc.get('FailureReason', '(No reason provided)')
@@ -757,7 +775,7 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
757775
state = LogState.JOB_COMPLETE
758776

759777
if wait:
760-
self._check_job_status(job_name, description)
778+
self._check_job_status(job_name, description, 'TrainingJobStatus')
761779
if dot:
762780
print()
763781
print('===== Job Complete =====')
@@ -909,7 +927,7 @@ def _train_done(sagemaker_client, job_name):
909927
return desc
910928

911929

912-
def _tune_done(sagemaker_client, job_name):
930+
def _tuning_job_status(sagemaker_client, job_name):
913931
tuning_status_codes = {
914932
'Completed': '!',
915933
'InProgress': '.',

src/sagemaker/tuner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,12 @@ def stop_tuning_job(self):
153153
self._ensure_last_tuning_job()
154154
self.latest_tuning_job.stop()
155155

156+
def wait(self):
157+
"""Wait for latest tuning job to finish.
158+
"""
159+
self._ensure_last_tuning_job()
160+
self.latest_tuning_job.wait()
161+
156162
def best_training_job(self):
157163
"""Return name of the best training job for the latest tuning job.
158164
"""
@@ -258,4 +264,4 @@ def stop(self):
258264
self.sagemaker_session.stop_tuning_job(HyperParameterTuningJobName=self.name)
259265

260266
def wait(self):
261-
pass
267+
self.sagemaker_session.wait_for_tuning_job(self.name)

tests/unit/test_session.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
from botocore.exceptions import ClientError
2424

25+
from sagemaker.session import _tuning_job_status
26+
2527
REGION = 'us-west-2'
2628

2729

@@ -451,3 +453,32 @@ def test_endpoint_from_production_variants(sagemaker_session):
451453
'InitialVariantWeight': 1,
452454
'InitialInstanceCount': 1,
453455
'VariantName': 'AllTraffic'}])
456+
457+
458+
def test_wait_for_tuning_job(sagemaker_session):
459+
hyperparameter_tuning_job_desc = {'HyperParameterTuningJobStatus': 'Completed'}
460+
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(
461+
name='describe_hyper_parameter_tuning_job', return_value=hyperparameter_tuning_job_desc)
462+
463+
result = sagemaker_session.wait_for_tuning_job(JOB_NAME)
464+
assert result['HyperParameterTuningJobStatus'] == 'Completed'
465+
466+
467+
def test_tune_job_status(sagemaker_session):
468+
hyperparameter_tuning_job_desc = {'HyperParameterTuningJobStatus': 'Completed'}
469+
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(
470+
name='describe_hyper_parameter_tuning_job', return_value=hyperparameter_tuning_job_desc)
471+
472+
result = _tuning_job_status(sagemaker_session.sagemaker_client, JOB_NAME)
473+
474+
assert result['HyperParameterTuningJobStatus'] == 'Completed'
475+
476+
477+
def test_tune_job_status_none(sagemaker_session):
478+
hyperparameter_tuning_job_desc = {'HyperParameterTuningJobStatus': 'InProgress'}
479+
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(
480+
name='describe_hyper_parameter_tuning_job', return_value=hyperparameter_tuning_job_desc)
481+
482+
result = _tuning_job_status(sagemaker_session.sagemaker_client, JOB_NAME)
483+
484+
assert result is None

tests/unit/test_tuner.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,15 @@ def test_deploy_default(tuner):
275275
assert predictor.sagemaker_session == tuner.estimator.sagemaker_session
276276

277277

278+
def test_wait(tuner):
279+
tuner.latest_tuning_job = _TuningJob(tuner.estimator.sagemaker_session, JOB_NAME)
280+
tuner.estimator.sagemaker_session.wait_for_tuning_job = Mock(name='wait_for_tuning_job')
281+
282+
tuner.wait()
283+
284+
tuner.estimator.sagemaker_session.wait_for_tuning_job.assert_called_once_with(JOB_NAME)
285+
286+
278287
#################################################################################
279288
# _ParameterRange Tests
280289

@@ -353,3 +362,12 @@ def test_stop(sagemaker_session):
353362
tuning_job.stop()
354363

355364
sagemaker_session.stop_tuning_job.assert_called_once_with(HyperParameterTuningJobName=JOB_NAME)
365+
366+
367+
def test_tuning_job_wait(sagemaker_session):
368+
sagemaker_session.wait_for_tuning_job = Mock(name='wait_for_tuning_job')
369+
370+
tuning_job = _TuningJob(sagemaker_session, JOB_NAME)
371+
tuning_job.wait()
372+
373+
sagemaker_session.wait_for_tuning_job.assert_called_once_with(JOB_NAME)

0 commit comments

Comments
 (0)