Skip to content

bpo-32314: Fix asyncio.run() to cancel runinng tasks on shutdown #5262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,9 @@ def __init__(self):
self._coroutine_origin_tracking_enabled = False
self._coroutine_origin_tracking_saved_depth = None

if hasattr(sys, 'get_asyncgen_hooks'):
# Python >= 3.6
# A weak set of all asynchronous generators that are
# being iterated by the loop.
self._asyncgens = weakref.WeakSet()
else:
self._asyncgens = None

# A weak set of all asynchronous generators that are
# being iterated by the loop.
self._asyncgens = weakref.WeakSet()
# Set to True when `loop.shutdown_asyncgens` is called.
self._asyncgens_shutdown_called = False

Expand Down Expand Up @@ -354,7 +349,7 @@ async def shutdown_asyncgens(self):
"""Shutdown all active asynchronous generators."""
self._asyncgens_shutdown_called = True

if self._asyncgens is None or not len(self._asyncgens):
if not len(self._asyncgens):
# If Python version is <3.6 or we don't have any asynchronous
# generators alive.
return
Expand Down Expand Up @@ -386,10 +381,10 @@ def run_forever(self):
'Cannot run the event loop while another loop is running')
self._set_coroutine_origin_tracking(self._debug)
self._thread_id = threading.get_ident()
if self._asyncgens is not None:
old_agen_hooks = sys.get_asyncgen_hooks()
sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,
finalizer=self._asyncgen_finalizer_hook)

old_agen_hooks = sys.get_asyncgen_hooks()
sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,
finalizer=self._asyncgen_finalizer_hook)
try:
events._set_running_loop(self)
while True:
Expand All @@ -401,8 +396,7 @@ def run_forever(self):
self._thread_id = None
events._set_running_loop(None)
self._set_coroutine_origin_tracking(False)
if self._asyncgens is not None:
sys.set_asyncgen_hooks(*old_agen_hooks)
sys.set_asyncgen_hooks(*old_agen_hooks)

def run_until_complete(self, future):
"""Run until the Future is done.
Expand Down Expand Up @@ -1374,6 +1368,7 @@ def call_exception_handler(self, context):
- 'message': Error message;
- 'exception' (optional): Exception object;
- 'future' (optional): Future instance;
- 'task' (optional): Task instance;
- 'handle' (optional): Handle instance;
- 'protocol' (optional): Protocol instance;
- 'transport' (optional): Transport instance;
Expand Down
25 changes: 25 additions & 0 deletions Lib/asyncio/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from . import coroutines
from . import events
from . import tasks


def run(main, *, debug=False):
Expand Down Expand Up @@ -42,7 +43,31 @@ async def main():
return loop.run_until_complete(main)
finally:
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
events.set_event_loop(None)
loop.close()


def _cancel_all_tasks(loop):
to_cancel = [task for task in tasks.all_tasks(loop)
if not task.done()]
if not to_cancel:
return

for task in to_cancel:
task.cancel()

loop.run_until_complete(
tasks.gather(*to_cancel, loop=loop, return_exceptions=True))

for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler({
'message': 'unhandled exception during asyncio.run() shutdown',
'exception': task.exception(),
'task': task,
})
79 changes: 79 additions & 0 deletions Lib/test/test_asyncio/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest

from unittest import mock
from . import utils as test_utils


class TestPolicy(asyncio.AbstractEventLoopPolicy):
Expand Down Expand Up @@ -98,3 +99,81 @@ async def main():
with self.assertRaisesRegex(RuntimeError,
'cannot be called from a running'):
asyncio.run(main())

def test_asyncio_run_cancels_hanging_tasks(self):
lo_task = None

async def leftover():
await asyncio.sleep(0.1)

async def main():
nonlocal lo_task
lo_task = asyncio.create_task(leftover())
return 123

self.assertEqual(asyncio.run(main()), 123)
self.assertTrue(lo_task.done())

def test_asyncio_run_reports_hanging_tasks_errors(self):
lo_task = None
call_exc_handler_mock = mock.Mock()

async def leftover():
try:
await asyncio.sleep(0.1)
except asyncio.CancelledError:
1 / 0

async def main():
loop = asyncio.get_running_loop()
loop.call_exception_handler = call_exc_handler_mock

nonlocal lo_task
lo_task = asyncio.create_task(leftover())
return 123

self.assertEqual(asyncio.run(main()), 123)
self.assertTrue(lo_task.done())

call_exc_handler_mock.assert_called_with({
'message': test_utils.MockPattern(r'asyncio.run.*shutdown'),
'task': lo_task,
'exception': test_utils.MockInstanceOf(ZeroDivisionError)
})

def test_asyncio_run_closes_gens_after_hanging_tasks_errors(self):
spinner = None
lazyboy = None

class FancyExit(Exception):
pass

async def fidget():
while True:
yield 1
await asyncio.sleep(1)

async def spin():
nonlocal spinner
spinner = fidget()
try:
async for the_meaning_of_life in spinner: # NoQA
pass
except asyncio.CancelledError:
1 / 0

async def main():
loop = asyncio.get_running_loop()
loop.call_exception_handler = mock.Mock()

nonlocal lazyboy
lazyboy = asyncio.create_task(spin())
raise FancyExit

with self.assertRaises(FancyExit):
asyncio.run(main())

self.assertTrue(lazyboy.done())

self.assertIsNone(spinner.ag_frame)
self.assertFalse(spinner.ag_running)
8 changes: 8 additions & 0 deletions Lib/test/test_asyncio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,14 @@ def __eq__(self, other):
return bool(re.search(str(self), other, re.S))


class MockInstanceOf:
def __init__(self, type):
self._type = type

def __eq__(self, other):
return isinstance(other, self._type)


def get_function_source(func):
source = format_helpers._get_function_source(func)
if source is None:
Expand Down