Skip to content

Commit de929f3

Browse files
ZeroIntensitygraingertJelleZijlstrawillingckumaraditya303
authored
gh-124309: Modernize the staggered_race implementation to support eager task factories (#124390)
Co-authored-by: Thomas Grainger <[email protected]> Co-authored-by: Jelle Zijlstra <[email protected]> Co-authored-by: Carol Willing <[email protected]> Co-authored-by: Kumar Aditya <[email protected]>
1 parent d929652 commit de929f3

File tree

5 files changed

+100
-66
lines changed

5 files changed

+100
-66
lines changed

Lib/asyncio/base_events.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1144,7 +1144,7 @@ async def create_connection(
11441144
(functools.partial(self._connect_sock,
11451145
exceptions, addrinfo, laddr_infos)
11461146
for addrinfo in infos),
1147-
happy_eyeballs_delay, loop=self)
1147+
happy_eyeballs_delay)
11481148

11491149
if sock is None:
11501150
exceptions = [exc for sub in exceptions for exc in sub]

Lib/asyncio/staggered.py

Lines changed: 18 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44

55
import contextlib
66

7-
from . import events
8-
from . import exceptions as exceptions_mod
97
from . import locks
108
from . import tasks
9+
from . import taskgroups
1110

11+
class _Done(Exception):
12+
pass
1213

13-
async def staggered_race(coro_fns, delay, *, loop=None):
14+
async def staggered_race(coro_fns, delay):
1415
"""Run coroutines with staggered start times and take the first to finish.
1516
1617
This method takes an iterable of coroutine functions. The first one is
@@ -42,8 +43,6 @@ async def staggered_race(coro_fns, delay, *, loop=None):
4243
delay: amount of time, in seconds, between starting coroutines. If
4344
``None``, the coroutines will run sequentially.
4445
45-
loop: the event loop to use.
46-
4746
Returns:
4847
tuple *(winner_result, winner_index, exceptions)* where
4948
@@ -62,36 +61,11 @@ async def staggered_race(coro_fns, delay, *, loop=None):
6261
6362
"""
6463
# TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
65-
loop = loop or events.get_running_loop()
66-
enum_coro_fns = enumerate(coro_fns)
6764
winner_result = None
6865
winner_index = None
6966
exceptions = []
70-
running_tasks = []
71-
72-
async def run_one_coro(previous_failed) -> None:
73-
# Wait for the previous task to finish, or for delay seconds
74-
if previous_failed is not None:
75-
with contextlib.suppress(exceptions_mod.TimeoutError):
76-
# Use asyncio.wait_for() instead of asyncio.wait() here, so
77-
# that if we get cancelled at this point, Event.wait() is also
78-
# cancelled, otherwise there will be a "Task destroyed but it is
79-
# pending" later.
80-
await tasks.wait_for(previous_failed.wait(), delay)
81-
# Get the next coroutine to run
82-
try:
83-
this_index, coro_fn = next(enum_coro_fns)
84-
except StopIteration:
85-
return
86-
# Start task that will run the next coroutine
87-
this_failed = locks.Event()
88-
next_task = loop.create_task(run_one_coro(this_failed))
89-
running_tasks.append(next_task)
90-
assert len(running_tasks) == this_index + 2
91-
# Prepare place to put this coroutine's exceptions if not won
92-
exceptions.append(None)
93-
assert len(exceptions) == this_index + 1
9467

68+
async def run_one_coro(this_index, coro_fn, this_failed):
9569
try:
9670
result = await coro_fn()
9771
except (SystemExit, KeyboardInterrupt):
@@ -105,34 +79,17 @@ async def run_one_coro(previous_failed) -> None:
10579
assert winner_index is None
10680
winner_index = this_index
10781
winner_result = result
108-
# Cancel all other tasks. We take care to not cancel the current
109-
# task as well. If we do so, then since there is no `await` after
110-
# here and CancelledError are usually thrown at one, we will
111-
# encounter a curious corner case where the current task will end
112-
# up as done() == True, cancelled() == False, exception() ==
113-
# asyncio.CancelledError. This behavior is specified in
114-
# https://bugs.python.org/issue30048
115-
for i, t in enumerate(running_tasks):
116-
if i != this_index:
117-
t.cancel()
118-
119-
first_task = loop.create_task(run_one_coro(None))
120-
running_tasks.append(first_task)
82+
raise _Done
83+
12184
try:
122-
# Wait for a growing list of tasks to all finish: poor man's version of
123-
# curio's TaskGroup or trio's nursery
124-
done_count = 0
125-
while done_count != len(running_tasks):
126-
done, _ = await tasks.wait(running_tasks)
127-
done_count = len(done)
128-
# If run_one_coro raises an unhandled exception, it's probably a
129-
# programming error, and I want to see it.
130-
if __debug__:
131-
for d in done:
132-
if d.done() and not d.cancelled() and d.exception():
133-
raise d.exception()
134-
return winner_result, winner_index, exceptions
135-
finally:
136-
# Make sure no tasks are left running if we leave this function
137-
for t in running_tasks:
138-
t.cancel()
85+
async with taskgroups.TaskGroup() as tg:
86+
for this_index, coro_fn in enumerate(coro_fns):
87+
this_failed = locks.Event()
88+
exceptions.append(None)
89+
tg.create_task(run_one_coro(this_index, coro_fn, this_failed))
90+
with contextlib.suppress(TimeoutError):
91+
await tasks.wait_for(this_failed.wait(), delay)
92+
except* _Done:
93+
pass
94+
95+
return winner_result, winner_index, exceptions

Lib/test/test_asyncio/test_eager_task_factory.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,53 @@ async def run():
213213

214214
self.run_coro(run())
215215

216+
def test_staggered_race_with_eager_tasks(self):
217+
# See https://github.com/python/cpython/issues/124309
218+
219+
async def fail():
220+
await asyncio.sleep(0)
221+
raise ValueError("no good")
222+
223+
async def run():
224+
winner, index, excs = await asyncio.staggered.staggered_race(
225+
[
226+
lambda: asyncio.sleep(2, result="sleep2"),
227+
lambda: asyncio.sleep(1, result="sleep1"),
228+
lambda: fail()
229+
],
230+
delay=0.25
231+
)
232+
self.assertEqual(winner, 'sleep1')
233+
self.assertEqual(index, 1)
234+
self.assertIsNone(excs[index])
235+
self.assertIsInstance(excs[0], asyncio.CancelledError)
236+
self.assertIsInstance(excs[2], ValueError)
237+
238+
self.run_coro(run())
239+
240+
def test_staggered_race_with_eager_tasks_no_delay(self):
241+
# See https://github.com/python/cpython/issues/124309
242+
async def fail():
243+
raise ValueError("no good")
244+
245+
async def run():
246+
winner, index, excs = await asyncio.staggered.staggered_race(
247+
[
248+
lambda: fail(),
249+
lambda: asyncio.sleep(1, result="sleep1"),
250+
lambda: asyncio.sleep(0, result="sleep0"),
251+
],
252+
delay=None
253+
)
254+
self.assertEqual(winner, 'sleep1')
255+
self.assertEqual(index, 1)
256+
self.assertIsNone(excs[index])
257+
self.assertIsInstance(excs[0], ValueError)
258+
self.assertEqual(len(excs), 2)
259+
260+
self.run_coro(run())
261+
262+
216263

217264
class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
218265
Task = tasks._PyTask

Lib/test/test_asyncio/test_staggered.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,45 @@ async def test_none_successful(self):
8282
async def coro(index):
8383
raise ValueError(index)
8484

85+
for delay in [None, 0, 0.1, 1]:
86+
with self.subTest(delay=delay):
87+
winner, index, excs = await staggered_race(
88+
[
89+
lambda: coro(0),
90+
lambda: coro(1),
91+
],
92+
delay=delay,
93+
)
94+
95+
self.assertIs(winner, None)
96+
self.assertIs(index, None)
97+
self.assertEqual(len(excs), 2)
98+
self.assertIsInstance(excs[0], ValueError)
99+
self.assertIsInstance(excs[1], ValueError)
100+
101+
async def test_long_delay_early_failure(self):
102+
async def coro(index):
103+
await asyncio.sleep(0) # Dummy coroutine for the 1 case
104+
if index == 0:
105+
await asyncio.sleep(0.1) # Dummy coroutine
106+
raise ValueError(index)
107+
108+
return f'Res: {index}'
109+
85110
winner, index, excs = await staggered_race(
86111
[
87112
lambda: coro(0),
88113
lambda: coro(1),
89114
],
90-
delay=None,
115+
delay=10,
91116
)
92117

93-
self.assertIs(winner, None)
94-
self.assertIs(index, None)
118+
self.assertEqual(winner, 'Res: 1')
119+
self.assertEqual(index, 1)
95120
self.assertEqual(len(excs), 2)
96121
self.assertIsInstance(excs[0], ValueError)
97-
self.assertIsInstance(excs[1], ValueError)
122+
self.assertIsNone(excs[1])
123+
124+
125+
if __name__ == "__main__":
126+
unittest.main()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed :exc:`AssertionError` when using :func:`!asyncio.staggered.staggered_race` with :attr:`asyncio.eager_task_factory`.

0 commit comments

Comments
 (0)