Skip to content

Fix: checking stopped status when waiting for logs #185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ CHANGELOG
* bug-fix: tensorflow-serving-api: SageMaker does not conflict with tensorflow-serving-api module version
* feature: Local Mode: add support for local training data using file://
* feature: Updated TensorFlow Serving api protobuf files
* bug-fix: No longer poll for logs from stopped training jobs

1.2.4
=====
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def _check_job_status(self, job, desc):
"""
status = desc['TrainingJobStatus']

if status != 'Completed':
if status != 'Completed' and status != 'Stopped':
reason = desc.get('FailureReason', '(No reason provided)')
raise ValueError('Error training {}: {} Reason: {}'.format(job, status, reason))

Expand Down Expand Up @@ -590,7 +590,7 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress
client = self.boto_session.client('logs', config=config)
log_group = '/aws/sagemaker/TrainingJobs'

job_already_completed = True if status == 'Completed' or status == 'Failed' else False
job_already_completed = True if status == 'Completed' or status == 'Failed' or status == 'Stopped' else False

state = LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
dot = False
Expand Down Expand Up @@ -662,7 +662,7 @@ def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress

status = description['TrainingJobStatus']

if status == 'Completed' or status == 'Failed':
if status == 'Completed' or status == 'Failed' or status == 'Stopped':
state = LogState.JOB_COMPLETE

if wait:
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ def test_s3_input_all_arguments():
{'TrainingStartTime': datetime.datetime(2018, 2, 17, 7, 15, 0, 103000)})
COMPLETED_DESCRIBE_JOB_RESULT.update(
{'TrainingEndTime': datetime.datetime(2018, 2, 17, 7, 19, 34, 953000)})

STOPPED_DESCRIBE_JOB_RESULT = dict(COMPLETED_DESCRIBE_JOB_RESULT)
STOPPED_DESCRIBE_JOB_RESULT.update({'TrainingJobStatus': 'Stopped'})

IN_PROGRESS_DESCRIBE_JOB_RESULT = dict(DEFAULT_EXPECTED_TRAIN_JOB_ARGS)
IN_PROGRESS_DESCRIBE_JOB_RESULT.update({'TrainingJobStatus': 'InProgress'})

Expand Down Expand Up @@ -270,6 +274,16 @@ def sagemaker_session_complete():
return ims


@pytest.fixture()
def sagemaker_session_stopped():
boto_mock = Mock(name='boto_session')
boto_mock.client('logs').describe_log_streams.return_value = DEFAULT_LOG_STREAMS
boto_mock.client('logs').get_log_events.side_effect = DEFAULT_LOG_EVENTS
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())
ims.sagemaker_client.describe_training_job.return_value = STOPPED_DESCRIBE_JOB_RESULT
return ims


@pytest.fixture()
def sagemaker_session_ready_lifecycle():
boto_mock = Mock(name='boto_session')
Expand Down Expand Up @@ -302,6 +316,14 @@ def test_logs_for_job_no_wait(cw, sagemaker_session_complete):
cw().assert_called_with(0, 'hi there #1')


@patch('sagemaker.logs.ColorWrap')
def test_logs_for_job_no_wait_stopped_job(cw, sagemaker_session_stopped):
ims = sagemaker_session_stopped
ims.logs_for_job(JOB_NAME)
ims.sagemaker_client.describe_training_job.assert_called_once_with(TrainingJobName=JOB_NAME)
cw().assert_called_with(0, 'hi there #1')


@patch('sagemaker.logs.ColorWrap')
def test_logs_for_job_wait_on_completed(cw, sagemaker_session_complete):
ims = sagemaker_session_complete
Expand All @@ -310,6 +332,14 @@ def test_logs_for_job_wait_on_completed(cw, sagemaker_session_complete):
cw().assert_called_with(0, 'hi there #1')


@patch('sagemaker.logs.ColorWrap')
def test_logs_for_job_wait_on_stopped(cw, sagemaker_session_stopped):
ims = sagemaker_session_stopped
ims.logs_for_job(JOB_NAME, wait=True, poll=0)
assert ims.sagemaker_client.describe_training_job.call_args_list == [call(TrainingJobName=JOB_NAME,)]
cw().assert_called_with(0, 'hi there #1')


@patch('sagemaker.logs.ColorWrap')
def test_logs_for_job_no_wait_on_running(cw, sagemaker_session_ready_lifecycle):
ims = sagemaker_session_ready_lifecycle
Expand Down