Skip to content
This repository was archived by the owner on Feb 13, 2025. It is now read-only.

Commit aab3c4a

Browse files
committed
Issue 24342: Let wrapper set by sys.set_coroutine_wrapper fail gracefully
1 parent 231d906 commit aab3c4a

File tree

6 files changed

+68
-10
lines changed

6 files changed

+68
-10
lines changed

Doc/library/sys.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,6 +1085,20 @@ always available.
10851085
If called twice, the new wrapper replaces the previous one. The function
10861086
is thread-specific.
10871087

1088+
The *wrapper* callable cannot define new coroutines directly or indirectly::
1089+
1090+
def wrapper(coro):
1091+
async def wrap(coro):
1092+
return await coro
1093+
return wrap(coro)
1094+
sys.set_coroutine_wrapper(wrapper)
1095+
1096+
async def foo(): pass
1097+
1098+
# The following line will fail with a RuntimeError, because
1099+
# `wrapper` creates a `wrap(coro)` coroutine:
1100+
foo()
1101+
10881102
See also :func:`get_coroutine_wrapper`.
10891103

10901104
.. versionadded:: 3.5

Include/ceval.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ PyAPI_FUNC(PyObject *) PyEval_CallMethod(PyObject *obj,
2323
#ifndef Py_LIMITED_API
2424
PyAPI_FUNC(void) PyEval_SetProfile(Py_tracefunc, PyObject *);
2525
PyAPI_FUNC(void) PyEval_SetTrace(Py_tracefunc, PyObject *);
26-
PyAPI_FUNC(void) _PyEval_SetCoroutineWrapper(PyObject *wrapper);
26+
PyAPI_FUNC(void) _PyEval_SetCoroutineWrapper(PyObject *);
2727
PyAPI_FUNC(PyObject *) _PyEval_GetCoroutineWrapper(void);
28+
PyAPI_FUNC(PyObject *) _PyEval_ApplyCoroutineWrapper(PyObject *);
2829
#endif
2930

3031
struct _frame; /* Avoid including frameobject.h */

Include/pystate.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ typedef struct _ts {
135135
void *on_delete_data;
136136

137137
PyObject *coroutine_wrapper;
138+
int in_coroutine_wrapper;
138139

139140
/* XXX signal handlers should also be here */
140141

Lib/test/test_coroutines.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,26 @@ def test_set_wrapper_2(self):
995995
sys.set_coroutine_wrapper(1)
996996
self.assertIsNone(sys.get_coroutine_wrapper())
997997

998+
def test_set_wrapper_3(self):
999+
async def foo():
1000+
return 'spam'
1001+
1002+
def wrapper(coro):
1003+
async def wrap(coro):
1004+
return await coro
1005+
return wrap(coro)
1006+
1007+
sys.set_coroutine_wrapper(wrapper)
1008+
try:
1009+
with self.assertRaisesRegex(
1010+
RuntimeError,
1011+
"coroutine wrapper.*\.wrapper at 0x.*attempted to "
1012+
"recursively wrap <coroutine.*\.wrap"):
1013+
1014+
foo()
1015+
finally:
1016+
sys.set_coroutine_wrapper(None)
1017+
9981018

9991019
class CAPITest(unittest.TestCase):
10001020

Python/ceval.c

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3921,7 +3921,6 @@ _PyEval_EvalCodeWithName(PyObject *_co, PyObject *globals, PyObject *locals,
39213921

39223922
if (co->co_flags & CO_GENERATOR) {
39233923
PyObject *gen;
3924-
PyObject *coroutine_wrapper;
39253924

39263925
/* Don't need to keep the reference to f_back, it will be set
39273926
* when the generator is resumed. */
@@ -3935,14 +3934,9 @@ _PyEval_EvalCodeWithName(PyObject *_co, PyObject *globals, PyObject *locals,
39353934
if (gen == NULL)
39363935
return NULL;
39373936

3938-
if (co->co_flags & (CO_COROUTINE | CO_ITERABLE_COROUTINE)) {
3939-
coroutine_wrapper = _PyEval_GetCoroutineWrapper();
3940-
if (coroutine_wrapper != NULL) {
3941-
PyObject *wrapped =
3942-
PyObject_CallFunction(coroutine_wrapper, "N", gen);
3943-
gen = wrapped;
3944-
}
3945-
}
3937+
if (co->co_flags & (CO_COROUTINE | CO_ITERABLE_COROUTINE))
3938+
return _PyEval_ApplyCoroutineWrapper(gen);
3939+
39463940
return gen;
39473941
}
39483942

@@ -4407,6 +4401,33 @@ _PyEval_GetCoroutineWrapper(void)
44074401
return tstate->coroutine_wrapper;
44084402
}
44094403

4404+
PyObject *
4405+
_PyEval_ApplyCoroutineWrapper(PyObject *gen)
4406+
{
4407+
PyObject *wrapped;
4408+
PyThreadState *tstate = PyThreadState_GET();
4409+
PyObject *wrapper = tstate->coroutine_wrapper;
4410+
4411+
if (tstate->in_coroutine_wrapper) {
4412+
assert(wrapper != NULL);
4413+
PyErr_Format(PyExc_RuntimeError,
4414+
"coroutine wrapper %.150R attempted "
4415+
"to recursively wrap %.150R",
4416+
wrapper,
4417+
gen);
4418+
return NULL;
4419+
}
4420+
4421+
if (wrapper == NULL) {
4422+
return gen;
4423+
}
4424+
4425+
tstate->in_coroutine_wrapper = 1;
4426+
wrapped = PyObject_CallFunction(wrapper, "N", gen);
4427+
tstate->in_coroutine_wrapper = 0;
4428+
return wrapped;
4429+
}
4430+
44104431
PyObject *
44114432
PyEval_GetBuiltins(void)
44124433
{

Python/pystate.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ new_threadstate(PyInterpreterState *interp, int init)
213213
tstate->on_delete_data = NULL;
214214

215215
tstate->coroutine_wrapper = NULL;
216+
tstate->in_coroutine_wrapper = 0;
216217

217218
if (init)
218219
_PyThreadState_Init(tstate);

0 commit comments

Comments
 (0)