Skip to content

Commit d1dcfaa

Browse files
ZhankuilNamrata Madan
authored andcommitted
pathway: fix executor's incompatibility with Python 3.7
The Semophore$release API has changed since python 3.9.
1 parent d710dd7 commit d1dcfaa

File tree

2 files changed

+125
-38
lines changed

2 files changed

+125
-38
lines changed

src/sagemaker/remote_function/client.py

Lines changed: 67 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,22 @@
2020
from typing import Dict, List, Tuple, Any
2121
import functools
2222
import inspect
23+
import logging
24+
25+
from botocore.exceptions import ClientError
2326

2427
import sagemaker.remote_function.core.serialization as serialization
2528

2629
from sagemaker.session import Session
2730
from sagemaker.s3 import s3_path_join
2831
from sagemaker.remote_function.job import _JobSettings, _Job
2932

30-
_POLLING_INTERVAL_IN_SECS = 10
33+
34+
_API_CALL_LIMIT = {
35+
"SubmittingIntervalInSecs": 1,
36+
"MinBatchPollingIntervalInSecs": 10,
37+
"PollingIntervalInSecs": 0.5,
38+
}
3139

3240
# Possible future states.
3341
_PENDING = "PENDING"
@@ -36,6 +44,9 @@
3644
_CANCELLED = "CANCELLED"
3745
_FINISHED = "FINISHED"
3846

47+
LOGGER = logging.getLogger(__name__)
48+
LOGGER.setLevel(logging.INFO)
49+
3950

4051
def remote(
4152
_func=None,
@@ -154,48 +165,73 @@ def __init__(self, future, job_settings: _JobSettings, func, func_args, func_kwa
154165

155166
def _submit_worker(executor):
156167
"""Background worker that submits job requests."""
157-
while True:
158-
request = executor._pending_request_queue.get(block=True)
168+
try:
169+
while True:
170+
request = executor._pending_request_queue.get(block=True)
159171

160-
if request is None:
161-
return
172+
if request is None:
173+
return
162174

163-
executor._semaphore.acquire(blocking=True)
175+
executor._semaphore.acquire(blocking=True)
164176

165-
# submit a new job
166-
job = request.future._start_and_notify(
167-
request.job_settings, request.func, request.args, request.kwargs
168-
)
177+
time.sleep(_API_CALL_LIMIT["SubmittingIntervalInSecs"])
178+
# submit a new job
179+
job = request.future._start_and_notify(
180+
request.job_settings, request.func, request.args, request.kwargs
181+
)
169182

170-
if job is None:
171-
# job fails to submit
172-
executor._semaphore.release(1)
173-
else:
174-
executor._running_jobs[job.job_name] = job
183+
if job is None:
184+
# job fails to submit
185+
executor._semaphore.release()
186+
else:
187+
executor._running_jobs[job.job_name] = job
188+
except Exception: # pylint: disable=broad-except
189+
LOGGER.exception("Error occurred while submitting CreateTrainingJob requests.")
175190

176191

177192
def _polling_worker(executor):
178193
"""Background worker that polls the status of the running jobs."""
179-
while True:
180-
if executor._shutdown and len(executor._running_jobs) == 0:
181-
return
182-
183-
time.sleep(_POLLING_INTERVAL_IN_SECS)
194+
try:
195+
while True:
196+
if executor._shutdown and len(executor._running_jobs) == 0:
197+
return
198+
199+
time.sleep(
200+
max(
201+
_API_CALL_LIMIT["MinBatchPollingIntervalInSecs"]
202+
- len(executor._running_jobs) * _API_CALL_LIMIT["PollingIntervalInSecs"],
203+
0,
204+
)
205+
)
184206

185-
# check if running jobs are terminated
186-
for job_name in executor._running_jobs.keys():
187-
if executor._running_jobs[job_name].describe()["TrainingJobStatus"] in [
188-
"Completed",
189-
"Failed",
190-
"Stopped",
191-
]:
192-
del executor._running_jobs[job_name]
193-
executor._semaphore.release(1)
207+
# check if running jobs are terminated
208+
for job_name in executor._running_jobs.keys():
209+
try:
210+
time.sleep(_API_CALL_LIMIT["PollingIntervalInSecs"])
211+
if executor._running_jobs[job_name].describe()["TrainingJobStatus"] in [
212+
"Completed",
213+
"Failed",
214+
"Stopped",
215+
]:
216+
del executor._running_jobs[job_name]
217+
executor._semaphore.release()
218+
except Exception as e: # pylint: disable=broad-except
219+
if (
220+
not isinstance(e, ClientError)
221+
or e.response["Error"]["Code"] # pylint: disable=no-member
222+
!= "LimitExceededException"
223+
):
224+
# Couldn't check the job status, move on
225+
LOGGER.exception(
226+
"Error occurred while checking the status of job %s", job_name
227+
)
228+
del executor._running_jobs[job_name]
229+
executor._semaphore.release()
230+
except Exception: # pylint: disable=broad-except
231+
LOGGER.exception("Error occurred while monitoring the job statuses.")
194232

195233

196234
# TODO: 1) add map method.
197-
# 2) in the background workers, limit rate of calls to CreateTrainingJob
198-
# and DescribeTrainingJob APIs
199235
class RemoteExecutor(object):
200236
"""Run Python functions asynchronously as SageMaker jobs"""
201237

tests/unit/sagemaker/remote_function/test_client.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pytest
1919
from mock import patch, Mock, ANY, call
2020

21+
from botocore.exceptions import ClientError
2122
from sagemaker.remote_function.client import remote, RemoteExecutor, Future
2223

2324
TRAINING_JOB_ARN = "training-job-arn"
@@ -47,6 +48,12 @@ def describe_training_job_response(job_status):
4748
CANCELLED_TRAINING_JOB = describe_training_job_response("Stopped")
4849
FAILED_TRAINING_JOB = describe_training_job_response("Failed")
4950

51+
API_CALL_LIMIT = {
52+
"SubmittingIntervalInSecs": 0.005,
53+
"MinBatchPollingIntervalInSecs": 0.01,
54+
"PollingIntervalInSecs": 0.01,
55+
}
56+
5057

5158
def job_function(a, b=1, *, c, d=3):
5259
return a * b * c * d
@@ -171,7 +178,7 @@ def test_executor_submit_after_shutdown():
171178
e.submit(job_function, 1, 2, c=3, d=4)
172179

173180

174-
@patch("sagemaker.remote_function.client._POLLING_INTERVAL_IN_SECS", new=0.01)
181+
@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT)
175182
@patch("sagemaker.remote_function.client._Job.start")
176183
def test_executor_submit_happy_case(mock_start):
177184
mock_job = Mock()
@@ -194,8 +201,7 @@ def test_executor_submit_happy_case(mock_start):
194201
mock_job.describe.assert_called()
195202

196203

197-
@pytest.mark.skip("This test hangs forever in py37")
198-
@patch("sagemaker.remote_function.client._POLLING_INTERVAL_IN_SECS", new=0.01)
204+
@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT)
199205
@patch("sagemaker.remote_function.client._Job.start")
200206
def test_executor_submit_enforcing_max_parallel_jobs(mock_start):
201207
mock_job = Mock()
@@ -220,8 +226,7 @@ def test_executor_submit_enforcing_max_parallel_jobs(mock_start):
220226
assert future_2.done()
221227

222228

223-
@pytest.mark.skip("This test fails in py37")
224-
@patch("sagemaker.remote_function.client._POLLING_INTERVAL_IN_SECS", new=0.01)
229+
@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT)
225230
@patch("sagemaker.remote_function.client._Job.start")
226231
def test_executor_fails_to_start_job(mock_start):
227232
mock_job = Mock()
@@ -239,8 +244,7 @@ def test_executor_fails_to_start_job(mock_start):
239244
assert future_2.done()
240245

241246

242-
@pytest.mark.skip("This test hangs forever in py37")
243-
@patch("sagemaker.remote_function.client._POLLING_INTERVAL_IN_SECS", new=0.01)
247+
@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT)
244248
@patch("sagemaker.remote_function.client._Job.start")
245249
def test_executor_submit_and_cancel(mock_start):
246250
mock_job = Mock()
@@ -265,6 +269,53 @@ def test_executor_submit_and_cancel(mock_start):
265269
mock_start.assert_called_once_with(ANY, job_function, (1, 2), {"c": 3, "d": 4})
266270

267271

272+
@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT)
273+
@patch("sagemaker.remote_function.client._Job.start")
274+
def test_executor_describe_job_throttled_temporarily(mock_start):
275+
throttling_error = ClientError(
276+
error_response={"Error": {"Code": "LimitExceededException"}},
277+
operation_name="SomeOperation",
278+
)
279+
mock_job = Mock()
280+
mock_job.describe.side_effect = [
281+
throttling_error,
282+
throttling_error,
283+
COMPLETED_TRAINING_JOB,
284+
COMPLETED_TRAINING_JOB,
285+
COMPLETED_TRAINING_JOB,
286+
COMPLETED_TRAINING_JOB,
287+
]
288+
mock_start.return_value = mock_job
289+
290+
with RemoteExecutor(max_parallel_job=1, s3_root_uri="s3://bucket/") as e:
291+
# submit first job
292+
future_1 = e.submit(job_function, 1, 2, c=3, d=4)
293+
# submit second job
294+
future_2 = e.submit(job_function, 5, 6, c=7, d=8)
295+
296+
assert future_1.done()
297+
assert future_2.done()
298+
299+
300+
@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT)
301+
@patch("sagemaker.remote_function.client._Job.start")
302+
def test_executor_describe_job_failed_permanently(mock_start):
303+
mock_job = Mock()
304+
mock_job.describe.side_effect = RuntimeError()
305+
mock_start.return_value = mock_job
306+
307+
with RemoteExecutor(max_parallel_job=1, s3_root_uri="s3://bucket/") as e:
308+
# submit first job
309+
future_1 = e.submit(job_function, 1, 2, c=3, d=4)
310+
# submit second job
311+
future_2 = e.submit(job_function, 5, 6, c=7, d=8)
312+
313+
with pytest.raises(RuntimeError):
314+
future_1.done()
315+
with pytest.raises(RuntimeError):
316+
future_2.done()
317+
318+
268319
@pytest.mark.parametrize(
269320
"args, kwargs, error_message",
270321
[

0 commit comments

Comments
 (0)