bpo-43751: Fix anext() bug where it erroneously returned None (GH-25238)

This commit is contained in:
Dennis Sweeney 2021-04-11 00:51:35 -04:00 committed by GitHub
parent 9045919bfa
commit dfb45323ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 182 additions and 6 deletions

View File

@ -372,11 +372,8 @@ class AsyncGenAsyncioTest(unittest.TestCase):
self.loop = None
asyncio.set_event_loop_policy(None)
def test_async_gen_anext(self):
async def gen():
yield 1
yield 2
g = gen()
def check_async_iterator_anext(self, ait_class):
g = ait_class()
async def consume():
results = []
results.append(await anext(g))
@ -388,6 +385,66 @@ class AsyncGenAsyncioTest(unittest.TestCase):
with self.assertRaises(StopAsyncIteration):
self.loop.run_until_complete(consume())
async def test_2():
g1 = ait_class()
self.assertEqual(await anext(g1), 1)
self.assertEqual(await anext(g1), 2)
with self.assertRaises(StopAsyncIteration):
await anext(g1)
with self.assertRaises(StopAsyncIteration):
await anext(g1)
g2 = ait_class()
self.assertEqual(await anext(g2, "default"), 1)
self.assertEqual(await anext(g2, "default"), 2)
self.assertEqual(await anext(g2, "default"), "default")
self.assertEqual(await anext(g2, "default"), "default")
return "completed"
result = self.loop.run_until_complete(test_2())
self.assertEqual(result, "completed")
def test_async_generator_anext(self):
async def agen():
yield 1
yield 2
self.check_async_iterator_anext(agen)
def test_python_async_iterator_anext(self):
class MyAsyncIter:
"""Asynchronously yield 1, then 2."""
def __init__(self):
self.yielded = 0
def __aiter__(self):
return self
async def __anext__(self):
if self.yielded >= 2:
raise StopAsyncIteration()
else:
self.yielded += 1
return self.yielded
self.check_async_iterator_anext(MyAsyncIter)
def test_python_async_iterator_types_coroutine_anext(self):
import types
class MyAsyncIterWithTypesCoro:
"""Asynchronously yield 1, then 2."""
def __init__(self):
self.yielded = 0
def __aiter__(self):
return self
@types.coroutine
def __anext__(self):
if False:
yield "this is a generator-based coroutine"
if self.yielded >= 2:
raise StopAsyncIteration()
else:
self.yielded += 1
return self.yielded
self.check_async_iterator_anext(MyAsyncIterWithTypesCoro)
def test_async_gen_aiter(self):
async def gen():
yield 1
@ -431,12 +488,85 @@ class AsyncGenAsyncioTest(unittest.TestCase):
await anext(gen(), 1, 3)
async def call_with_wrong_type_args():
await anext(1, gen())
async def call_with_kwarg():
await anext(aiterator=gen())
with self.assertRaises(TypeError):
self.loop.run_until_complete(call_with_too_few_args())
with self.assertRaises(TypeError):
self.loop.run_until_complete(call_with_too_many_args())
with self.assertRaises(TypeError):
self.loop.run_until_complete(call_with_wrong_type_args())
with self.assertRaises(TypeError):
self.loop.run_until_complete(call_with_kwarg())
def test_anext_bad_await(self):
async def bad_awaitable():
class BadAwaitable:
def __await__(self):
return 42
class MyAsyncIter:
def __aiter__(self):
return self
def __anext__(self):
return BadAwaitable()
regex = r"__await__.*iterator"
awaitable = anext(MyAsyncIter(), "default")
with self.assertRaisesRegex(TypeError, regex):
await awaitable
awaitable = anext(MyAsyncIter())
with self.assertRaisesRegex(TypeError, regex):
await awaitable
return "completed"
result = self.loop.run_until_complete(bad_awaitable())
self.assertEqual(result, "completed")
async def check_anext_returning_iterator(self, aiter_class):
awaitable = anext(aiter_class(), "default")
with self.assertRaises(TypeError):
await awaitable
awaitable = anext(aiter_class())
with self.assertRaises(TypeError):
await awaitable
return "completed"
def test_anext_return_iterator(self):
class WithIterAnext:
def __aiter__(self):
return self
def __anext__(self):
return iter("abc")
result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithIterAnext))
self.assertEqual(result, "completed")
def test_anext_return_generator(self):
class WithGenAnext:
def __aiter__(self):
return self
def __anext__(self):
yield
result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithGenAnext))
self.assertEqual(result, "completed")
def test_anext_await_raises(self):
class RaisingAwaitable:
def __await__(self):
raise ZeroDivisionError()
yield
class WithRaisingAwaitableAnext:
def __aiter__(self):
return self
def __anext__(self):
return RaisingAwaitable()
async def do_test():
awaitable = anext(WithRaisingAwaitableAnext())
with self.assertRaises(ZeroDivisionError):
await awaitable
awaitable = anext(WithRaisingAwaitableAnext(), "default")
with self.assertRaises(ZeroDivisionError):
await awaitable
return "completed"
result = self.loop.run_until_complete(do_test())
self.assertEqual(result, "completed")
def test_aiter_bad_args(self):
async def gen():

View File

@ -0,0 +1 @@
Fixed a bug where ``anext(ait, default)`` would erroneously return None.

View File

@ -316,7 +316,52 @@ anextawaitable_traverse(anextawaitableobject *obj, visitproc visit, void *arg)
static PyObject *
anextawaitable_iternext(anextawaitableobject *obj)
{
PyObject *result = PyIter_Next(obj->wrapped);
/* Consider the following class:
*
* class A:
* async def __anext__(self):
* ...
* a = A()
*
* Then `await anext(a)` should call
* a.__anext__().__await__().__next__()
*
* On the other hand, given
*
* async def agen():
* yield 1
* yield 2
* gen = agen()
*
* Then `await anext(gen)` can just call
* gen.__anext__().__next__()
*/
assert(obj->wrapped != NULL);
PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped);
if (awaitable == NULL) {
return NULL;
}
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
/* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator,
* or an iterator. Of these, only coroutines lack tp_iternext.
*/
assert(PyCoro_CheckExact(awaitable));
unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await;
PyObject *new_awaitable = getter(awaitable);
if (new_awaitable == NULL) {
Py_DECREF(awaitable);
return NULL;
}
Py_SETREF(awaitable, new_awaitable);
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
PyErr_SetString(PyExc_TypeError,
"__await__ returned a non-iterable");
Py_DECREF(awaitable);
return NULL;
}
}
PyObject *result = (*Py_TYPE(awaitable)->tp_iternext)(awaitable);
Py_DECREF(awaitable);
if (result != NULL) {
return result;
}