Skip to content

Commit f2209de

Browse files
Add stop tuning job (aws#29)
1 parent 6dc7905 commit f2209de

File tree

4 files changed

+91
-3
lines changed

4 files changed

+91
-3
lines changed

src/sagemaker/session.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,25 @@ def tune(self, job_name, strategy, objective, metric_name,
333333
LOGGER.debug('tune request: {}'.format(json.dumps(tune_request, indent=4)))
334334
self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request)
335335

336+
def stop_tuning_job(self, name):
337+
"""Attempts to stop tuning job on Amazon SageMaker with specified name.
338+
339+
Args:
340+
name: Name of Amazon SageMaker tuning job.
341+
"""
342+
try:
343+
LOGGER.info('Stopping tuning job: {}'.format(name))
344+
self.sagemaker_client.stop_hyper_parameter_tuning_job(HyperParameterTuningJobName=name)
345+
except ClientError as e:
346+
error_code = e.response['Error']['Code']
347+
# allow to pass if the job already stopped
348+
if error_code == 'ValidationException':
349+
LOGGER.info('Tuning job: {} is already stopped or not running.'.format(name))
350+
pass
351+
else:
352+
LOGGER.error('Error occurred while attempting to stop tuning job: {}. Please try again.'.format(name))
353+
raise
354+
336355
def create_model(self, name, role, primary_container):
337356
"""Create an Amazon SageMaker ``Model``.
338357

src/sagemaker/tuner.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,28 @@ def __init__(self, estimator, objective_metric_name, hyperparameter_ranges, metr
7676
self.latest_tuning_job = None
7777

7878
def fit(self, inputs):
79-
"""Create HPO job
79+
"""Create tuning job
8080
8181
Args:
8282
inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.
8383
"""
8484
self.latest_tuning_job = _TuningJob.start_new(self, inputs)
8585

86+
def stop_tuning_job(self):
87+
"""Stop latest running tuning job.
88+
"""
89+
self._ensure_last_tuning_job()
90+
self.latest_tuning_job.stop()
91+
92+
def _ensure_last_tuning_job(self):
93+
if 'latest_tuning_job' not in dir(self) or self.latest_tuning_job is None:
94+
raise ValueError('No tuning job available')
95+
8696
def hyperparameter_ranges(self):
8797
"""Return collections of ``ParameterRanges``
8898
8999
Returns:
90-
dict: ParameterRanges suitable for HPO tuning job.
100+
dict: ParameterRanges suitable for tuning job.
91101
"""
92102
hyperparameter_ranges = dict()
93103
for range_type in _ParameterRange.__all_types__:
@@ -108,7 +118,7 @@ def __init__(self, sagemaker_session, tuning_job_name):
108118

109119
@classmethod
110120
def start_new(cls, tuner, inputs):
111-
"""Create a new Amazon SageMaker HPO tuning job from the HyperparameterTuner.
121+
"""Create a new Amazon SageMaker tuning job from the HyperparameterTuner.
112122
113123
Args:
114124
tuner (sagemaker.tuner.HyperparameterTuner): Tuner object created by the user.
@@ -144,5 +154,8 @@ def start_new(cls, tuner, inputs):
144154

145155
return cls(tuner.estimator.sagemaker_session, tuning_job_name)
146156

157+
def stop(self):
158+
self.sagemaker_session.stop_tuning_job(HyperParameterTuningJobName=self.name)
159+
147160
def wait(self):
148161
pass

tests/unit/test_session.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,40 @@ def test_train_pack_to_request(sagemaker_session):
230230
'create_training_job', (), DEFAULT_EXPECTED_TRAIN_JOB_ARGS)
231231

232232

233+
def test_stop_tuning_job(sagemaker_session):
234+
sms = sagemaker_session
235+
sms.sagemaker_client.stop_hyper_parameter_tuning_job = Mock(name='stop_hyper_parameter_tuning_job')
236+
237+
sagemaker_session.stop_tuning_job(JOB_NAME)
238+
sms.sagemaker_client.stop_hyper_parameter_tuning_job.assert_called_once_with(HyperParameterTuningJobName=JOB_NAME)
239+
240+
241+
def test_stop_tuning_job_client_error_already_stopped(sagemaker_session):
242+
sms = sagemaker_session
243+
exception = ClientError({'Error': {'Code': 'ValidationException'}}, 'Operation')
244+
sms.sagemaker_client.stop_hyper_parameter_tuning_job = Mock(name='stop_hyper_parameter_tuning_job',
245+
side_effect=exception)
246+
sagemaker_session.stop_tuning_job(JOB_NAME)
247+
248+
sms.sagemaker_client.stop_hyper_parameter_tuning_job.assert_called_once_with(HyperParameterTuningJobName=JOB_NAME)
249+
250+
251+
def test_stop_tuning_job_client_error(sagemaker_session):
252+
error_response = {'Error': {'Code': 'MockException', 'Message': 'MockMessage'}}
253+
operation = 'Operation'
254+
exception = ClientError(error_response, operation)
255+
256+
sms = sagemaker_session
257+
sms.sagemaker_client.stop_hyper_parameter_tuning_job = Mock(name='stop_hyper_parameter_tuning_job',
258+
side_effect=exception)
259+
260+
with pytest.raises(ClientError) as e:
261+
sagemaker_session.stop_tuning_job(JOB_NAME)
262+
263+
sms.sagemaker_client.stop_hyper_parameter_tuning_job.assert_called_once_with(HyperParameterTuningJobName=JOB_NAME)
264+
assert 'An error occurred (MockException) when calling the Operation operation: MockMessage' in str(e)
265+
266+
233267
@patch('sys.stdout', new_callable=io.BytesIO if six.PY2 else io.StringIO)
234268
def test_color_wrap(bio):
235269
color_wrap = sagemaker.logs.ColorWrap()

tests/unit/test_tuner.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,21 @@ def test_serialize_nonexistent_parameter_ranges(tuner):
102102
assert not ranges[parameter_type + 'ParameterRanges']
103103

104104

105+
def test_stop_tuning_job(sagemaker_session, tuner):
106+
sagemaker_session.stop_tuning_job = Mock(name='stop_hyper_parameter_tuning_job')
107+
tuner.latest_tuning_job = _TuningJob(sagemaker_session, JOB_NAME)
108+
109+
tuner.stop_tuning_job()
110+
111+
sagemaker_session.stop_tuning_job.assert_called_once_with(HyperParameterTuningJobName=JOB_NAME)
112+
113+
114+
def test_stop_tuning_job_no_tuning_job(tuner):
115+
with pytest.raises(ValueError) as e:
116+
tuner.stop_tuning_job()
117+
assert 'No HPO job available' in str(e)
118+
119+
105120
#################################################################################
106121
# _ParameterRange Tests
107122

@@ -171,3 +186,10 @@ def test_start_new(tuner, sagemaker_session):
171186

172187
assert started_tuning_job.sagemaker_session == sagemaker_session
173188
sagemaker_session.tune.assert_called_once()
189+
190+
191+
def test_stop(sagemaker_session):
192+
tuning_job = _TuningJob(sagemaker_session, JOB_NAME)
193+
tuning_job.stop()
194+
195+
sagemaker_session.stop_tuning_job.assert_called_once_with(HyperParameterTuningJobName=JOB_NAME)

0 commit comments

Comments
 (0)