@@ -185,6 +185,10 @@ def test_s3_input_all_arguments():
185
185
{'TrainingStartTime' : datetime .datetime (2018 , 2 , 17 , 7 , 15 , 0 , 103000 )})
186
186
COMPLETED_DESCRIBE_JOB_RESULT .update (
187
187
{'TrainingEndTime' : datetime .datetime (2018 , 2 , 17 , 7 , 19 , 34 , 953000 )})
188
+
189
+ STOPPED_DESCRIBE_JOB_RESULT = dict (COMPLETED_DESCRIBE_JOB_RESULT )
190
+ STOPPED_DESCRIBE_JOB_RESULT .update ({'TrainingJobStatus' : 'Stopped' })
191
+
188
192
IN_PROGRESS_DESCRIBE_JOB_RESULT = dict (DEFAULT_EXPECTED_TRAIN_JOB_ARGS )
189
193
IN_PROGRESS_DESCRIBE_JOB_RESULT .update ({'TrainingJobStatus' : 'InProgress' })
190
194
@@ -270,6 +274,16 @@ def sagemaker_session_complete():
270
274
return ims
271
275
272
276
277
+ @pytest .fixture ()
278
+ def sagemaker_session_stopped ():
279
+ boto_mock = Mock (name = 'boto_session' )
280
+ boto_mock .client ('logs' ).describe_log_streams .return_value = DEFAULT_LOG_STREAMS
281
+ boto_mock .client ('logs' ).get_log_events .side_effect = DEFAULT_LOG_EVENTS
282
+ ims = sagemaker .Session (boto_session = boto_mock , sagemaker_client = Mock ())
283
+ ims .sagemaker_client .describe_training_job .return_value = STOPPED_DESCRIBE_JOB_RESULT
284
+ return ims
285
+
286
+
273
287
@pytest .fixture ()
274
288
def sagemaker_session_ready_lifecycle ():
275
289
boto_mock = Mock (name = 'boto_session' )
@@ -302,6 +316,14 @@ def test_logs_for_job_no_wait(cw, sagemaker_session_complete):
302
316
cw ().assert_called_with (0 , 'hi there #1' )
303
317
304
318
319
+ @patch ('sagemaker.logs.ColorWrap' )
320
+ def test_logs_for_job_no_wait_stopped_job (cw , sagemaker_session_stopped ):
321
+ ims = sagemaker_session_stopped
322
+ ims .logs_for_job (JOB_NAME )
323
+ ims .sagemaker_client .describe_training_job .assert_called_once_with (TrainingJobName = JOB_NAME )
324
+ cw ().assert_called_with (0 , 'hi there #1' )
325
+
326
+
305
327
@patch ('sagemaker.logs.ColorWrap' )
306
328
def test_logs_for_job_wait_on_completed (cw , sagemaker_session_complete ):
307
329
ims = sagemaker_session_complete
@@ -310,6 +332,14 @@ def test_logs_for_job_wait_on_completed(cw, sagemaker_session_complete):
310
332
cw ().assert_called_with (0 , 'hi there #1' )
311
333
312
334
335
+ @patch ('sagemaker.logs.ColorWrap' )
336
+ def test_logs_for_job_wait_on_stopped (cw , sagemaker_session_stopped ):
337
+ ims = sagemaker_session_stopped
338
+ ims .logs_for_job (JOB_NAME , wait = True , poll = 0 )
339
+ assert ims .sagemaker_client .describe_training_job .call_args_list == [call (TrainingJobName = JOB_NAME ,)]
340
+ cw ().assert_called_with (0 , 'hi there #1' )
341
+
342
+
313
343
@patch ('sagemaker.logs.ColorWrap' )
314
344
def test_logs_for_job_no_wait_on_running (cw , sagemaker_session_ready_lifecycle ):
315
345
ims = sagemaker_session_ready_lifecycle
0 commit comments