From dfb45323ce8a543ca844c311e32c994ec9554c1b Mon Sep 17 00:00:00 2001 From: Dennis Sweeney <36520290+sweeneyde@users.noreply.github.com> Date: Sun, 11 Apr 2021 00:51:35 -0400 Subject: [PATCH] bpo-43751: Fix anext() bug where it erroneously returned None (GH-25238) --- Lib/test/test_asyncgen.py | 140 +++++++++++++++++- .../2021-04-07-18-00-05.bpo-43751.8fHsqQ.rst | 1 + Objects/iterobject.c | 47 +++++- 3 files changed, 182 insertions(+), 6 deletions(-) create mode 100644 Misc/NEWS.d/next/Core and Builtins/2021-04-07-18-00-05.bpo-43751.8fHsqQ.rst diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py index 99464e3d092..77c15c02bc8 100644 --- a/Lib/test/test_asyncgen.py +++ b/Lib/test/test_asyncgen.py @@ -372,11 +372,8 @@ class AsyncGenAsyncioTest(unittest.TestCase): self.loop = None asyncio.set_event_loop_policy(None) - def test_async_gen_anext(self): - async def gen(): - yield 1 - yield 2 - g = gen() + def check_async_iterator_anext(self, ait_class): + g = ait_class() async def consume(): results = [] results.append(await anext(g)) @@ -388,6 +385,66 @@ class AsyncGenAsyncioTest(unittest.TestCase): with self.assertRaises(StopAsyncIteration): self.loop.run_until_complete(consume()) + async def test_2(): + g1 = ait_class() + self.assertEqual(await anext(g1), 1) + self.assertEqual(await anext(g1), 2) + with self.assertRaises(StopAsyncIteration): + await anext(g1) + with self.assertRaises(StopAsyncIteration): + await anext(g1) + + g2 = ait_class() + self.assertEqual(await anext(g2, "default"), 1) + self.assertEqual(await anext(g2, "default"), 2) + self.assertEqual(await anext(g2, "default"), "default") + self.assertEqual(await anext(g2, "default"), "default") + + return "completed" + + result = self.loop.run_until_complete(test_2()) + self.assertEqual(result, "completed") + + def test_async_generator_anext(self): + async def agen(): + yield 1 + yield 2 + self.check_async_iterator_anext(agen) + + def test_python_async_iterator_anext(self): + class MyAsyncIter: + """Asynchronously yield 1, then 2.""" + def __init__(self): + self.yielded = 0 + def __aiter__(self): + return self + async def __anext__(self): + if self.yielded >= 2: + raise StopAsyncIteration() + else: + self.yielded += 1 + return self.yielded + self.check_async_iterator_anext(MyAsyncIter) + + def test_python_async_iterator_types_coroutine_anext(self): + import types + class MyAsyncIterWithTypesCoro: + """Asynchronously yield 1, then 2.""" + def __init__(self): + self.yielded = 0 + def __aiter__(self): + return self + @types.coroutine + def __anext__(self): + if False: + yield "this is a generator-based coroutine" + if self.yielded >= 2: + raise StopAsyncIteration() + else: + self.yielded += 1 + return self.yielded + self.check_async_iterator_anext(MyAsyncIterWithTypesCoro) + def test_async_gen_aiter(self): async def gen(): yield 1 @@ -431,12 +488,85 @@ class AsyncGenAsyncioTest(unittest.TestCase): await anext(gen(), 1, 3) async def call_with_wrong_type_args(): await anext(1, gen()) + async def call_with_kwarg(): + await anext(aiterator=gen()) with self.assertRaises(TypeError): self.loop.run_until_complete(call_with_too_few_args()) with self.assertRaises(TypeError): self.loop.run_until_complete(call_with_too_many_args()) with self.assertRaises(TypeError): self.loop.run_until_complete(call_with_wrong_type_args()) + with self.assertRaises(TypeError): + self.loop.run_until_complete(call_with_kwarg()) + + def test_anext_bad_await(self): + async def bad_awaitable(): + class BadAwaitable: + def __await__(self): + return 42 + class MyAsyncIter: + def __aiter__(self): + return self + def __anext__(self): + return BadAwaitable() + regex = r"__await__.*iterator" + awaitable = anext(MyAsyncIter(), "default") + with self.assertRaisesRegex(TypeError, regex): + await awaitable + awaitable = anext(MyAsyncIter()) + with self.assertRaisesRegex(TypeError, regex): + await awaitable + return "completed" + result = self.loop.run_until_complete(bad_awaitable()) + self.assertEqual(result, "completed") + + async def check_anext_returning_iterator(self, aiter_class): + awaitable = anext(aiter_class(), "default") + with self.assertRaises(TypeError): + await awaitable + awaitable = anext(aiter_class()) + with self.assertRaises(TypeError): + await awaitable + return "completed" + + def test_anext_return_iterator(self): + class WithIterAnext: + def __aiter__(self): + return self + def __anext__(self): + return iter("abc") + result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithIterAnext)) + self.assertEqual(result, "completed") + + def test_anext_return_generator(self): + class WithGenAnext: + def __aiter__(self): + return self + def __anext__(self): + yield + result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithGenAnext)) + self.assertEqual(result, "completed") + + def test_anext_await_raises(self): + class RaisingAwaitable: + def __await__(self): + raise ZeroDivisionError() + yield + class WithRaisingAwaitableAnext: + def __aiter__(self): + return self + def __anext__(self): + return RaisingAwaitable() + async def do_test(): + awaitable = anext(WithRaisingAwaitableAnext()) + with self.assertRaises(ZeroDivisionError): + await awaitable + awaitable = anext(WithRaisingAwaitableAnext(), "default") + with self.assertRaises(ZeroDivisionError): + await awaitable + return "completed" + result = self.loop.run_until_complete(do_test()) + self.assertEqual(result, "completed") def test_aiter_bad_args(self): async def gen(): diff --git a/Misc/NEWS.d/next/Core and Builtins/2021-04-07-18-00-05.bpo-43751.8fHsqQ.rst b/Misc/NEWS.d/next/Core and Builtins/2021-04-07-18-00-05.bpo-43751.8fHsqQ.rst new file mode 100644 index 00000000000..75951ae794d --- /dev/null +++ b/Misc/NEWS.d/next/Core and Builtins/2021-04-07-18-00-05.bpo-43751.8fHsqQ.rst @@ -0,0 +1 @@ +Fixed a bug where ``anext(ait, default)`` would erroneously return None. \ No newline at end of file diff --git a/Objects/iterobject.c b/Objects/iterobject.c index 65af18abf79..6961fc3b4a9 100644 --- a/Objects/iterobject.c +++ b/Objects/iterobject.c @@ -316,7 +316,52 @@ anextawaitable_traverse(anextawaitableobject *obj, visitproc visit, void *arg) static PyObject * anextawaitable_iternext(anextawaitableobject *obj) { - PyObject *result = PyIter_Next(obj->wrapped); + /* Consider the following class: + * + * class A: + * async def __anext__(self): + * ... + * a = A() + * + * Then `await anext(a)` should call + * a.__anext__().__await__().__next__() + * + * On the other hand, given + * + * async def agen(): + * yield 1 + * yield 2 + * gen = agen() + * + * Then `await anext(gen)` can just call + * gen.__anext__().__next__() + */ + assert(obj->wrapped != NULL); + PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped); + if (awaitable == NULL) { + return NULL; + } + if (Py_TYPE(awaitable)->tp_iternext == NULL) { + /* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator, + * or an iterator. Of these, only coroutines lack tp_iternext. + */ + assert(PyCoro_CheckExact(awaitable)); + unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await; + PyObject *new_awaitable = getter(awaitable); + if (new_awaitable == NULL) { + Py_DECREF(awaitable); + return NULL; + } + Py_SETREF(awaitable, new_awaitable); + if (Py_TYPE(awaitable)->tp_iternext == NULL) { + PyErr_SetString(PyExc_TypeError, + "__await__ returned a non-iterable"); + Py_DECREF(awaitable); + return NULL; + } + } + PyObject *result = (*Py_TYPE(awaitable)->tp_iternext)(awaitable); + Py_DECREF(awaitable); if (result != NULL) { return result; }