Skip to content

Commit 5cd7537

Browse files
Namrata Madanknikure
authored andcommitted
Revert "fix: make RemoteExecutor context manager non-blocking on pending futures (aws#3822)"
This reverts commit 5f40087.
1 parent 54db09b commit 5cd7537

File tree

2 files changed

+14
-19
lines changed

2 files changed

+14
-19
lines changed

src/sagemaker/remote_function/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ def map(self, func, *iterables):
745745
futures = map(self.submit, itertools.repeat(func), *iterables)
746746
return [future.result() for future in futures]
747747

748-
def shutdown(self, wait=True):
748+
def shutdown(self):
749749
"""Prevent more function executions to be submitted to this executor."""
750750
with self._state_condition:
751751
self._shutdown = True
@@ -756,15 +756,15 @@ def shutdown(self, wait=True):
756756
self._state_condition.notify_all()
757757

758758
if self._workers is not None:
759-
self._workers.shutdown(wait)
759+
self._workers.shutdown(wait=True)
760760

761761
def __enter__(self):
762762
"""Create an executor instance and return it"""
763763
return self
764764

765765
def __exit__(self, exc_type, exc_val, exc_tb):
766766
"""Make sure the executor instance is shutdown."""
767-
self.shutdown(wait=False)
767+
self.shutdown()
768768
return False
769769

770770
@staticmethod

tests/unit/sagemaker/remote_function/test_client.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -518,11 +518,6 @@ def test_executor_submit_happy_case(mock_start, mock_job_settings, parallelism):
518518
future_3 = e.submit(job_function, 9, 10, c=11, d=12)
519519
future_4 = e.submit(job_function, 13, 14, c=15, d=16)
520520

521-
future_1.wait()
522-
future_2.wait()
523-
future_3.wait()
524-
future_4.wait()
525-
526521
mock_start.assert_has_calls(
527522
[
528523
call(ANY, job_function, (1, 2), {"c": 3, "d": 4}, None),
@@ -531,6 +526,10 @@ def test_executor_submit_happy_case(mock_start, mock_job_settings, parallelism):
531526
call(ANY, job_function, (13, 14), {"c": 15, "d": 16}, None),
532527
]
533528
)
529+
mock_job_1.describe.assert_called()
530+
mock_job_2.describe.assert_called()
531+
mock_job_3.describe.assert_called()
532+
mock_job_4.describe.assert_called()
534533

535534
assert future_1.done()
536535
assert future_2.done()
@@ -555,15 +554,14 @@ def test_executor_submit_with_run(mock_start, mock_job_settings, run_obj):
555554
future_1 = e.submit(job_function, 1, 2, c=3, d=4)
556555
future_2 = e.submit(job_function, 5, 6, c=7, d=8)
557556

558-
future_1.wait()
559-
future_2.wait()
560-
561557
mock_start.assert_has_calls(
562558
[
563559
call(ANY, job_function, (1, 2), {"c": 3, "d": 4}, run_info),
564560
call(ANY, job_function, (5, 6), {"c": 7, "d": 8}, run_info),
565561
]
566562
)
563+
mock_job_1.describe.assert_called()
564+
mock_job_2.describe.assert_called()
567565

568566
assert future_1.done()
569567
assert future_2.done()
@@ -573,15 +571,14 @@ def test_executor_submit_with_run(mock_start, mock_job_settings, run_obj):
573571
future_3 = e.submit(job_function, 9, 10, c=11, d=12)
574572
future_4 = e.submit(job_function, 13, 14, c=15, d=16)
575573

576-
future_3.wait()
577-
future_4.wait()
578-
579574
mock_start.assert_has_calls(
580575
[
581576
call(ANY, job_function, (9, 10), {"c": 11, "d": 12}, run_info),
582577
call(ANY, job_function, (13, 14), {"c": 15, "d": 16}, run_info),
583578
]
584579
)
580+
mock_job_3.describe.assert_called()
581+
mock_job_4.describe.assert_called()
585582

586583
assert future_3.done()
587584
assert future_4.done()
@@ -633,7 +630,7 @@ def test_executor_fails_to_start_job(mock_start, *args):
633630

634631
with pytest.raises(TypeError):
635632
future_1.result()
636-
future_2.wait()
633+
print(future_2._state)
637634
assert future_2.done()
638635

639636

@@ -698,8 +695,6 @@ def test_executor_describe_job_throttled_temporarily(mock_start, *args):
698695
# submit second job
699696
future_2 = e.submit(job_function, 5, 6, c=7, d=8)
700697

701-
future_1.wait()
702-
future_2.wait()
703698
assert future_1.done()
704699
assert future_2.done()
705700

@@ -719,9 +714,9 @@ def test_executor_describe_job_failed_permanently(mock_start, *args):
719714
future_2 = e.submit(job_function, 5, 6, c=7, d=8)
720715

721716
with pytest.raises(RuntimeError):
722-
future_1.result()
717+
future_1.done()
723718
with pytest.raises(RuntimeError):
724-
future_2.result()
719+
future_2.done()
725720

726721

727722
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)