diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py index bc0ae8f5321..473bce484b4 100644 --- a/Lib/test/test_asyncgen.py +++ b/Lib/test/test_asyncgen.py @@ -1,12 +1,16 @@ import inspect import types import unittest +import contextlib from test.support.import_helper import import_module from test.support import gc_collect asyncio = import_module("asyncio") +_no_default = object() + + class AwaitException(Exception): pass @@ -45,6 +49,37 @@ def to_list(gen): return run_until_complete(iterate()) +def py_anext(iterator, default=_no_default): + """Pure-Python implementation of anext() for testing purposes. + + Closely matches the builtin anext() C implementation. + Can be used to compare the built-in implementation of the inner + coroutines machinery to C-implementation of __anext__() and send() + or throw() on the returned generator. + """ + + try: + __anext__ = type(iterator).__anext__ + except AttributeError: + raise TypeError(f'{iterator!r} is not an async iterator') + + if default is _no_default: + return __anext__(iterator) + + async def anext_impl(): + try: + # The C code is way more low-level than this, as it implements + # all methods of the iterator protocol. In this implementation + # we're relying on higher-level coroutine concepts, but that's + # exactly what we want -- crosstest pure-Python high-level + # implementation and low-level C anext() iterators. + return await __anext__(iterator) + except StopAsyncIteration: + return default + + return anext_impl() + + class AsyncGenSyntaxTest(unittest.TestCase): def test_async_gen_syntax_01(self): @@ -374,6 +409,12 @@ class AsyncGenAsyncioTest(unittest.TestCase): asyncio.set_event_loop_policy(None) def check_async_iterator_anext(self, ait_class): + with self.subTest(anext="pure-Python"): + self._check_async_iterator_anext(ait_class, py_anext) + with self.subTest(anext="builtin"): + self._check_async_iterator_anext(ait_class, anext) + + def _check_async_iterator_anext(self, ait_class, anext): g = ait_class() async def consume(): results = [] @@ -406,6 +447,24 @@ class AsyncGenAsyncioTest(unittest.TestCase): result = self.loop.run_until_complete(test_2()) self.assertEqual(result, "completed") + def test_send(): + p = ait_class() + obj = anext(p, "completed") + with self.assertRaises(StopIteration): + with contextlib.closing(obj.__await__()) as g: + g.send(None) + + test_send() + + async def test_throw(): + p = ait_class() + obj = anext(p, "completed") + self.assertRaises(SyntaxError, obj.throw, SyntaxError) + return "completed" + + result = self.loop.run_until_complete(test_throw()) + self.assertEqual(result, "completed") + def test_async_generator_anext(self): async def agen(): yield 1 @@ -569,6 +628,119 @@ class AsyncGenAsyncioTest(unittest.TestCase): result = self.loop.run_until_complete(do_test()) self.assertEqual(result, "completed") + def test_anext_iter(self): + @types.coroutine + def _async_yield(v): + return (yield v) + + class MyError(Exception): + pass + + async def agenfn(): + try: + await _async_yield(1) + except MyError: + await _async_yield(2) + return + yield + + def test1(anext): + agen = agenfn() + with contextlib.closing(anext(agen, "default").__await__()) as g: + self.assertEqual(g.send(None), 1) + self.assertEqual(g.throw(MyError, MyError(), None), 2) + try: + g.send(None) + except StopIteration as e: + err = e + else: + self.fail('StopIteration was not raised') + self.assertEqual(err.value, "default") + + def test2(anext): + agen = agenfn() + with contextlib.closing(anext(agen, "default").__await__()) as g: + self.assertEqual(g.send(None), 1) + self.assertEqual(g.throw(MyError, MyError(), None), 2) + with self.assertRaises(MyError): + g.throw(MyError, MyError(), None) + + def test3(anext): + agen = agenfn() + with contextlib.closing(anext(agen, "default").__await__()) as g: + self.assertEqual(g.send(None), 1) + g.close() + with self.assertRaisesRegex(RuntimeError, 'cannot reuse'): + self.assertEqual(g.send(None), 1) + + def test4(anext): + @types.coroutine + def _async_yield(v): + yield v * 10 + return (yield (v * 10 + 1)) + + async def agenfn(): + try: + await _async_yield(1) + except MyError: + await _async_yield(2) + return + yield + + agen = agenfn() + with contextlib.closing(anext(agen, "default").__await__()) as g: + self.assertEqual(g.send(None), 10) + self.assertEqual(g.throw(MyError, MyError(), None), 20) + with self.assertRaisesRegex(MyError, 'val'): + g.throw(MyError, MyError('val'), None) + + def test5(anext): + @types.coroutine + def _async_yield(v): + yield v * 10 + return (yield (v * 10 + 1)) + + async def agenfn(): + try: + await _async_yield(1) + except MyError: + return + yield 'aaa' + + agen = agenfn() + with contextlib.closing(anext(agen, "default").__await__()) as g: + self.assertEqual(g.send(None), 10) + with self.assertRaisesRegex(StopIteration, 'default'): + g.throw(MyError, MyError(), None) + + def test6(anext): + @types.coroutine + def _async_yield(v): + yield v * 10 + return (yield (v * 10 + 1)) + + async def agenfn(): + await _async_yield(1) + yield 'aaa' + + agen = agenfn() + with contextlib.closing(anext(agen, "default").__await__()) as g: + with self.assertRaises(MyError): + g.throw(MyError, MyError(), None) + + def run_test(test): + with self.subTest('pure-Python anext()'): + test(py_anext) + with self.subTest('builtin anext()'): + test(anext) + + run_test(test1) + run_test(test2) + run_test(test3) + run_test(test4) + run_test(test5) + run_test(test6) + def test_aiter_bad_args(self): async def gen(): yield 1 diff --git a/Misc/NEWS.d/next/Core and Builtins/2021-08-25-23-07-10.bpo-44963.5EET8y.rst b/Misc/NEWS.d/next/Core and Builtins/2021-08-25-23-07-10.bpo-44963.5EET8y.rst new file mode 100644 index 00000000000..9a54bda118e --- /dev/null +++ b/Misc/NEWS.d/next/Core and Builtins/2021-08-25-23-07-10.bpo-44963.5EET8y.rst @@ -0,0 +1,2 @@ +Implement ``send()`` and ``throw()`` methods for ``anext_awaitable`` +objects. Patch by Pablo Galindo. diff --git a/Objects/iterobject.c b/Objects/iterobject.c index 6961fc3b4a9..e493e41131b 100644 --- a/Objects/iterobject.c +++ b/Objects/iterobject.c @@ -313,6 +313,36 @@ anextawaitable_traverse(anextawaitableobject *obj, visitproc visit, void *arg) return 0; } +static PyObject * +anextawaitable_getiter(anextawaitableobject *obj) +{ + assert(obj->wrapped != NULL); + PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped); + if (awaitable == NULL) { + return NULL; + } + if (Py_TYPE(awaitable)->tp_iternext == NULL) { + /* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator, + * or an iterator. Of these, only coroutines lack tp_iternext. + */ + assert(PyCoro_CheckExact(awaitable)); + unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await; + PyObject *new_awaitable = getter(awaitable); + if (new_awaitable == NULL) { + Py_DECREF(awaitable); + return NULL; + } + Py_SETREF(awaitable, new_awaitable); + if (!PyIter_Check(awaitable)) { + PyErr_SetString(PyExc_TypeError, + "__await__ returned a non-iterable"); + Py_DECREF(awaitable); + return NULL; + } + } + return awaitable; +} + static PyObject * anextawaitable_iternext(anextawaitableobject *obj) { @@ -336,30 +366,10 @@ anextawaitable_iternext(anextawaitableobject *obj) * Then `await anext(gen)` can just call * gen.__anext__().__next__() */ - assert(obj->wrapped != NULL); - PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped); + PyObject *awaitable = anextawaitable_getiter(obj); if (awaitable == NULL) { return NULL; } - if (Py_TYPE(awaitable)->tp_iternext == NULL) { - /* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator, - * or an iterator. Of these, only coroutines lack tp_iternext. - */ - assert(PyCoro_CheckExact(awaitable)); - unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await; - PyObject *new_awaitable = getter(awaitable); - if (new_awaitable == NULL) { - Py_DECREF(awaitable); - return NULL; - } - Py_SETREF(awaitable, new_awaitable); - if (Py_TYPE(awaitable)->tp_iternext == NULL) { - PyErr_SetString(PyExc_TypeError, - "__await__ returned a non-iterable"); - Py_DECREF(awaitable); - return NULL; - } - } PyObject *result = (*Py_TYPE(awaitable)->tp_iternext)(awaitable); Py_DECREF(awaitable); if (result != NULL) { @@ -371,6 +381,70 @@ anextawaitable_iternext(anextawaitableobject *obj) return NULL; } + +static PyObject * +anextawaitable_proxy(anextawaitableobject *obj, char *meth, PyObject *arg) { + PyObject *awaitable = anextawaitable_getiter(obj); + if (awaitable == NULL) { + return NULL; + } + PyObject *ret = PyObject_CallMethod(awaitable, meth, "O", arg); + Py_DECREF(awaitable); + if (ret != NULL) { + return ret; + } + if (PyErr_ExceptionMatches(PyExc_StopAsyncIteration)) { + /* `anextawaitableobject` is only used by `anext()` when + * a default value is provided. So when we have a StopAsyncIteration + * exception we replace it with a `StopIteration(default)`, as if + * it was the return value of `__anext__()` coroutine. + */ + _PyGen_SetStopIterationValue(obj->default_value); + } + return NULL; +} + + +static PyObject * +anextawaitable_send(anextawaitableobject *obj, PyObject *arg) { + return anextawaitable_proxy(obj, "send", arg); +} + + +static PyObject * +anextawaitable_throw(anextawaitableobject *obj, PyObject *arg) { + return anextawaitable_proxy(obj, "throw", arg); +} + + +static PyObject * +anextawaitable_close(anextawaitableobject *obj, PyObject *arg) { + return anextawaitable_proxy(obj, "close", arg); +} + + +PyDoc_STRVAR(send_doc, +"send(arg) -> send 'arg' into the wrapped iterator,\n\ +return next yielded value or raise StopIteration."); + + +PyDoc_STRVAR(throw_doc, +"throw(typ[,val[,tb]]) -> raise exception in the wrapped iterator,\n\ +return next yielded value or raise StopIteration."); + + +PyDoc_STRVAR(close_doc, +"close() -> raise GeneratorExit inside generator."); + + +static PyMethodDef anextawaitable_methods[] = { + {"send",(PyCFunction)anextawaitable_send, METH_O, send_doc}, + {"throw",(PyCFunction)anextawaitable_throw, METH_VARARGS, throw_doc}, + {"close",(PyCFunction)anextawaitable_close, METH_VARARGS, close_doc}, + {NULL, NULL} /* Sentinel */ +}; + + static PyAsyncMethods anextawaitable_as_async = { PyObject_SelfIter, /* am_await */ 0, /* am_aiter */ @@ -407,7 +481,7 @@ PyTypeObject _PyAnextAwaitable_Type = { 0, /* tp_weaklistoffset */ PyObject_SelfIter, /* tp_iter */ (unaryfunc)anextawaitable_iternext, /* tp_iternext */ - 0, /* tp_methods */ + anextawaitable_methods, /* tp_methods */ }; PyObject *