Skip to content

Commit cb91c9a

Browse files
ZhankuilNamrata Madan
authored andcommitted
pathway: Fix race conditions in the RemoteExecutor
There is a race condition which may cause the polling worker to abort early: it checks the number of running jobs as the exist criteria, however it is possible that the running job list is empty because the submit worker is slow to put the pending job to the running job list.
1 parent df0c0f1 commit cb91c9a

File tree

4 files changed

+108
-47
lines changed

4 files changed

+108
-47
lines changed

src/sagemaker/remote_function/client.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from concurrent.futures import ThreadPoolExecutor
17-
import queue
17+
from collections import deque
1818
import time
1919
import threading
2020
from typing import Dict, List, Tuple, Any
@@ -171,26 +171,36 @@ def __init__(self, future, job_settings: _JobSettings, func, func_args, func_kwa
171171

172172
def _submit_worker(executor):
173173
"""Background worker that submits job requests."""
174+
175+
def has_work_to_do():
176+
return (
177+
len(executor._pending_request_queue) > 0
178+
and len(executor._running_jobs) < executor.max_parallel_job
179+
)
180+
174181
try:
175182
while True:
176-
request = executor._pending_request_queue.get(block=True)
183+
with executor._state_condition:
184+
executor._state_condition.wait_for(has_work_to_do)
185+
request = executor._pending_request_queue[0]
177186

178187
if request is None:
188+
with executor._state_condition:
189+
# remove the anchor from the pending queue
190+
executor._pending_request_queue.popleft()
179191
return
180192

181-
executor._semaphore.acquire(blocking=True)
182-
183193
time.sleep(_API_CALL_LIMIT["SubmittingIntervalInSecs"])
184194
# submit a new job
185195
job = request.future._start_and_notify(
186196
request.job_settings, request.func, request.args, request.kwargs
187197
)
188198

189-
if job is None:
190-
# job fails to submit
191-
executor._semaphore.release()
192-
else:
193-
executor._running_jobs[job.job_name] = job
199+
with executor._state_condition:
200+
if job:
201+
executor._running_jobs[job.job_name] = job
202+
# remove the request from the pending queue
203+
executor._pending_request_queue.popleft()
194204
except Exception: # pylint: disable=broad-except
195205
logger.exception("Error occurred while submitting CreateTrainingJob requests.")
196206

@@ -199,8 +209,12 @@ def _polling_worker(executor):
199209
"""Background worker that polls the status of the running jobs."""
200210
try:
201211
while True:
202-
if executor._shutdown and len(executor._running_jobs) == 0:
203-
return
212+
with executor._state_condition:
213+
if (
214+
executor._shutdown
215+
and len(executor._running_jobs) + len(executor._pending_request_queue) == 0
216+
):
217+
return
204218

205219
time.sleep(
206220
max(
@@ -211,16 +225,17 @@ def _polling_worker(executor):
211225
)
212226

213227
# check if running jobs are terminated
214-
for job_name in executor._running_jobs.keys():
228+
for job_name in list(executor._running_jobs.keys()):
215229
try:
216230
time.sleep(_API_CALL_LIMIT["PollingIntervalInSecs"])
217231
if executor._running_jobs[job_name].describe()["TrainingJobStatus"] in [
218232
"Completed",
219233
"Failed",
220234
"Stopped",
221235
]:
222-
del executor._running_jobs[job_name]
223-
executor._semaphore.release()
236+
with executor._state_condition:
237+
del executor._running_jobs[job_name]
238+
executor._state_condition.notify_all()
224239
except Exception as e: # pylint: disable=broad-except
225240
if (
226241
not isinstance(e, ClientError)
@@ -231,8 +246,9 @@ def _polling_worker(executor):
231246
logger.exception(
232247
"Error occurred while checking the status of job %s", job_name
233248
)
234-
del executor._running_jobs[job_name]
235-
executor._semaphore.release()
249+
with executor._state_condition:
250+
del executor._running_jobs[job_name]
251+
executor._state_condition.notify_all()
236252
except Exception: # pylint: disable=broad-except
237253
logger.exception("Error occurred while monitoring the job statuses.")
238254

@@ -331,14 +347,14 @@ def __init__(
331347
volume_size=volume_size,
332348
)
333349

334-
self._pending_request_queue = queue.SimpleQueue()
335-
self._semaphore = threading.BoundedSemaphore(self.max_parallel_job)
350+
self._state_condition = threading.Condition()
351+
self._pending_request_queue = deque()
336352
# For thread safety, see
337353
# https://web.archive.org/web/20201108091210/http://effbot.org/pyfaq/what-kinds-of-global-value-mutation-are-thread-safe.htm
338354
self._running_jobs = dict()
355+
self._shutdown = False
339356

340357
self._workers: ThreadPoolExecutor = None
341-
self._shutdown = False
342358

343359
def submit(self, func, *args, **kwargs):
344360
"""Execute the input function as a SageMaker job asynchronously.
@@ -353,24 +369,30 @@ def submit(self, func, *args, **kwargs):
353369

354370
self._validate_submit_args(func, *args, **kwargs)
355371

356-
future = Future()
357-
self._pending_request_queue.put(
358-
_SubmitRequest(future, self.job_settings, func, args, kwargs)
359-
)
372+
with self._state_condition:
373+
future = Future()
374+
self._pending_request_queue.append(
375+
_SubmitRequest(future, self.job_settings, func, args, kwargs)
376+
)
377+
378+
if self._workers is None:
379+
self._workers = ThreadPoolExecutor(2)
380+
self._workers.submit(_submit_worker, self)
381+
self._workers.submit(_polling_worker, self)
360382

361-
if self._workers is None:
362-
self._workers = ThreadPoolExecutor(2)
363-
self._workers.submit(_submit_worker, self)
364-
self._workers.submit(_polling_worker, self)
383+
self._state_condition.notify_all()
365384

366385
return future
367386

368387
def shutdown(self):
369388
"""Prevent more function executions to be submitted to this executor."""
370-
self._shutdown = True
389+
with self._state_condition:
390+
self._shutdown = True
391+
392+
# give a signal to the submitting worker so that it doesn't block on empty queue forever
393+
self._pending_request_queue.append(None)
371394

372-
# give a signal to the submitting worker so that it doesn't block on empty queue forever
373-
self._pending_request_queue.put(None)
395+
self._state_condition.notify_all()
374396

375397
if self._workers is not None:
376398
self._workers.shutdown(wait=True)

src/sagemaker/remote_function/core/stored_function.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,20 @@ def save(self, func, source_dir=None, *args, **kwargs):
5656
"""
5757
if source_dir:
5858
self._zip_and_upload_source_dir(source_dir)
59+
60+
logger.info(
61+
f"Serializing function code to {s3_path_join(self.s3_base_uri, 'function.pkl')}"
62+
)
5963
serialization.serialize_func_to_s3(
6064
func,
6165
self.sagemaker_session,
6266
s3_path_join(self.s3_base_uri, "function.pkl"),
6367
self.s3_kms_key,
6468
)
69+
70+
logger.info(
71+
f"Serializing function arguments to {s3_path_join(self.s3_base_uri, 'arguments.pkl')}"
72+
)
6573
serialization.serialize_obj_to_s3(
6674
(args, kwargs),
6775
self.sagemaker_session,

src/sagemaker/remote_function/job.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,15 @@
2222
from sagemaker import vpc_utils
2323
from sagemaker.remote_function.core.stored_function import StoredFunction
2424
from sagemaker.remote_function.core.runtime_environment import RuntimeEnvironmentManager
25+
from sagemaker.remote_function import logging_config
2526

2627

2728
JOBS_CONTAINER_ENTRYPOINT = ["invoke-remote-function"]
2829

2930

31+
logger = logging_config.get_logger()
32+
33+
3034
# TODO: extend this class to load job settings from the configuration files.
3135
class _JobSettings:
3236
"""Helper class that processes the job settings.
@@ -191,6 +195,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs):
191195
if job_settings.environment_variables:
192196
request_dict["Environment"] = job_settings.environment_variables
193197

198+
logger.info("Creating job: %s", job_name)
194199
job_settings.sagemaker_session.sagemaker_client.create_training_job(**request_dict)
195200

196201
return _Job(job_name, job_settings)

tests/unit/sagemaker/remote_function/test_client.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@ def describe_training_job_response(job_status):
5555
}
5656

5757

58+
def create_mock_job(job_name, describe_return):
59+
mock_job = Mock()
60+
mock_job.describe.return_value = describe_return
61+
mock_job.job_name = job_name
62+
return mock_job
63+
64+
5865
def job_function(a, b=1, *, c, d=3):
5966
return a * b * c * d
6067

@@ -179,37 +186,49 @@ def test_executor_submit_after_shutdown(*args):
179186
e.submit(job_function, 1, 2, c=3, d=4)
180187

181188

189+
@pytest.mark.parametrize("parallelism", [1, 2, 3, 4])
182190
@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT)
183191
@patch("sagemaker.remote_function.client._JobSettings")
184192
@patch("sagemaker.remote_function.client._Job.start")
185-
def test_executor_submit_happy_case(mock_start, *args):
186-
mock_job = Mock()
187-
mock_job.describe.return_value = COMPLETED_TRAINING_JOB
188-
mock_job.job_name = TRAINING_JOB_NAME
189-
mock_start.return_value = mock_job
190-
191-
with RemoteExecutor(max_parallel_job=2, s3_root_uri="s3://bucket/") as e:
193+
def test_executor_submit_happy_case(mock_start, MockJobSetting, parallelism):
194+
mock_job_1 = create_mock_job("job_1", COMPLETED_TRAINING_JOB)
195+
mock_job_2 = create_mock_job("job_2", COMPLETED_TRAINING_JOB)
196+
mock_job_3 = create_mock_job("job_3", COMPLETED_TRAINING_JOB)
197+
mock_job_4 = create_mock_job("job_4", COMPLETED_TRAINING_JOB)
198+
mock_start.side_effect = [mock_job_1, mock_job_2, mock_job_3, mock_job_4]
199+
200+
with RemoteExecutor(max_parallel_job=parallelism, s3_root_uri="s3://bucket/") as e:
192201
future_1 = e.submit(job_function, 1, 2, c=3, d=4)
193202
future_2 = e.submit(job_function, 5, 6, c=7, d=8)
203+
future_3 = e.submit(job_function, 9, 10, c=11, d=12)
204+
future_4 = e.submit(job_function, 13, 14, c=15, d=16)
194205

195206
mock_start.assert_has_calls(
196207
[
197208
call(ANY, job_function, (1, 2), {"c": 3, "d": 4}),
198209
call(ANY, job_function, (5, 6), {"c": 7, "d": 8}),
210+
call(ANY, job_function, (9, 10), {"c": 11, "d": 12}),
211+
call(ANY, job_function, (13, 14), {"c": 15, "d": 16}),
199212
]
200213
)
214+
mock_job_1.describe.assert_called()
215+
mock_job_2.describe.assert_called()
216+
mock_job_3.describe.assert_called()
217+
mock_job_4.describe.assert_called()
218+
201219
assert future_1.done()
202220
assert future_2.done()
203-
mock_job.describe.assert_called()
221+
assert future_3.done()
222+
assert future_4.done()
204223

205224

206225
@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT)
207226
@patch("sagemaker.remote_function.client._JobSettings")
208227
@patch("sagemaker.remote_function.client._Job.start")
209228
def test_executor_submit_enforcing_max_parallel_jobs(mock_start, *args):
210-
mock_job = Mock()
211-
mock_job.describe.return_value = INPROGRESS_TRAINING_JOB
212-
mock_start.return_value = mock_job
229+
mock_job_1 = create_mock_job("job_1", INPROGRESS_TRAINING_JOB)
230+
mock_job_2 = create_mock_job("job_2", INPROGRESS_TRAINING_JOB)
231+
mock_start.side_effect = [mock_job_1, mock_job_2]
213232

214233
e = RemoteExecutor(max_parallel_job=1, s3_root_uri="s3://bucket/")
215234
future_1 = e.submit(job_function, 1, 2, c=3, d=4)
@@ -221,10 +240,15 @@ def test_executor_submit_enforcing_max_parallel_jobs(mock_start, *args):
221240
assert not future_2.running()
222241
mock_start.assert_called_with(ANY, job_function, (1, 2), {"c": 3, "d": 4})
223242

224-
mock_job.describe.return_value = COMPLETED_TRAINING_JOB
243+
mock_job_1.describe.return_value = COMPLETED_TRAINING_JOB
244+
mock_job_2.describe.return_value = COMPLETED_TRAINING_JOB
245+
225246
e.shutdown()
226247

227248
mock_start.assert_called_with(ANY, job_function, (5, 6), {"c": 7, "d": 8})
249+
mock_job_1.describe.assert_called()
250+
mock_job_2.describe.assert_called()
251+
228252
assert future_1.done()
229253
assert future_2.done()
230254

@@ -252,9 +276,9 @@ def test_executor_fails_to_start_job(mock_start, *args):
252276
@patch("sagemaker.remote_function.client._JobSettings")
253277
@patch("sagemaker.remote_function.client._Job.start")
254278
def test_executor_submit_and_cancel(mock_start, *args):
255-
mock_job = Mock()
256-
mock_job.describe.return_value = INPROGRESS_TRAINING_JOB
257-
mock_start.return_value = mock_job
279+
mock_job_1 = create_mock_job("job_1", INPROGRESS_TRAINING_JOB)
280+
mock_job_2 = create_mock_job("job_2", INPROGRESS_TRAINING_JOB)
281+
mock_start.side_effect = [mock_job_1, mock_job_2]
258282

259283
e = RemoteExecutor(max_parallel_job=1, s3_root_uri="s3://bucket/")
260284

@@ -266,12 +290,14 @@ def test_executor_submit_and_cancel(mock_start, *args):
266290
future_2.cancel()
267291

268292
# let the first job complete
269-
mock_job.describe.return_value = COMPLETED_TRAINING_JOB
293+
mock_job_1.describe.return_value = COMPLETED_TRAINING_JOB
270294
e.shutdown()
271295

296+
mock_start.assert_called_once_with(ANY, job_function, (1, 2), {"c": 3, "d": 4})
297+
mock_job_1.describe.assert_called()
298+
272299
assert future_1.done()
273300
assert future_2.cancelled()
274-
mock_start.assert_called_once_with(ANY, job_function, (1, 2), {"c": 3, "d": 4})
275301

276302

277303
@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT)

0 commit comments

Comments
 (0)