Skip to content

Commit 5f40087

Browse files
nmadanNamrata Madan
andauthored
fix: make RemoteExecutor context manager non-blocking on pending futures (#3822)
Co-authored-by: Namrata Madan <[email protected]>
1 parent 2ce5d91 commit 5f40087

File tree

2 files changed

+19
-14
lines changed

2 files changed

+19
-14
lines changed

src/sagemaker/remote_function/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ def map(self, func, *iterables):
731731
futures = map(self.submit, itertools.repeat(func), *iterables)
732732
return [future.result() for future in futures]
733733

734-
def shutdown(self):
734+
def shutdown(self, wait=True):
735735
"""Prevent more function executions to be submitted to this executor."""
736736
with self._state_condition:
737737
self._shutdown = True
@@ -742,15 +742,15 @@ def shutdown(self):
742742
self._state_condition.notify_all()
743743

744744
if self._workers is not None:
745-
self._workers.shutdown(wait=True)
745+
self._workers.shutdown(wait)
746746

747747
def __enter__(self):
748748
"""Create an executor instance and return it"""
749749
return self
750750

751751
def __exit__(self, exc_type, exc_val, exc_tb):
752752
"""Make sure the executor instance is shutdown."""
753-
self.shutdown()
753+
self.shutdown(wait=False)
754754
return False
755755

756756
@staticmethod

tests/unit/sagemaker/remote_function/test_client.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,11 @@ def test_executor_submit_happy_case(mock_start, mock_job_settings, parallelism):
509509
future_3 = e.submit(job_function, 9, 10, c=11, d=12)
510510
future_4 = e.submit(job_function, 13, 14, c=15, d=16)
511511

512+
future_1.wait()
513+
future_2.wait()
514+
future_3.wait()
515+
future_4.wait()
516+
512517
mock_start.assert_has_calls(
513518
[
514519
call(ANY, job_function, (1, 2), {"c": 3, "d": 4}, None),
@@ -517,10 +522,6 @@ def test_executor_submit_happy_case(mock_start, mock_job_settings, parallelism):
517522
call(ANY, job_function, (13, 14), {"c": 15, "d": 16}, None),
518523
]
519524
)
520-
mock_job_1.describe.assert_called()
521-
mock_job_2.describe.assert_called()
522-
mock_job_3.describe.assert_called()
523-
mock_job_4.describe.assert_called()
524525

525526
assert future_1.done()
526527
assert future_2.done()
@@ -545,14 +546,15 @@ def test_executor_submit_with_run(mock_start, mock_job_settings, run_obj):
545546
future_1 = e.submit(job_function, 1, 2, c=3, d=4)
546547
future_2 = e.submit(job_function, 5, 6, c=7, d=8)
547548

549+
future_1.wait()
550+
future_2.wait()
551+
548552
mock_start.assert_has_calls(
549553
[
550554
call(ANY, job_function, (1, 2), {"c": 3, "d": 4}, run_info),
551555
call(ANY, job_function, (5, 6), {"c": 7, "d": 8}, run_info),
552556
]
553557
)
554-
mock_job_1.describe.assert_called()
555-
mock_job_2.describe.assert_called()
556558

557559
assert future_1.done()
558560
assert future_2.done()
@@ -562,14 +564,15 @@ def test_executor_submit_with_run(mock_start, mock_job_settings, run_obj):
562564
future_3 = e.submit(job_function, 9, 10, c=11, d=12)
563565
future_4 = e.submit(job_function, 13, 14, c=15, d=16)
564566

567+
future_3.wait()
568+
future_4.wait()
569+
565570
mock_start.assert_has_calls(
566571
[
567572
call(ANY, job_function, (9, 10), {"c": 11, "d": 12}, run_info),
568573
call(ANY, job_function, (13, 14), {"c": 15, "d": 16}, run_info),
569574
]
570575
)
571-
mock_job_3.describe.assert_called()
572-
mock_job_4.describe.assert_called()
573576

574577
assert future_3.done()
575578
assert future_4.done()
@@ -621,7 +624,7 @@ def test_executor_fails_to_start_job(mock_start, *args):
621624

622625
with pytest.raises(TypeError):
623626
future_1.result()
624-
print(future_2._state)
627+
future_2.wait()
625628
assert future_2.done()
626629

627630

@@ -678,6 +681,8 @@ def test_executor_describe_job_throttled_temporarily(mock_start, *args):
678681
# submit second job
679682
future_2 = e.submit(job_function, 5, 6, c=7, d=8)
680683

684+
future_1.wait()
685+
future_2.wait()
681686
assert future_1.done()
682687
assert future_2.done()
683688

@@ -697,9 +702,9 @@ def test_executor_describe_job_failed_permanently(mock_start, *args):
697702
future_2 = e.submit(job_function, 5, 6, c=7, d=8)
698703

699704
with pytest.raises(RuntimeError):
700-
future_1.done()
705+
future_1.result()
701706
with pytest.raises(RuntimeError):
702-
future_2.done()
707+
future_2.result()
703708

704709

705710
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)