Skip to content

Commit 50e6b01

Browse files
committed
bpo-46994: Accept explicit contextvars.Context in asyncio create_task() API
1 parent 882d809 commit 50e6b01

File tree

10 files changed

+99
-61
lines changed

10 files changed

+99
-61
lines changed

Doc/library/asyncio-eventloop.rst

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ Creating Futures and Tasks
330330

331331
.. versionadded:: 3.5.2
332332

333-
.. method:: loop.create_task(coro, *, name=None)
333+
.. method:: loop.create_task(coro, *, name=None, context=None)
334334

335335
Schedule the execution of a :ref:`coroutine`.
336336
Return a :class:`Task` object.
@@ -342,17 +342,24 @@ Creating Futures and Tasks
342342
If the *name* argument is provided and not ``None``, it is set as
343343
the name of the task using :meth:`Task.set_name`.
344344

345+
An optional keyword-only *context* argument allows specifying a
346+
custom :class:`contextvars.Context` for the *coro* to run in.
347+
The current context copy is created when no *context* is provided.
348+
345349
.. versionchanged:: 3.8
346350
Added the *name* parameter.
347351

352+
.. versionchanged:: 3.11
353+
Added the *context* parameter.
354+
348355
.. method:: loop.set_task_factory(factory)
349356

350357
Set a task factory that will be used by
351358
:meth:`loop.create_task`.
352359

353360
If *factory* is ``None`` the default task factory will be set.
354361
Otherwise, *factory* must be a *callable* with the signature matching
355-
``(loop, coro)``, where *loop* is a reference to the active
362+
``(loop, coro, context=None)``, where *loop* is a reference to the active
356363
event loop, and *coro* is a coroutine object. The callable
357364
must return a :class:`asyncio.Future`-compatible object.
358365

Doc/library/asyncio-task.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,14 +244,18 @@ Running an asyncio Program
244244
Creating Tasks
245245
==============
246246

247-
.. function:: create_task(coro, *, name=None)
247+
.. function:: create_task(coro, *, name=None, context=None)
248248

249249
Wrap the *coro* :ref:`coroutine <coroutine>` into a :class:`Task`
250250
and schedule its execution. Return the Task object.
251251

252252
If *name* is not ``None``, it is set as the name of the task using
253253
:meth:`Task.set_name`.
254254

255+
An optional keyword-only *context* argument allows specifying a
256+
custom :class:`contextvars.Context` for the *coro* to run in.
257+
The current context copy is created when no *context* is provided.
258+
255259
The task is executed in the loop returned by :func:`get_running_loop`,
256260
:exc:`RuntimeError` is raised if there is no running loop in
257261
current thread.
@@ -281,6 +285,9 @@ Creating Tasks
281285
.. versionchanged:: 3.8
282286
Added the *name* parameter.
283287

288+
.. versionchanged:: 3.11
289+
Added the *context* parameter.
290+
284291

285292
Sleeping
286293
========

Lib/asyncio/base_events.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,18 +426,23 @@ def create_future(self):
426426
"""Create a Future object attached to the loop."""
427427
return futures.Future(loop=self)
428428

429-
def create_task(self, coro, *, name=None):
429+
def create_task(self, coro, *, name=None, context=None):
430430
"""Schedule a coroutine object.
431431
432432
Return a task object.
433433
"""
434434
self._check_closed()
435435
if self._task_factory is None:
436-
task = tasks.Task(coro, loop=self, name=name)
436+
task = tasks.Task(coro, loop=self, name=name, context=context)
437437
if task._source_traceback:
438438
del task._source_traceback[-1]
439439
else:
440-
task = self._task_factory(self, coro)
440+
if context is None:
441+
# Use legacy API if context is not needed
442+
task = self._task_factory(self, coro)
443+
else:
444+
task = self._task_factory(self, coro, context=context)
445+
441446
tasks._set_task_name(task, name)
442447

443448
return task

Lib/asyncio/events.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def create_future(self):
274274

275275
# Method scheduling a coroutine object: create a task.
276276

277-
def create_task(self, coro, *, name=None):
277+
def create_task(self, coro, *, name=None, context=None):
278278
raise NotImplementedError
279279

280280
# Methods for interacting with threads.

Lib/asyncio/tasks.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class Task(futures._PyFuture): # Inherit Python Task implementation
9090
# status is still pending
9191
_log_destroy_pending = True
9292

93-
def __init__(self, coro, *, loop=None, name=None):
93+
def __init__(self, coro, *, loop=None, name=None, context=None):
9494
super().__init__(loop=loop)
9595
if self._source_traceback:
9696
del self._source_traceback[-1]
@@ -109,7 +109,10 @@ def __init__(self, coro, *, loop=None, name=None):
109109
self._must_cancel = False
110110
self._fut_waiter = None
111111
self._coro = coro
112-
self._context = contextvars.copy_context()
112+
if context is None:
113+
self._context = contextvars.copy_context()
114+
else:
115+
self._context = context
113116

114117
self._loop.call_soon(self.__step, context=self._context)
115118
_register_task(self)
@@ -357,13 +360,18 @@ def __wakeup(self, future):
357360
Task = _CTask = _asyncio.Task
358361

359362

360-
def create_task(coro, *, name=None):
363+
def create_task(coro, *, name=None, context=None):
361364
"""Schedule the execution of a coroutine object in a spawn task.
362365
363366
Return a Task object.
364367
"""
365368
loop = events.get_running_loop()
366-
task = loop.create_task(coro)
369+
if context is None:
370+
# Use legacy API if context is not needed
371+
task = loop.create_task(coro)
372+
else:
373+
task = loop.create_task(coro, context=context)
374+
367375
_set_task_name(task, name)
368376
return task
369377

Lib/unittest/async_case.py

Lines changed: 17 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextvars
23
import inspect
34
import warnings
45

@@ -34,7 +35,7 @@ class IsolatedAsyncioTestCase(TestCase):
3435
def __init__(self, methodName='runTest'):
3536
super().__init__(methodName)
3637
self._asyncioTestLoop = None
37-
self._asyncioCallsQueue = None
38+
self._asyncioTestContext = contextvars.copy_context()
3839

3940
async def asyncSetUp(self):
4041
pass
@@ -58,7 +59,7 @@ def addAsyncCleanup(self, func, /, *args, **kwargs):
5859
self.addCleanup(*(func, *args), **kwargs)
5960

6061
def _callSetUp(self):
61-
self.setUp()
62+
self._asyncioTestContext.run(self.setUp)
6263
self._callAsync(self.asyncSetUp)
6364

6465
def _callTestMethod(self, method):
@@ -68,64 +69,42 @@ def _callTestMethod(self, method):
6869

6970
def _callTearDown(self):
7071
self._callAsync(self.asyncTearDown)
71-
self.tearDown()
72+
self._asyncioTestContext.run(self.tearDown)
7273

7374
def _callCleanup(self, function, *args, **kwargs):
7475
self._callMaybeAsync(function, *args, **kwargs)
7576

7677
def _callAsync(self, func, /, *args, **kwargs):
7778
assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
78-
ret = func(*args, **kwargs)
79-
assert inspect.isawaitable(ret), f'{func!r} returned non-awaitable'
80-
fut = self._asyncioTestLoop.create_future()
81-
self._asyncioCallsQueue.put_nowait((fut, ret))
82-
return self._asyncioTestLoop.run_until_complete(fut)
79+
assert inspect.iscoroutinefunction(func), f'{func!r} is not an async function'
80+
task = self._asyncioTestLoop.create_task(
81+
func(*args, **kwargs),
82+
context=self._asyncioTestContext,
83+
)
84+
return self._asyncioTestLoop.run_until_complete(task)
8385

8486
def _callMaybeAsync(self, func, /, *args, **kwargs):
8587
assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
86-
ret = func(*args, **kwargs)
87-
if inspect.isawaitable(ret):
88-
fut = self._asyncioTestLoop.create_future()
89-
self._asyncioCallsQueue.put_nowait((fut, ret))
90-
return self._asyncioTestLoop.run_until_complete(fut)
88+
if inspect.iscoroutinefunction(func):
89+
task = self._asyncioTestLoop.create_task(
90+
func(*args, **kwargs),
91+
context=self._asyncioTestContext,
92+
)
93+
return self._asyncioTestLoop.run_until_complete(task)
9194
else:
92-
return ret
93-
94-
async def _asyncioLoopRunner(self, fut):
95-
self._asyncioCallsQueue = queue = asyncio.Queue()
96-
fut.set_result(None)
97-
while True:
98-
query = await queue.get()
99-
queue.task_done()
100-
if query is None:
101-
return
102-
fut, awaitable = query
103-
try:
104-
ret = await awaitable
105-
if not fut.cancelled():
106-
fut.set_result(ret)
107-
except (SystemExit, KeyboardInterrupt):
108-
raise
109-
except (BaseException, asyncio.CancelledError) as ex:
110-
if not fut.cancelled():
111-
fut.set_exception(ex)
95+
return self._asyncioTestContext.run(func, *args, **kwargs)
11296

11397
def _setupAsyncioLoop(self):
11498
assert self._asyncioTestLoop is None, 'asyncio test loop already initialized'
11599
loop = asyncio.new_event_loop()
116100
asyncio.set_event_loop(loop)
117101
loop.set_debug(True)
118102
self._asyncioTestLoop = loop
119-
fut = loop.create_future()
120-
self._asyncioCallsTask = loop.create_task(self._asyncioLoopRunner(fut))
121-
loop.run_until_complete(fut)
122103

123104
def _tearDownAsyncioLoop(self):
124105
assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
125106
loop = self._asyncioTestLoop
126107
self._asyncioTestLoop = None
127-
self._asyncioCallsQueue.put_nowait(None)
128-
loop.run_until_complete(self._asyncioCallsQueue.join())
129108

130109
try:
131110
# cancel all tasks

Lib/unittest/test/test_async_case.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextvars
23
import unittest
34
from test import support
45

@@ -11,6 +12,9 @@ def tearDownModule():
1112
asyncio.set_event_loop_policy(None)
1213

1314

15+
VAR = contextvars.ContextVar('VAR', default=())
16+
17+
1418
class TestAsyncCase(unittest.TestCase):
1519
maxDiff = None
1620

@@ -24,22 +28,26 @@ class Test(unittest.IsolatedAsyncioTestCase):
2428
def setUp(self):
2529
self.assertEqual(events, [])
2630
events.append('setUp')
31+
VAR.set(VAR.get() + ('setUp',))
2732

2833
async def asyncSetUp(self):
2934
self.assertEqual(events, ['setUp'])
3035
events.append('asyncSetUp')
36+
VAR.set(VAR.get() + ('asyncSetUp',))
3137
self.addAsyncCleanup(self.on_cleanup1)
3238

3339
async def test_func(self):
3440
self.assertEqual(events, ['setUp',
3541
'asyncSetUp'])
3642
events.append('test')
43+
VAR.set(VAR.get() + ('test',))
3744
self.addAsyncCleanup(self.on_cleanup2)
3845

3946
async def asyncTearDown(self):
4047
self.assertEqual(events, ['setUp',
4148
'asyncSetUp',
4249
'test'])
50+
VAR.set(VAR.get() + ('asyncTearDown',))
4351
events.append('asyncTearDown')
4452

4553
def tearDown(self):
@@ -48,6 +56,7 @@ def tearDown(self):
4856
'test',
4957
'asyncTearDown'])
5058
events.append('tearDown')
59+
VAR.set(VAR.get() + ('tearDown',))
5160

5261
async def on_cleanup1(self):
5362
self.assertEqual(events, ['setUp',
@@ -57,6 +66,9 @@ async def on_cleanup1(self):
5766
'tearDown',
5867
'cleanup2'])
5968
events.append('cleanup1')
69+
VAR.set(VAR.get() + ('cleanup1',))
70+
nonlocal cvar
71+
cvar = VAR.get()
6072

6173
async def on_cleanup2(self):
6274
self.assertEqual(events, ['setUp',
@@ -65,22 +77,28 @@ async def on_cleanup2(self):
6577
'asyncTearDown',
6678
'tearDown'])
6779
events.append('cleanup2')
80+
VAR.set(VAR.get() + ('cleanup2',))
6881

6982
events = []
83+
cvar = ()
7084
test = Test("test_func")
7185
result = test.run()
7286
self.assertEqual(result.errors, [])
7387
self.assertEqual(result.failures, [])
7488
expected = ['setUp', 'asyncSetUp', 'test',
7589
'asyncTearDown', 'tearDown', 'cleanup2', 'cleanup1']
7690
self.assertEqual(events, expected)
91+
self.assertEqual(cvar, tuple(expected))
7792

7893
events = []
94+
cvar = ()
7995
test = Test("test_func")
8096
test.debug()
8197
self.assertEqual(events, expected)
98+
self.assertEqual(cvar, tuple(expected))
8299
test.doCleanups()
83100
self.assertEqual(events, expected)
101+
self.assertEqual(cvar, tuple(expected))
84102

85103
def test_exception_in_setup(self):
86104
class Test(unittest.IsolatedAsyncioTestCase):
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Accept explicit contextvars.Context in :func:`asyncio.create_task` and
2+
:meth:`asyncio.loop.create_task`.

Modules/_asynciomodule.c

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,14 +2003,15 @@ _asyncio.Task.__init__
20032003
*
20042004
loop: object = None
20052005
name: object = None
2006+
context: object = None
20062007
20072008
A coroutine wrapped in a Future.
20082009
[clinic start generated code]*/
20092010

20102011
static int
20112012
_asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop,
2012-
PyObject *name)
2013-
/*[clinic end generated code: output=88b12b83d570df50 input=352a3137fe60091d]*/
2013+
PyObject *name, PyObject *context)
2014+
/*[clinic end generated code: output=49ac96fe33d0e5c7 input=924522490c8ce825]*/
20142015
{
20152016
if (future_init((FutureObj*)self, loop)) {
20162017
return -1;
@@ -2028,9 +2029,13 @@ _asyncio_Task___init___impl(TaskObj *self, PyObject *coro, PyObject *loop,
20282029
return -1;
20292030
}
20302031

2031-
Py_XSETREF(self->task_context, PyContext_CopyCurrent());
2032-
if (self->task_context == NULL) {
2033-
return -1;
2032+
if (context != NULL) {
2033+
self->task_context = Py_NewRef(context);
2034+
} else {
2035+
Py_XSETREF(self->task_context, PyContext_CopyCurrent());
2036+
if (self->task_context == NULL) {
2037+
return -1;
2038+
}
20342039
}
20352040

20362041
Py_CLEAR(self->task_fut_waiter);

0 commit comments

Comments
 (0)