bpo-44963: Implement send() and throw() methods for anext_awaitable objects (GH-27955)

Co-authored-by: Yury Selivanov <yury@edgedb.com>
This commit is contained in:
Pablo Galindo Salgado 2021-09-07 11:30:14 +01:00 committed by GitHub
parent 4f88161f07
commit 533e725821
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 270 additions and 22 deletions

View File

@ -1,12 +1,16 @@
import inspect import inspect
import types import types
import unittest import unittest
import contextlib
from test.support.import_helper import import_module from test.support.import_helper import import_module
from test.support import gc_collect from test.support import gc_collect
asyncio = import_module("asyncio") asyncio = import_module("asyncio")
_no_default = object()
class AwaitException(Exception): class AwaitException(Exception):
pass pass
@ -45,6 +49,37 @@ def to_list(gen):
return run_until_complete(iterate()) return run_until_complete(iterate())
def py_anext(iterator, default=_no_default):
"""Pure-Python implementation of anext() for testing purposes.
Closely matches the builtin anext() C implementation.
Can be used to compare the built-in implementation of the inner
coroutines machinery to C-implementation of __anext__() and send()
or throw() on the returned generator.
"""
try:
__anext__ = type(iterator).__anext__
except AttributeError:
raise TypeError(f'{iterator!r} is not an async iterator')
if default is _no_default:
return __anext__(iterator)
async def anext_impl():
try:
# The C code is way more low-level than this, as it implements
# all methods of the iterator protocol. In this implementation
# we're relying on higher-level coroutine concepts, but that's
# exactly what we want -- crosstest pure-Python high-level
# implementation and low-level C anext() iterators.
return await __anext__(iterator)
except StopAsyncIteration:
return default
return anext_impl()
class AsyncGenSyntaxTest(unittest.TestCase): class AsyncGenSyntaxTest(unittest.TestCase):
def test_async_gen_syntax_01(self): def test_async_gen_syntax_01(self):
@ -374,6 +409,12 @@ class AsyncGenAsyncioTest(unittest.TestCase):
asyncio.set_event_loop_policy(None) asyncio.set_event_loop_policy(None)
def check_async_iterator_anext(self, ait_class): def check_async_iterator_anext(self, ait_class):
with self.subTest(anext="pure-Python"):
self._check_async_iterator_anext(ait_class, py_anext)
with self.subTest(anext="builtin"):
self._check_async_iterator_anext(ait_class, anext)
def _check_async_iterator_anext(self, ait_class, anext):
g = ait_class() g = ait_class()
async def consume(): async def consume():
results = [] results = []
@ -406,6 +447,24 @@ class AsyncGenAsyncioTest(unittest.TestCase):
result = self.loop.run_until_complete(test_2()) result = self.loop.run_until_complete(test_2())
self.assertEqual(result, "completed") self.assertEqual(result, "completed")
def test_send():
p = ait_class()
obj = anext(p, "completed")
with self.assertRaises(StopIteration):
with contextlib.closing(obj.__await__()) as g:
g.send(None)
test_send()
async def test_throw():
p = ait_class()
obj = anext(p, "completed")
self.assertRaises(SyntaxError, obj.throw, SyntaxError)
return "completed"
result = self.loop.run_until_complete(test_throw())
self.assertEqual(result, "completed")
def test_async_generator_anext(self): def test_async_generator_anext(self):
async def agen(): async def agen():
yield 1 yield 1
@ -569,6 +628,119 @@ class AsyncGenAsyncioTest(unittest.TestCase):
result = self.loop.run_until_complete(do_test()) result = self.loop.run_until_complete(do_test())
self.assertEqual(result, "completed") self.assertEqual(result, "completed")
def test_anext_iter(self):
@types.coroutine
def _async_yield(v):
return (yield v)
class MyError(Exception):
pass
async def agenfn():
try:
await _async_yield(1)
except MyError:
await _async_yield(2)
return
yield
def test1(anext):
agen = agenfn()
with contextlib.closing(anext(agen, "default").__await__()) as g:
self.assertEqual(g.send(None), 1)
self.assertEqual(g.throw(MyError, MyError(), None), 2)
try:
g.send(None)
except StopIteration as e:
err = e
else:
self.fail('StopIteration was not raised')
self.assertEqual(err.value, "default")
def test2(anext):
agen = agenfn()
with contextlib.closing(anext(agen, "default").__await__()) as g:
self.assertEqual(g.send(None), 1)
self.assertEqual(g.throw(MyError, MyError(), None), 2)
with self.assertRaises(MyError):
g.throw(MyError, MyError(), None)
def test3(anext):
agen = agenfn()
with contextlib.closing(anext(agen, "default").__await__()) as g:
self.assertEqual(g.send(None), 1)
g.close()
with self.assertRaisesRegex(RuntimeError, 'cannot reuse'):
self.assertEqual(g.send(None), 1)
def test4(anext):
@types.coroutine
def _async_yield(v):
yield v * 10
return (yield (v * 10 + 1))
async def agenfn():
try:
await _async_yield(1)
except MyError:
await _async_yield(2)
return
yield
agen = agenfn()
with contextlib.closing(anext(agen, "default").__await__()) as g:
self.assertEqual(g.send(None), 10)
self.assertEqual(g.throw(MyError, MyError(), None), 20)
with self.assertRaisesRegex(MyError, 'val'):
g.throw(MyError, MyError('val'), None)
def test5(anext):
@types.coroutine
def _async_yield(v):
yield v * 10
return (yield (v * 10 + 1))
async def agenfn():
try:
await _async_yield(1)
except MyError:
return
yield 'aaa'
agen = agenfn()
with contextlib.closing(anext(agen, "default").__await__()) as g:
self.assertEqual(g.send(None), 10)
with self.assertRaisesRegex(StopIteration, 'default'):
g.throw(MyError, MyError(), None)
def test6(anext):
@types.coroutine
def _async_yield(v):
yield v * 10
return (yield (v * 10 + 1))
async def agenfn():
await _async_yield(1)
yield 'aaa'
agen = agenfn()
with contextlib.closing(anext(agen, "default").__await__()) as g:
with self.assertRaises(MyError):
g.throw(MyError, MyError(), None)
def run_test(test):
with self.subTest('pure-Python anext()'):
test(py_anext)
with self.subTest('builtin anext()'):
test(anext)
run_test(test1)
run_test(test2)
run_test(test3)
run_test(test4)
run_test(test5)
run_test(test6)
def test_aiter_bad_args(self): def test_aiter_bad_args(self):
async def gen(): async def gen():
yield 1 yield 1

View File

@ -0,0 +1,2 @@
Implement ``send()`` and ``throw()`` methods for ``anext_awaitable``
objects. Patch by Pablo Galindo.

View File

@ -313,6 +313,36 @@ anextawaitable_traverse(anextawaitableobject *obj, visitproc visit, void *arg)
return 0; return 0;
} }
static PyObject *
anextawaitable_getiter(anextawaitableobject *obj)
{
assert(obj->wrapped != NULL);
PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped);
if (awaitable == NULL) {
return NULL;
}
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
/* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator,
* or an iterator. Of these, only coroutines lack tp_iternext.
*/
assert(PyCoro_CheckExact(awaitable));
unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await;
PyObject *new_awaitable = getter(awaitable);
if (new_awaitable == NULL) {
Py_DECREF(awaitable);
return NULL;
}
Py_SETREF(awaitable, new_awaitable);
if (!PyIter_Check(awaitable)) {
PyErr_SetString(PyExc_TypeError,
"__await__ returned a non-iterable");
Py_DECREF(awaitable);
return NULL;
}
}
return awaitable;
}
static PyObject * static PyObject *
anextawaitable_iternext(anextawaitableobject *obj) anextawaitable_iternext(anextawaitableobject *obj)
{ {
@ -336,30 +366,10 @@ anextawaitable_iternext(anextawaitableobject *obj)
* Then `await anext(gen)` can just call * Then `await anext(gen)` can just call
* gen.__anext__().__next__() * gen.__anext__().__next__()
*/ */
assert(obj->wrapped != NULL); PyObject *awaitable = anextawaitable_getiter(obj);
PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped);
if (awaitable == NULL) { if (awaitable == NULL) {
return NULL; return NULL;
} }
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
/* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator,
* or an iterator. Of these, only coroutines lack tp_iternext.
*/
assert(PyCoro_CheckExact(awaitable));
unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await;
PyObject *new_awaitable = getter(awaitable);
if (new_awaitable == NULL) {
Py_DECREF(awaitable);
return NULL;
}
Py_SETREF(awaitable, new_awaitable);
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
PyErr_SetString(PyExc_TypeError,
"__await__ returned a non-iterable");
Py_DECREF(awaitable);
return NULL;
}
}
PyObject *result = (*Py_TYPE(awaitable)->tp_iternext)(awaitable); PyObject *result = (*Py_TYPE(awaitable)->tp_iternext)(awaitable);
Py_DECREF(awaitable); Py_DECREF(awaitable);
if (result != NULL) { if (result != NULL) {
@ -371,6 +381,70 @@ anextawaitable_iternext(anextawaitableobject *obj)
return NULL; return NULL;
} }
static PyObject *
anextawaitable_proxy(anextawaitableobject *obj, char *meth, PyObject *arg) {
PyObject *awaitable = anextawaitable_getiter(obj);
if (awaitable == NULL) {
return NULL;
}
PyObject *ret = PyObject_CallMethod(awaitable, meth, "O", arg);
Py_DECREF(awaitable);
if (ret != NULL) {
return ret;
}
if (PyErr_ExceptionMatches(PyExc_StopAsyncIteration)) {
/* `anextawaitableobject` is only used by `anext()` when
* a default value is provided. So when we have a StopAsyncIteration
* exception we replace it with a `StopIteration(default)`, as if
* it was the return value of `__anext__()` coroutine.
*/
_PyGen_SetStopIterationValue(obj->default_value);
}
return NULL;
}
static PyObject *
anextawaitable_send(anextawaitableobject *obj, PyObject *arg) {
return anextawaitable_proxy(obj, "send", arg);
}
static PyObject *
anextawaitable_throw(anextawaitableobject *obj, PyObject *arg) {
return anextawaitable_proxy(obj, "throw", arg);
}
static PyObject *
anextawaitable_close(anextawaitableobject *obj, PyObject *arg) {
return anextawaitable_proxy(obj, "close", arg);
}
PyDoc_STRVAR(send_doc,
"send(arg) -> send 'arg' into the wrapped iterator,\n\
return next yielded value or raise StopIteration.");
PyDoc_STRVAR(throw_doc,
"throw(typ[,val[,tb]]) -> raise exception in the wrapped iterator,\n\
return next yielded value or raise StopIteration.");
PyDoc_STRVAR(close_doc,
"close() -> raise GeneratorExit inside generator.");
static PyMethodDef anextawaitable_methods[] = {
{"send",(PyCFunction)anextawaitable_send, METH_O, send_doc},
{"throw",(PyCFunction)anextawaitable_throw, METH_VARARGS, throw_doc},
{"close",(PyCFunction)anextawaitable_close, METH_VARARGS, close_doc},
{NULL, NULL} /* Sentinel */
};
static PyAsyncMethods anextawaitable_as_async = { static PyAsyncMethods anextawaitable_as_async = {
PyObject_SelfIter, /* am_await */ PyObject_SelfIter, /* am_await */
0, /* am_aiter */ 0, /* am_aiter */
@ -407,7 +481,7 @@ PyTypeObject _PyAnextAwaitable_Type = {
0, /* tp_weaklistoffset */ 0, /* tp_weaklistoffset */
PyObject_SelfIter, /* tp_iter */ PyObject_SelfIter, /* tp_iter */
(unaryfunc)anextawaitable_iternext, /* tp_iternext */ (unaryfunc)anextawaitable_iternext, /* tp_iternext */
0, /* tp_methods */ anextawaitable_methods, /* tp_methods */
}; };
PyObject * PyObject *