Skip to content

Commit ef5ee5d

Browse files
committed
bpo-32314: Fix asyncio.run() to cancel runinng tasks on shutdown
1 parent fc2f407 commit ef5ee5d

File tree

4 files changed

+122
-15
lines changed

4 files changed

+122
-15
lines changed

Lib/asyncio/base_events.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -228,14 +228,9 @@ def __init__(self):
228228
self._coroutine_origin_tracking_enabled = False
229229
self._coroutine_origin_tracking_saved_depth = None
230230

231-
if hasattr(sys, 'get_asyncgen_hooks'):
232-
# Python >= 3.6
233-
# A weak set of all asynchronous generators that are
234-
# being iterated by the loop.
235-
self._asyncgens = weakref.WeakSet()
236-
else:
237-
self._asyncgens = None
238-
231+
# A weak set of all asynchronous generators that are
232+
# being iterated by the loop.
233+
self._asyncgens = weakref.WeakSet()
239234
# Set to True when `loop.shutdown_asyncgens` is called.
240235
self._asyncgens_shutdown_called = False
241236

@@ -354,7 +349,7 @@ async def shutdown_asyncgens(self):
354349
"""Shutdown all active asynchronous generators."""
355350
self._asyncgens_shutdown_called = True
356351

357-
if self._asyncgens is None or not len(self._asyncgens):
352+
if not len(self._asyncgens):
358353
# If Python version is <3.6 or we don't have any asynchronous
359354
# generators alive.
360355
return
@@ -386,10 +381,10 @@ def run_forever(self):
386381
'Cannot run the event loop while another loop is running')
387382
self._set_coroutine_origin_tracking(self._debug)
388383
self._thread_id = threading.get_ident()
389-
if self._asyncgens is not None:
390-
old_agen_hooks = sys.get_asyncgen_hooks()
391-
sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,
392-
finalizer=self._asyncgen_finalizer_hook)
384+
385+
old_agen_hooks = sys.get_asyncgen_hooks()
386+
sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,
387+
finalizer=self._asyncgen_finalizer_hook)
393388
try:
394389
events._set_running_loop(self)
395390
while True:
@@ -401,8 +396,7 @@ def run_forever(self):
401396
self._thread_id = None
402397
events._set_running_loop(None)
403398
self._set_coroutine_origin_tracking(False)
404-
if self._asyncgens is not None:
405-
sys.set_asyncgen_hooks(*old_agen_hooks)
399+
sys.set_asyncgen_hooks(*old_agen_hooks)
406400

407401
def run_until_complete(self, future):
408402
"""Run until the Future is done.
@@ -1374,6 +1368,7 @@ def call_exception_handler(self, context):
13741368
- 'message': Error message;
13751369
- 'exception' (optional): Exception object;
13761370
- 'future' (optional): Future instance;
1371+
- 'task' (optional): Task instance;
13771372
- 'handle' (optional): Handle instance;
13781373
- 'protocol' (optional): Protocol instance;
13791374
- 'transport' (optional): Transport instance;

Lib/asyncio/runners.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from . import coroutines
44
from . import events
5+
from . import tasks
56

67

78
def run(main, *, debug=False):
@@ -42,7 +43,31 @@ async def main():
4243
return loop.run_until_complete(main)
4344
finally:
4445
try:
46+
_cancel_all_tasks(loop)
4547
loop.run_until_complete(loop.shutdown_asyncgens())
4648
finally:
4749
events.set_event_loop(None)
4850
loop.close()
51+
52+
53+
def _cancel_all_tasks(loop):
54+
to_cancel = [task for task in tasks.all_tasks(loop)
55+
if not task.done()]
56+
if not to_cancel:
57+
return
58+
59+
for task in to_cancel:
60+
task.cancel()
61+
62+
loop.run_until_complete(
63+
tasks.gather(*to_cancel, loop=loop, return_exceptions=True))
64+
65+
for task in to_cancel:
66+
if task.cancelled():
67+
continue
68+
if task.exception() is not None:
69+
loop.call_exception_handler({
70+
'message': 'unhandled exception during asyncio.run() shutdown',
71+
'exception': task.exception(),
72+
'task': task,
73+
})

Lib/test/test_asyncio/test_runners.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import unittest
33

44
from unittest import mock
5+
from . import utils as test_utils
56

67

78
class TestPolicy(asyncio.AbstractEventLoopPolicy):
@@ -98,3 +99,81 @@ async def main():
9899
with self.assertRaisesRegex(RuntimeError,
99100
'cannot be called from a running'):
100101
asyncio.run(main())
102+
103+
def test_asyncio_run_cancels_hanging_tasks(self):
104+
lo_task = None
105+
106+
async def leftover():
107+
await asyncio.sleep(0.1)
108+
109+
async def main():
110+
nonlocal lo_task
111+
lo_task = asyncio.create_task(leftover())
112+
return 123
113+
114+
self.assertEqual(asyncio.run(main()), 123)
115+
self.assertTrue(lo_task.done())
116+
117+
def test_asyncio_run_reports_hanging_tasks_errors(self):
118+
lo_task = None
119+
call_exc_handler_mock = mock.Mock()
120+
121+
async def leftover():
122+
try:
123+
await asyncio.sleep(0.1)
124+
except asyncio.CancelledError:
125+
1 / 0
126+
127+
async def main():
128+
loop = asyncio.get_running_loop()
129+
loop.call_exception_handler = call_exc_handler_mock
130+
131+
nonlocal lo_task
132+
lo_task = asyncio.create_task(leftover())
133+
return 123
134+
135+
self.assertEqual(asyncio.run(main()), 123)
136+
self.assertTrue(lo_task.done())
137+
138+
call_exc_handler_mock.assert_called_with({
139+
'message': test_utils.MockPattern(r'asyncio.run.*shutdown'),
140+
'task': lo_task,
141+
'exception': test_utils.MockInstanceOf(ZeroDivisionError)
142+
})
143+
144+
def test_asyncio_run_closes_gens_after_hanging_tasks_errors(self):
145+
spinner = None
146+
lazyboy = None
147+
148+
class FancyExit(Exception):
149+
pass
150+
151+
async def fidget():
152+
while True:
153+
yield 1
154+
await asyncio.sleep(1)
155+
156+
async def spin():
157+
nonlocal spinner
158+
spinner = fidget()
159+
try:
160+
async for the_meaning_of_life in spinner: # NoQA
161+
pass
162+
except asyncio.CancelledError:
163+
1 / 0
164+
165+
async def main():
166+
loop = asyncio.get_running_loop()
167+
loop.call_exception_handler = mock.Mock()
168+
169+
nonlocal lazyboy
170+
lazyboy = asyncio.create_task(spin())
171+
raise FancyExit
172+
173+
with self.assertRaises(FancyExit):
174+
asyncio.run(main())
175+
176+
self.assertTrue(lazyboy.done())
177+
178+
self.assertIsNone(spinner.ag_frame)
179+
self.assertFalse(spinner.ag_running)

Lib/test/test_asyncio/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,14 @@ def __eq__(self, other):
485485
return bool(re.search(str(self), other, re.S))
486486

487487

488+
class MockInstanceOf:
489+
def __init__(self, type):
490+
self._type = type
491+
492+
def __eq__(self, other):
493+
return isinstance(other, self._type)
494+
495+
488496
def get_function_source(func):
489497
source = format_helpers._get_function_source(func)
490498
if source is None:

0 commit comments

Comments
 (0)