diff --git a/Doc/library/sys.rst b/Doc/library/sys.rst index 3e8fd82aa72..f9733b2ed63 100644 --- a/Doc/library/sys.rst +++ b/Doc/library/sys.rst @@ -1085,6 +1085,20 @@ always available. If called twice, the new wrapper replaces the previous one. The function is thread-specific. + The *wrapper* callable cannot define new coroutines directly or indirectly:: + + def wrapper(coro): + async def wrap(coro): + return await coro + return wrap(coro) + sys.set_coroutine_wrapper(wrapper) + + async def foo(): pass + + # The following line will fail with a RuntimeError, because + # `wrapper` creates a `wrap(coro)` coroutine: + foo() + See also :func:`get_coroutine_wrapper`. .. versionadded:: 3.5 diff --git a/Include/ceval.h b/Include/ceval.h index e5585945ae1..9f4d3f1998c 100644 --- a/Include/ceval.h +++ b/Include/ceval.h @@ -23,8 +23,9 @@ PyAPI_FUNC(PyObject *) PyEval_CallMethod(PyObject *obj, #ifndef Py_LIMITED_API PyAPI_FUNC(void) PyEval_SetProfile(Py_tracefunc, PyObject *); PyAPI_FUNC(void) PyEval_SetTrace(Py_tracefunc, PyObject *); -PyAPI_FUNC(void) _PyEval_SetCoroutineWrapper(PyObject *wrapper); +PyAPI_FUNC(void) _PyEval_SetCoroutineWrapper(PyObject *); PyAPI_FUNC(PyObject *) _PyEval_GetCoroutineWrapper(void); +PyAPI_FUNC(PyObject *) _PyEval_ApplyCoroutineWrapper(PyObject *); #endif struct _frame; /* Avoid including frameobject.h */ diff --git a/Include/pystate.h b/Include/pystate.h index 2ee81df7b58..a2fd8031d04 100644 --- a/Include/pystate.h +++ b/Include/pystate.h @@ -135,6 +135,7 @@ typedef struct _ts { void *on_delete_data; PyObject *coroutine_wrapper; + int in_coroutine_wrapper; /* XXX signal handlers should also be here */ diff --git a/Lib/test/test_coroutines.py b/Lib/test/test_coroutines.py index e79896a9b8e..670852d20c0 100644 --- a/Lib/test/test_coroutines.py +++ b/Lib/test/test_coroutines.py @@ -995,6 +995,26 @@ class SysSetCoroWrapperTest(unittest.TestCase): sys.set_coroutine_wrapper(1) self.assertIsNone(sys.get_coroutine_wrapper()) + def test_set_wrapper_3(self): + async def foo(): + return 'spam' + + def wrapper(coro): + async def wrap(coro): + return await coro + return wrap(coro) + + sys.set_coroutine_wrapper(wrapper) + try: + with self.assertRaisesRegex( + RuntimeError, + "coroutine wrapper.*\.wrapper at 0x.*attempted to " + "recursively wrap co_flags & CO_GENERATOR) { PyObject *gen; - PyObject *coroutine_wrapper; /* Don't need to keep the reference to f_back, it will be set * when the generator is resumed. */ @@ -3935,14 +3934,9 @@ _PyEval_EvalCodeWithName(PyObject *_co, PyObject *globals, PyObject *locals, if (gen == NULL) return NULL; - if (co->co_flags & (CO_COROUTINE | CO_ITERABLE_COROUTINE)) { - coroutine_wrapper = _PyEval_GetCoroutineWrapper(); - if (coroutine_wrapper != NULL) { - PyObject *wrapped = - PyObject_CallFunction(coroutine_wrapper, "N", gen); - gen = wrapped; - } - } + if (co->co_flags & (CO_COROUTINE | CO_ITERABLE_COROUTINE)) + return _PyEval_ApplyCoroutineWrapper(gen); + return gen; } @@ -4407,6 +4401,33 @@ _PyEval_GetCoroutineWrapper(void) return tstate->coroutine_wrapper; } +PyObject * +_PyEval_ApplyCoroutineWrapper(PyObject *gen) +{ + PyObject *wrapped; + PyThreadState *tstate = PyThreadState_GET(); + PyObject *wrapper = tstate->coroutine_wrapper; + + if (tstate->in_coroutine_wrapper) { + assert(wrapper != NULL); + PyErr_Format(PyExc_RuntimeError, + "coroutine wrapper %.150R attempted " + "to recursively wrap %.150R", + wrapper, + gen); + return NULL; + } + + if (wrapper == NULL) { + return gen; + } + + tstate->in_coroutine_wrapper = 1; + wrapped = PyObject_CallFunction(wrapper, "N", gen); + tstate->in_coroutine_wrapper = 0; + return wrapped; +} + PyObject * PyEval_GetBuiltins(void) { diff --git a/Python/pystate.c b/Python/pystate.c index 4ac05d66256..7e0267ae1d0 100644 --- a/Python/pystate.c +++ b/Python/pystate.c @@ -213,6 +213,7 @@ new_threadstate(PyInterpreterState *interp, int init) tstate->on_delete_data = NULL; tstate->coroutine_wrapper = NULL; + tstate->in_coroutine_wrapper = 0; if (init) _PyThreadState_Init(tstate);