mirror of https://github.com/python/cpython
bpo-43751: Fix anext() bug where it erroneously returned None (GH-25238)
This commit is contained in:
parent
9045919bfa
commit
dfb45323ce
|
@ -372,11 +372,8 @@ class AsyncGenAsyncioTest(unittest.TestCase):
|
|||
self.loop = None
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
def test_async_gen_anext(self):
|
||||
async def gen():
|
||||
yield 1
|
||||
yield 2
|
||||
g = gen()
|
||||
def check_async_iterator_anext(self, ait_class):
|
||||
g = ait_class()
|
||||
async def consume():
|
||||
results = []
|
||||
results.append(await anext(g))
|
||||
|
@ -388,6 +385,66 @@ class AsyncGenAsyncioTest(unittest.TestCase):
|
|||
with self.assertRaises(StopAsyncIteration):
|
||||
self.loop.run_until_complete(consume())
|
||||
|
||||
async def test_2():
|
||||
g1 = ait_class()
|
||||
self.assertEqual(await anext(g1), 1)
|
||||
self.assertEqual(await anext(g1), 2)
|
||||
with self.assertRaises(StopAsyncIteration):
|
||||
await anext(g1)
|
||||
with self.assertRaises(StopAsyncIteration):
|
||||
await anext(g1)
|
||||
|
||||
g2 = ait_class()
|
||||
self.assertEqual(await anext(g2, "default"), 1)
|
||||
self.assertEqual(await anext(g2, "default"), 2)
|
||||
self.assertEqual(await anext(g2, "default"), "default")
|
||||
self.assertEqual(await anext(g2, "default"), "default")
|
||||
|
||||
return "completed"
|
||||
|
||||
result = self.loop.run_until_complete(test_2())
|
||||
self.assertEqual(result, "completed")
|
||||
|
||||
def test_async_generator_anext(self):
|
||||
async def agen():
|
||||
yield 1
|
||||
yield 2
|
||||
self.check_async_iterator_anext(agen)
|
||||
|
||||
def test_python_async_iterator_anext(self):
|
||||
class MyAsyncIter:
|
||||
"""Asynchronously yield 1, then 2."""
|
||||
def __init__(self):
|
||||
self.yielded = 0
|
||||
def __aiter__(self):
|
||||
return self
|
||||
async def __anext__(self):
|
||||
if self.yielded >= 2:
|
||||
raise StopAsyncIteration()
|
||||
else:
|
||||
self.yielded += 1
|
||||
return self.yielded
|
||||
self.check_async_iterator_anext(MyAsyncIter)
|
||||
|
||||
def test_python_async_iterator_types_coroutine_anext(self):
|
||||
import types
|
||||
class MyAsyncIterWithTypesCoro:
|
||||
"""Asynchronously yield 1, then 2."""
|
||||
def __init__(self):
|
||||
self.yielded = 0
|
||||
def __aiter__(self):
|
||||
return self
|
||||
@types.coroutine
|
||||
def __anext__(self):
|
||||
if False:
|
||||
yield "this is a generator-based coroutine"
|
||||
if self.yielded >= 2:
|
||||
raise StopAsyncIteration()
|
||||
else:
|
||||
self.yielded += 1
|
||||
return self.yielded
|
||||
self.check_async_iterator_anext(MyAsyncIterWithTypesCoro)
|
||||
|
||||
def test_async_gen_aiter(self):
|
||||
async def gen():
|
||||
yield 1
|
||||
|
@ -431,12 +488,85 @@ class AsyncGenAsyncioTest(unittest.TestCase):
|
|||
await anext(gen(), 1, 3)
|
||||
async def call_with_wrong_type_args():
|
||||
await anext(1, gen())
|
||||
async def call_with_kwarg():
|
||||
await anext(aiterator=gen())
|
||||
with self.assertRaises(TypeError):
|
||||
self.loop.run_until_complete(call_with_too_few_args())
|
||||
with self.assertRaises(TypeError):
|
||||
self.loop.run_until_complete(call_with_too_many_args())
|
||||
with self.assertRaises(TypeError):
|
||||
self.loop.run_until_complete(call_with_wrong_type_args())
|
||||
with self.assertRaises(TypeError):
|
||||
self.loop.run_until_complete(call_with_kwarg())
|
||||
|
||||
def test_anext_bad_await(self):
|
||||
async def bad_awaitable():
|
||||
class BadAwaitable:
|
||||
def __await__(self):
|
||||
return 42
|
||||
class MyAsyncIter:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
def __anext__(self):
|
||||
return BadAwaitable()
|
||||
regex = r"__await__.*iterator"
|
||||
awaitable = anext(MyAsyncIter(), "default")
|
||||
with self.assertRaisesRegex(TypeError, regex):
|
||||
await awaitable
|
||||
awaitable = anext(MyAsyncIter())
|
||||
with self.assertRaisesRegex(TypeError, regex):
|
||||
await awaitable
|
||||
return "completed"
|
||||
result = self.loop.run_until_complete(bad_awaitable())
|
||||
self.assertEqual(result, "completed")
|
||||
|
||||
async def check_anext_returning_iterator(self, aiter_class):
|
||||
awaitable = anext(aiter_class(), "default")
|
||||
with self.assertRaises(TypeError):
|
||||
await awaitable
|
||||
awaitable = anext(aiter_class())
|
||||
with self.assertRaises(TypeError):
|
||||
await awaitable
|
||||
return "completed"
|
||||
|
||||
def test_anext_return_iterator(self):
|
||||
class WithIterAnext:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
def __anext__(self):
|
||||
return iter("abc")
|
||||
result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithIterAnext))
|
||||
self.assertEqual(result, "completed")
|
||||
|
||||
def test_anext_return_generator(self):
|
||||
class WithGenAnext:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
def __anext__(self):
|
||||
yield
|
||||
result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithGenAnext))
|
||||
self.assertEqual(result, "completed")
|
||||
|
||||
def test_anext_await_raises(self):
|
||||
class RaisingAwaitable:
|
||||
def __await__(self):
|
||||
raise ZeroDivisionError()
|
||||
yield
|
||||
class WithRaisingAwaitableAnext:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
def __anext__(self):
|
||||
return RaisingAwaitable()
|
||||
async def do_test():
|
||||
awaitable = anext(WithRaisingAwaitableAnext())
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
await awaitable
|
||||
awaitable = anext(WithRaisingAwaitableAnext(), "default")
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
await awaitable
|
||||
return "completed"
|
||||
result = self.loop.run_until_complete(do_test())
|
||||
self.assertEqual(result, "completed")
|
||||
|
||||
def test_aiter_bad_args(self):
|
||||
async def gen():
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Fixed a bug where ``anext(ait, default)`` would erroneously return None.
|
|
@ -316,7 +316,52 @@ anextawaitable_traverse(anextawaitableobject *obj, visitproc visit, void *arg)
|
|||
static PyObject *
|
||||
anextawaitable_iternext(anextawaitableobject *obj)
|
||||
{
|
||||
PyObject *result = PyIter_Next(obj->wrapped);
|
||||
/* Consider the following class:
|
||||
*
|
||||
* class A:
|
||||
* async def __anext__(self):
|
||||
* ...
|
||||
* a = A()
|
||||
*
|
||||
* Then `await anext(a)` should call
|
||||
* a.__anext__().__await__().__next__()
|
||||
*
|
||||
* On the other hand, given
|
||||
*
|
||||
* async def agen():
|
||||
* yield 1
|
||||
* yield 2
|
||||
* gen = agen()
|
||||
*
|
||||
* Then `await anext(gen)` can just call
|
||||
* gen.__anext__().__next__()
|
||||
*/
|
||||
assert(obj->wrapped != NULL);
|
||||
PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped);
|
||||
if (awaitable == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
|
||||
/* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator,
|
||||
* or an iterator. Of these, only coroutines lack tp_iternext.
|
||||
*/
|
||||
assert(PyCoro_CheckExact(awaitable));
|
||||
unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await;
|
||||
PyObject *new_awaitable = getter(awaitable);
|
||||
if (new_awaitable == NULL) {
|
||||
Py_DECREF(awaitable);
|
||||
return NULL;
|
||||
}
|
||||
Py_SETREF(awaitable, new_awaitable);
|
||||
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"__await__ returned a non-iterable");
|
||||
Py_DECREF(awaitable);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
PyObject *result = (*Py_TYPE(awaitable)->tp_iternext)(awaitable);
|
||||
Py_DECREF(awaitable);
|
||||
if (result != NULL) {
|
||||
return result;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue