mirror of https://github.com/python/cpython
bpo-44963: Implement send() and throw() methods for anext_awaitable objects (GH-27955)
Co-authored-by: Yury Selivanov <yury@edgedb.com>
This commit is contained in:
parent
4f88161f07
commit
533e725821
|
@ -1,12 +1,16 @@
|
|||
import inspect
|
||||
import types
|
||||
import unittest
|
||||
import contextlib
|
||||
|
||||
from test.support.import_helper import import_module
|
||||
from test.support import gc_collect
|
||||
asyncio = import_module("asyncio")
|
||||
|
||||
|
||||
_no_default = object()
|
||||
|
||||
|
||||
class AwaitException(Exception):
|
||||
pass
|
||||
|
||||
|
@ -45,6 +49,37 @@ def to_list(gen):
|
|||
return run_until_complete(iterate())
|
||||
|
||||
|
||||
def py_anext(iterator, default=_no_default):
|
||||
"""Pure-Python implementation of anext() for testing purposes.
|
||||
|
||||
Closely matches the builtin anext() C implementation.
|
||||
Can be used to compare the built-in implementation of the inner
|
||||
coroutines machinery to C-implementation of __anext__() and send()
|
||||
or throw() on the returned generator.
|
||||
"""
|
||||
|
||||
try:
|
||||
__anext__ = type(iterator).__anext__
|
||||
except AttributeError:
|
||||
raise TypeError(f'{iterator!r} is not an async iterator')
|
||||
|
||||
if default is _no_default:
|
||||
return __anext__(iterator)
|
||||
|
||||
async def anext_impl():
|
||||
try:
|
||||
# The C code is way more low-level than this, as it implements
|
||||
# all methods of the iterator protocol. In this implementation
|
||||
# we're relying on higher-level coroutine concepts, but that's
|
||||
# exactly what we want -- crosstest pure-Python high-level
|
||||
# implementation and low-level C anext() iterators.
|
||||
return await __anext__(iterator)
|
||||
except StopAsyncIteration:
|
||||
return default
|
||||
|
||||
return anext_impl()
|
||||
|
||||
|
||||
class AsyncGenSyntaxTest(unittest.TestCase):
|
||||
|
||||
def test_async_gen_syntax_01(self):
|
||||
|
@ -374,6 +409,12 @@ class AsyncGenAsyncioTest(unittest.TestCase):
|
|||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
def check_async_iterator_anext(self, ait_class):
|
||||
with self.subTest(anext="pure-Python"):
|
||||
self._check_async_iterator_anext(ait_class, py_anext)
|
||||
with self.subTest(anext="builtin"):
|
||||
self._check_async_iterator_anext(ait_class, anext)
|
||||
|
||||
def _check_async_iterator_anext(self, ait_class, anext):
|
||||
g = ait_class()
|
||||
async def consume():
|
||||
results = []
|
||||
|
@ -406,6 +447,24 @@ class AsyncGenAsyncioTest(unittest.TestCase):
|
|||
result = self.loop.run_until_complete(test_2())
|
||||
self.assertEqual(result, "completed")
|
||||
|
||||
def test_send():
|
||||
p = ait_class()
|
||||
obj = anext(p, "completed")
|
||||
with self.assertRaises(StopIteration):
|
||||
with contextlib.closing(obj.__await__()) as g:
|
||||
g.send(None)
|
||||
|
||||
test_send()
|
||||
|
||||
async def test_throw():
|
||||
p = ait_class()
|
||||
obj = anext(p, "completed")
|
||||
self.assertRaises(SyntaxError, obj.throw, SyntaxError)
|
||||
return "completed"
|
||||
|
||||
result = self.loop.run_until_complete(test_throw())
|
||||
self.assertEqual(result, "completed")
|
||||
|
||||
def test_async_generator_anext(self):
|
||||
async def agen():
|
||||
yield 1
|
||||
|
@ -569,6 +628,119 @@ class AsyncGenAsyncioTest(unittest.TestCase):
|
|||
result = self.loop.run_until_complete(do_test())
|
||||
self.assertEqual(result, "completed")
|
||||
|
||||
def test_anext_iter(self):
|
||||
@types.coroutine
|
||||
def _async_yield(v):
|
||||
return (yield v)
|
||||
|
||||
class MyError(Exception):
|
||||
pass
|
||||
|
||||
async def agenfn():
|
||||
try:
|
||||
await _async_yield(1)
|
||||
except MyError:
|
||||
await _async_yield(2)
|
||||
return
|
||||
yield
|
||||
|
||||
def test1(anext):
|
||||
agen = agenfn()
|
||||
with contextlib.closing(anext(agen, "default").__await__()) as g:
|
||||
self.assertEqual(g.send(None), 1)
|
||||
self.assertEqual(g.throw(MyError, MyError(), None), 2)
|
||||
try:
|
||||
g.send(None)
|
||||
except StopIteration as e:
|
||||
err = e
|
||||
else:
|
||||
self.fail('StopIteration was not raised')
|
||||
self.assertEqual(err.value, "default")
|
||||
|
||||
def test2(anext):
|
||||
agen = agenfn()
|
||||
with contextlib.closing(anext(agen, "default").__await__()) as g:
|
||||
self.assertEqual(g.send(None), 1)
|
||||
self.assertEqual(g.throw(MyError, MyError(), None), 2)
|
||||
with self.assertRaises(MyError):
|
||||
g.throw(MyError, MyError(), None)
|
||||
|
||||
def test3(anext):
|
||||
agen = agenfn()
|
||||
with contextlib.closing(anext(agen, "default").__await__()) as g:
|
||||
self.assertEqual(g.send(None), 1)
|
||||
g.close()
|
||||
with self.assertRaisesRegex(RuntimeError, 'cannot reuse'):
|
||||
self.assertEqual(g.send(None), 1)
|
||||
|
||||
def test4(anext):
|
||||
@types.coroutine
|
||||
def _async_yield(v):
|
||||
yield v * 10
|
||||
return (yield (v * 10 + 1))
|
||||
|
||||
async def agenfn():
|
||||
try:
|
||||
await _async_yield(1)
|
||||
except MyError:
|
||||
await _async_yield(2)
|
||||
return
|
||||
yield
|
||||
|
||||
agen = agenfn()
|
||||
with contextlib.closing(anext(agen, "default").__await__()) as g:
|
||||
self.assertEqual(g.send(None), 10)
|
||||
self.assertEqual(g.throw(MyError, MyError(), None), 20)
|
||||
with self.assertRaisesRegex(MyError, 'val'):
|
||||
g.throw(MyError, MyError('val'), None)
|
||||
|
||||
def test5(anext):
|
||||
@types.coroutine
|
||||
def _async_yield(v):
|
||||
yield v * 10
|
||||
return (yield (v * 10 + 1))
|
||||
|
||||
async def agenfn():
|
||||
try:
|
||||
await _async_yield(1)
|
||||
except MyError:
|
||||
return
|
||||
yield 'aaa'
|
||||
|
||||
agen = agenfn()
|
||||
with contextlib.closing(anext(agen, "default").__await__()) as g:
|
||||
self.assertEqual(g.send(None), 10)
|
||||
with self.assertRaisesRegex(StopIteration, 'default'):
|
||||
g.throw(MyError, MyError(), None)
|
||||
|
||||
def test6(anext):
|
||||
@types.coroutine
|
||||
def _async_yield(v):
|
||||
yield v * 10
|
||||
return (yield (v * 10 + 1))
|
||||
|
||||
async def agenfn():
|
||||
await _async_yield(1)
|
||||
yield 'aaa'
|
||||
|
||||
agen = agenfn()
|
||||
with contextlib.closing(anext(agen, "default").__await__()) as g:
|
||||
with self.assertRaises(MyError):
|
||||
g.throw(MyError, MyError(), None)
|
||||
|
||||
def run_test(test):
|
||||
with self.subTest('pure-Python anext()'):
|
||||
test(py_anext)
|
||||
with self.subTest('builtin anext()'):
|
||||
test(anext)
|
||||
|
||||
run_test(test1)
|
||||
run_test(test2)
|
||||
run_test(test3)
|
||||
run_test(test4)
|
||||
run_test(test5)
|
||||
run_test(test6)
|
||||
|
||||
def test_aiter_bad_args(self):
|
||||
async def gen():
|
||||
yield 1
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
Implement ``send()`` and ``throw()`` methods for ``anext_awaitable``
|
||||
objects. Patch by Pablo Galindo.
|
|
@ -313,6 +313,36 @@ anextawaitable_traverse(anextawaitableobject *obj, visitproc visit, void *arg)
|
|||
return 0;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
anextawaitable_getiter(anextawaitableobject *obj)
|
||||
{
|
||||
assert(obj->wrapped != NULL);
|
||||
PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped);
|
||||
if (awaitable == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
|
||||
/* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator,
|
||||
* or an iterator. Of these, only coroutines lack tp_iternext.
|
||||
*/
|
||||
assert(PyCoro_CheckExact(awaitable));
|
||||
unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await;
|
||||
PyObject *new_awaitable = getter(awaitable);
|
||||
if (new_awaitable == NULL) {
|
||||
Py_DECREF(awaitable);
|
||||
return NULL;
|
||||
}
|
||||
Py_SETREF(awaitable, new_awaitable);
|
||||
if (!PyIter_Check(awaitable)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"__await__ returned a non-iterable");
|
||||
Py_DECREF(awaitable);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
return awaitable;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
anextawaitable_iternext(anextawaitableobject *obj)
|
||||
{
|
||||
|
@ -336,30 +366,10 @@ anextawaitable_iternext(anextawaitableobject *obj)
|
|||
* Then `await anext(gen)` can just call
|
||||
* gen.__anext__().__next__()
|
||||
*/
|
||||
assert(obj->wrapped != NULL);
|
||||
PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped);
|
||||
PyObject *awaitable = anextawaitable_getiter(obj);
|
||||
if (awaitable == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
|
||||
/* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator,
|
||||
* or an iterator. Of these, only coroutines lack tp_iternext.
|
||||
*/
|
||||
assert(PyCoro_CheckExact(awaitable));
|
||||
unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await;
|
||||
PyObject *new_awaitable = getter(awaitable);
|
||||
if (new_awaitable == NULL) {
|
||||
Py_DECREF(awaitable);
|
||||
return NULL;
|
||||
}
|
||||
Py_SETREF(awaitable, new_awaitable);
|
||||
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
"__await__ returned a non-iterable");
|
||||
Py_DECREF(awaitable);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
PyObject *result = (*Py_TYPE(awaitable)->tp_iternext)(awaitable);
|
||||
Py_DECREF(awaitable);
|
||||
if (result != NULL) {
|
||||
|
@ -371,6 +381,70 @@ anextawaitable_iternext(anextawaitableobject *obj)
|
|||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
static PyObject *
|
||||
anextawaitable_proxy(anextawaitableobject *obj, char *meth, PyObject *arg) {
|
||||
PyObject *awaitable = anextawaitable_getiter(obj);
|
||||
if (awaitable == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
PyObject *ret = PyObject_CallMethod(awaitable, meth, "O", arg);
|
||||
Py_DECREF(awaitable);
|
||||
if (ret != NULL) {
|
||||
return ret;
|
||||
}
|
||||
if (PyErr_ExceptionMatches(PyExc_StopAsyncIteration)) {
|
||||
/* `anextawaitableobject` is only used by `anext()` when
|
||||
* a default value is provided. So when we have a StopAsyncIteration
|
||||
* exception we replace it with a `StopIteration(default)`, as if
|
||||
* it was the return value of `__anext__()` coroutine.
|
||||
*/
|
||||
_PyGen_SetStopIterationValue(obj->default_value);
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
static PyObject *
|
||||
anextawaitable_send(anextawaitableobject *obj, PyObject *arg) {
|
||||
return anextawaitable_proxy(obj, "send", arg);
|
||||
}
|
||||
|
||||
|
||||
static PyObject *
|
||||
anextawaitable_throw(anextawaitableobject *obj, PyObject *arg) {
|
||||
return anextawaitable_proxy(obj, "throw", arg);
|
||||
}
|
||||
|
||||
|
||||
static PyObject *
|
||||
anextawaitable_close(anextawaitableobject *obj, PyObject *arg) {
|
||||
return anextawaitable_proxy(obj, "close", arg);
|
||||
}
|
||||
|
||||
|
||||
PyDoc_STRVAR(send_doc,
|
||||
"send(arg) -> send 'arg' into the wrapped iterator,\n\
|
||||
return next yielded value or raise StopIteration.");
|
||||
|
||||
|
||||
PyDoc_STRVAR(throw_doc,
|
||||
"throw(typ[,val[,tb]]) -> raise exception in the wrapped iterator,\n\
|
||||
return next yielded value or raise StopIteration.");
|
||||
|
||||
|
||||
PyDoc_STRVAR(close_doc,
|
||||
"close() -> raise GeneratorExit inside generator.");
|
||||
|
||||
|
||||
static PyMethodDef anextawaitable_methods[] = {
|
||||
{"send",(PyCFunction)anextawaitable_send, METH_O, send_doc},
|
||||
{"throw",(PyCFunction)anextawaitable_throw, METH_VARARGS, throw_doc},
|
||||
{"close",(PyCFunction)anextawaitable_close, METH_VARARGS, close_doc},
|
||||
{NULL, NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
|
||||
static PyAsyncMethods anextawaitable_as_async = {
|
||||
PyObject_SelfIter, /* am_await */
|
||||
0, /* am_aiter */
|
||||
|
@ -407,7 +481,7 @@ PyTypeObject _PyAnextAwaitable_Type = {
|
|||
0, /* tp_weaklistoffset */
|
||||
PyObject_SelfIter, /* tp_iter */
|
||||
(unaryfunc)anextawaitable_iternext, /* tp_iternext */
|
||||
0, /* tp_methods */
|
||||
anextawaitable_methods, /* tp_methods */
|
||||
};
|
||||
|
||||
PyObject *
|
||||
|
|
Loading…
Reference in New Issue