Skip to content

Commit 0f3f3c8

Browse files
authored
Update hpo with master (aws#17)
1 parent 453b6a8 commit 0f3f3c8

File tree

5 files changed

+40
-3
lines changed

5 files changed

+40
-3
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ CHANGELOG
1111
* bug-fix: tensorflow-serving-api: SageMaker does not conflict with tensorflow-serving-api module version
1212
* feature: Local Mode: add support for local training data using file://
1313
* feature: Updated TensorFlow Serving api protobuf files
14+
* bug-fix: No longer poll for logs from stopped training jobs
1415

1516
1.2.4
1617
=====

src/sagemaker/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sagemaker.model import Model
2626
from sagemaker.model import (SCRIPT_PARAM_NAME, DIR_PARAM_NAME, CLOUDWATCH_METRICS_PARAM_NAME,
2727
CONTAINER_LOG_LEVEL_PARAM_NAME, JOB_NAME_PARAM_NAME, SAGEMAKER_REGION_PARAM_NAME)
28+
2829
from sagemaker.predictor import RealTimePredictor
2930
from sagemaker.session import Session
3031
from sagemaker.session import s3_input

src/sagemaker/session.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def _check_job_status(self, job, desc):
480480
"""
481481
status = desc['TrainingJobStatus']
482482

483-
if status != 'Completed':
483+
if status != 'Completed' and status != 'Stopped':
484484
reason = desc.get('FailureReason', '(No reason provided)')
485485
raise ValueError('Error training {}: {} Reason: {}'.format(job, status, reason))
486486

@@ -666,7 +666,7 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
666666
client = self.boto_session.client('logs', config=config)
667667
log_group = '/aws/sagemaker/TrainingJobs'
668668

669-
job_already_completed = True if status == 'Completed' or status == 'Failed' else False
669+
job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False
670670

671671
state = LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
672672
dot = False
@@ -738,7 +738,7 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
738738

739739
status = description['TrainingJobStatus']
740740

741-
if status == 'Completed' or status == 'Failed':
741+
if status == 'Completed' or status == 'Failed' or status == 'Stopped':
742742
state = LogState.JOB_COMPLETE
743743

744744
if wait:

tests/unit/test_estimator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,11 @@ def test_wait_with_logs(sagemaker_session):
402402
assert not sagemaker_session.wait_for_job.called
403403

404404

405+
def test_unsupported_type_in_dict():
406+
with pytest.raises(ValueError):
407+
_TrainingJob._format_inputs_to_input_config({'a': 66})
408+
409+
405410
#################################################################################
406411
# Tests for the generic Estimator class
407412

tests/unit/test_session.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ def test_s3_input_all_arguments():
185185
{'TrainingStartTime': datetime.datetime(2018, 2, 17, 7, 15, 0, 103000)})
186186
COMPLETED_DESCRIBE_JOB_RESULT.update(
187187
{'TrainingEndTime': datetime.datetime(2018, 2, 17, 7, 19, 34, 953000)})
188+
189+
STOPPED_DESCRIBE_JOB_RESULT = dict(COMPLETED_DESCRIBE_JOB_RESULT)
190+
STOPPED_DESCRIBE_JOB_RESULT.update({'TrainingJobStatus': 'Stopped'})
191+
188192
IN_PROGRESS_DESCRIBE_JOB_RESULT = dict(DEFAULT_EXPECTED_TRAIN_JOB_ARGS)
189193
IN_PROGRESS_DESCRIBE_JOB_RESULT.update({'TrainingJobStatus': 'InProgress'})
190194

@@ -270,6 +274,16 @@ def sagemaker_session_complete():
270274
return ims
271275

272276

277+
@pytest.fixture()
278+
def sagemaker_session_stopped():
279+
boto_mock = Mock(name='boto_session')
280+
boto_mock.client('logs').describe_log_streams.return_value = DEFAULT_LOG_STREAMS
281+
boto_mock.client('logs').get_log_events.side_effect = DEFAULT_LOG_EVENTS
282+
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())
283+
ims.sagemaker_client.describe_training_job.return_value = STOPPED_DESCRIBE_JOB_RESULT
284+
return ims
285+
286+
273287
@pytest.fixture()
274288
def sagemaker_session_ready_lifecycle():
275289
boto_mock = Mock(name='boto_session')
@@ -302,6 +316,14 @@ def test_logs_for_job_no_wait(cw, sagemaker_session_complete):
302316
cw().assert_called_with(0, 'hi there #1')
303317

304318

319+
@patch('sagemaker.logs.ColorWrap')
320+
def test_logs_for_job_no_wait_stopped_job(cw, sagemaker_session_stopped):
321+
ims = sagemaker_session_stopped
322+
ims.logs_for_job(JOB_NAME)
323+
ims.sagemaker_client.describe_training_job.assert_called_once_with(TrainingJobName=JOB_NAME)
324+
cw().assert_called_with(0, 'hi there #1')
325+
326+
305327
@patch('sagemaker.logs.ColorWrap')
306328
def test_logs_for_job_wait_on_completed(cw, sagemaker_session_complete):
307329
ims = sagemaker_session_complete
@@ -310,6 +332,14 @@ def test_logs_for_job_wait_on_completed(cw, sagemaker_session_complete):
310332
cw().assert_called_with(0, 'hi there #1')
311333

312334

335+
@patch('sagemaker.logs.ColorWrap')
336+
def test_logs_for_job_wait_on_stopped(cw, sagemaker_session_stopped):
337+
ims = sagemaker_session_stopped
338+
ims.logs_for_job(JOB_NAME, wait=True, poll=0)
339+
assert ims.sagemaker_client.describe_training_job.call_args_list == [call(TrainingJobName=JOB_NAME,)]
340+
cw().assert_called_with(0, 'hi there #1')
341+
342+
313343
@patch('sagemaker.logs.ColorWrap')
314344
def test_logs_for_job_no_wait_on_running(cw, sagemaker_session_ready_lifecycle):
315345
ims = sagemaker_session_ready_lifecycle

0 commit comments

Comments
 (0)