18
18
import pytest
19
19
from mock import patch , Mock , ANY , call
20
20
21
+ from botocore .exceptions import ClientError
21
22
from sagemaker .remote_function .client import remote , RemoteExecutor , Future
22
23
23
24
TRAINING_JOB_ARN = "training-job-arn"
@@ -47,6 +48,12 @@ def describe_training_job_response(job_status):
47
48
CANCELLED_TRAINING_JOB = describe_training_job_response ("Stopped" )
48
49
FAILED_TRAINING_JOB = describe_training_job_response ("Failed" )
49
50
51
+ API_CALL_LIMIT = {
52
+ "SubmittingIntervalInSecs" : 0.005 ,
53
+ "MinBatchPollingIntervalInSecs" : 0.01 ,
54
+ "PollingIntervalInSecs" : 0.01 ,
55
+ }
56
+
50
57
51
58
def job_function (a , b = 1 , * , c , d = 3 ):
52
59
return a * b * c * d
@@ -171,7 +178,7 @@ def test_executor_submit_after_shutdown():
171
178
e .submit (job_function , 1 , 2 , c = 3 , d = 4 )
172
179
173
180
174
- @patch ("sagemaker.remote_function.client._POLLING_INTERVAL_IN_SECS " , new = 0.01 )
181
+ @patch ("sagemaker.remote_function.client._API_CALL_LIMIT " , new = API_CALL_LIMIT )
175
182
@patch ("sagemaker.remote_function.client._Job.start" )
176
183
def test_executor_submit_happy_case (mock_start ):
177
184
mock_job = Mock ()
@@ -194,8 +201,7 @@ def test_executor_submit_happy_case(mock_start):
194
201
mock_job .describe .assert_called ()
195
202
196
203
197
- @pytest .mark .skip ("This test hangs forever in py37" )
198
- @patch ("sagemaker.remote_function.client._POLLING_INTERVAL_IN_SECS" , new = 0.01 )
204
+ @patch ("sagemaker.remote_function.client._API_CALL_LIMIT" , new = API_CALL_LIMIT )
199
205
@patch ("sagemaker.remote_function.client._Job.start" )
200
206
def test_executor_submit_enforcing_max_parallel_jobs (mock_start ):
201
207
mock_job = Mock ()
@@ -220,8 +226,7 @@ def test_executor_submit_enforcing_max_parallel_jobs(mock_start):
220
226
assert future_2 .done ()
221
227
222
228
223
- @pytest .mark .skip ("This test fails in py37" )
224
- @patch ("sagemaker.remote_function.client._POLLING_INTERVAL_IN_SECS" , new = 0.01 )
229
+ @patch ("sagemaker.remote_function.client._API_CALL_LIMIT" , new = API_CALL_LIMIT )
225
230
@patch ("sagemaker.remote_function.client._Job.start" )
226
231
def test_executor_fails_to_start_job (mock_start ):
227
232
mock_job = Mock ()
@@ -239,8 +244,7 @@ def test_executor_fails_to_start_job(mock_start):
239
244
assert future_2 .done ()
240
245
241
246
242
- @pytest .mark .skip ("This test hangs forever in py37" )
243
- @patch ("sagemaker.remote_function.client._POLLING_INTERVAL_IN_SECS" , new = 0.01 )
247
+ @patch ("sagemaker.remote_function.client._API_CALL_LIMIT" , new = API_CALL_LIMIT )
244
248
@patch ("sagemaker.remote_function.client._Job.start" )
245
249
def test_executor_submit_and_cancel (mock_start ):
246
250
mock_job = Mock ()
@@ -265,6 +269,53 @@ def test_executor_submit_and_cancel(mock_start):
265
269
mock_start .assert_called_once_with (ANY , job_function , (1 , 2 ), {"c" : 3 , "d" : 4 })
266
270
267
271
272
+ @patch ("sagemaker.remote_function.client._API_CALL_LIMIT" , new = API_CALL_LIMIT )
273
+ @patch ("sagemaker.remote_function.client._Job.start" )
274
+ def test_executor_describe_job_throttled_temporarily (mock_start ):
275
+ throttling_error = ClientError (
276
+ error_response = {"Error" : {"Code" : "LimitExceededException" }},
277
+ operation_name = "SomeOperation" ,
278
+ )
279
+ mock_job = Mock ()
280
+ mock_job .describe .side_effect = [
281
+ throttling_error ,
282
+ throttling_error ,
283
+ COMPLETED_TRAINING_JOB ,
284
+ COMPLETED_TRAINING_JOB ,
285
+ COMPLETED_TRAINING_JOB ,
286
+ COMPLETED_TRAINING_JOB ,
287
+ ]
288
+ mock_start .return_value = mock_job
289
+
290
+ with RemoteExecutor (max_parallel_job = 1 , s3_root_uri = "s3://bucket/" ) as e :
291
+ # submit first job
292
+ future_1 = e .submit (job_function , 1 , 2 , c = 3 , d = 4 )
293
+ # submit second job
294
+ future_2 = e .submit (job_function , 5 , 6 , c = 7 , d = 8 )
295
+
296
+ assert future_1 .done ()
297
+ assert future_2 .done ()
298
+
299
+
300
+ @patch ("sagemaker.remote_function.client._API_CALL_LIMIT" , new = API_CALL_LIMIT )
301
+ @patch ("sagemaker.remote_function.client._Job.start" )
302
+ def test_executor_describe_job_failed_permanently (mock_start ):
303
+ mock_job = Mock ()
304
+ mock_job .describe .side_effect = RuntimeError ()
305
+ mock_start .return_value = mock_job
306
+
307
+ with RemoteExecutor (max_parallel_job = 1 , s3_root_uri = "s3://bucket/" ) as e :
308
+ # submit first job
309
+ future_1 = e .submit (job_function , 1 , 2 , c = 3 , d = 4 )
310
+ # submit second job
311
+ future_2 = e .submit (job_function , 5 , 6 , c = 7 , d = 8 )
312
+
313
+ with pytest .raises (RuntimeError ):
314
+ future_1 .done ()
315
+ with pytest .raises (RuntimeError ):
316
+ future_2 .done ()
317
+
318
+
268
319
@pytest .mark .parametrize (
269
320
"args, kwargs, error_message" ,
270
321
[
0 commit comments