Skip to content

Commit ee80650

Browse files
authored
Merge pull request #10 from aws/mvs-poll-every-30-s
describing training job call every 30 seconds
2 parents f8b90f2 + 6e98ec5 commit ee80650

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

src/sagemaker/session.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def logs_for_job(self, job_name, wait=False, poll=5): # noqa: C901 - suppress c
605605
# Notes:
606606
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to Cloudwatch after
607607
# the job was marked complete.
608-
608+
last_describe_job_call = time.time()
609609
while True:
610610
if len(stream_names) < instance_count:
611611
# Log streams are created whenever a container starts writing to stdout/err, so this list
@@ -645,8 +645,10 @@ def logs_for_job(self, job_name, wait=False, poll=5): # noqa: C901 - suppress c
645645

646646
if state == LogState.JOB_COMPLETE:
647647
state = LogState.COMPLETE
648-
else:
648+
elif time.time() - last_describe_job_call >= 30:
649649
description = self.sagemaker_client.describe_training_job(TrainingJobName=job_name)
650+
last_describe_job_call = time.time()
651+
650652
status = description['TrainingJobStatus']
651653

652654
if status == 'Completed' or status == 'Failed':

tests/unit/test_session.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,8 @@ def test_logs_for_job_no_wait_on_running(cw, sagemaker_session_ready_lifecycle):
311311

312312

313313
@patch('sagemaker.logs.ColorWrap')
314-
def test_logs_for_job_full_lifecycle(cw, sagemaker_session_full_lifecycle):
314+
@patch('time.time', side_effect=[0, 30, 60, 90, 120, 150, 180])
315+
def test_logs_for_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle):
315316
ims = sagemaker_session_full_lifecycle
316317
ims.logs_for_job(JOB_NAME, wait=True, poll=0)
317318
assert ims.sagemaker_client.describe_training_job.call_args_list == [call(TrainingJobName=JOB_NAME,)] * 3

0 commit comments

Comments
 (0)