Issue #23996: Added _PyGen_SetStopIterationValue for safe raising

StopIteration with value. More safely handle non-normalized exceptions
in -_PyGen_FetchStopIterationValue.
This commit is contained in:
Serhiy Storchaka 2016-11-06 18:47:03 +02:00
commit 60e49aa756
7 changed files with 281 additions and 68 deletions

View File

@ -41,6 +41,7 @@ PyAPI_FUNC(PyObject *) PyGen_New(struct _frame *);
PyAPI_FUNC(PyObject *) PyGen_NewWithQualName(struct _frame *, PyAPI_FUNC(PyObject *) PyGen_NewWithQualName(struct _frame *,
PyObject *name, PyObject *qualname); PyObject *name, PyObject *qualname);
PyAPI_FUNC(int) PyGen_NeedsFinalizing(PyGenObject *); PyAPI_FUNC(int) PyGen_NeedsFinalizing(PyGenObject *);
PyAPI_FUNC(int) _PyGen_SetStopIterationValue(PyObject *);
PyAPI_FUNC(int) _PyGen_FetchStopIterationValue(PyObject **); PyAPI_FUNC(int) _PyGen_FetchStopIterationValue(PyObject **);
PyAPI_FUNC(PyObject *) _PyGen_Send(PyGenObject *, PyObject *); PyAPI_FUNC(PyObject *) _PyGen_Send(PyGenObject *, PyObject *);
PyObject *_PyGen_yf(PyGenObject *); PyObject *_PyGen_yf(PyGenObject *);

View File

@ -450,6 +450,48 @@ class AsyncGenAsyncioTest(unittest.TestCase):
self.loop.run_until_complete(run()) self.loop.run_until_complete(run())
def test_async_gen_asyncio_anext_tuple(self):
async def foo():
try:
yield (1,)
except ZeroDivisionError:
yield (2,)
async def run():
it = foo().__aiter__()
self.assertEqual(await it.__anext__(), (1,))
with self.assertRaises(StopIteration) as cm:
it.__anext__().throw(ZeroDivisionError)
self.assertEqual(cm.exception.args[0], (2,))
with self.assertRaises(StopAsyncIteration):
await it.__anext__()
self.loop.run_until_complete(run())
def test_async_gen_asyncio_anext_stopiteration(self):
async def foo():
try:
yield StopIteration(1)
except ZeroDivisionError:
yield StopIteration(3)
async def run():
it = foo().__aiter__()
v = await it.__anext__()
self.assertIsInstance(v, StopIteration)
self.assertEqual(v.value, 1)
with self.assertRaises(StopIteration) as cm:
it.__anext__().throw(ZeroDivisionError)
v = cm.exception.args[0]
self.assertIsInstance(v, StopIteration)
self.assertEqual(v.value, 3)
with self.assertRaises(StopAsyncIteration):
await it.__anext__()
self.loop.run_until_complete(run())
def test_async_gen_asyncio_aclose_06(self): def test_async_gen_asyncio_aclose_06(self):
async def foo(): async def foo():
try: try:
@ -759,6 +801,43 @@ class AsyncGenAsyncioTest(unittest.TestCase):
self.loop.run_until_complete(run()) self.loop.run_until_complete(run())
self.assertEqual(DONE, 1) self.assertEqual(DONE, 1)
def test_async_gen_asyncio_athrow_tuple(self):
async def gen():
try:
yield 1
except ZeroDivisionError:
yield (2,)
async def run():
g = gen()
v = await g.asend(None)
self.assertEqual(v, 1)
v = await g.athrow(ZeroDivisionError)
self.assertEqual(v, (2,))
with self.assertRaises(StopAsyncIteration):
await g.asend(None)
self.loop.run_until_complete(run())
def test_async_gen_asyncio_athrow_stopiteration(self):
async def gen():
try:
yield 1
except ZeroDivisionError:
yield StopIteration(2)
async def run():
g = gen()
v = await g.asend(None)
self.assertEqual(v, 1)
v = await g.athrow(ZeroDivisionError)
self.assertIsInstance(v, StopIteration)
self.assertEqual(v.value, 2)
with self.assertRaises(StopAsyncIteration):
await g.asend(None)
self.loop.run_until_complete(run())
def test_async_gen_asyncio_shutdown_01(self): def test_async_gen_asyncio_shutdown_01(self):
finalized = 0 finalized = 0

View File

@ -838,6 +838,21 @@ class CoroutineTest(unittest.TestCase):
coro.close() coro.close()
self.assertEqual(CHK, 1) self.assertEqual(CHK, 1)
def test_coro_wrapper_send_tuple(self):
async def foo():
return (10,)
result = run_async__await__(foo())
self.assertEqual(result, ([], (10,)))
def test_coro_wrapper_send_stop_iterator(self):
async def foo():
return StopIteration(10)
result = run_async__await__(foo())
self.assertIsInstance(result[1], StopIteration)
self.assertEqual(result[1].value, 10)
def test_cr_await(self): def test_cr_await(self):
@types.coroutine @types.coroutine
def a(): def a():
@ -1665,6 +1680,52 @@ class CoroutineTest(unittest.TestCase):
warnings.simplefilter("error") warnings.simplefilter("error")
run_async(foo()) run_async(foo())
def test_for_tuple(self):
class Done(Exception): pass
class AIter(tuple):
i = 0
def __aiter__(self):
return self
async def __anext__(self):
if self.i >= len(self):
raise StopAsyncIteration
self.i += 1
return self[self.i - 1]
result = []
async def foo():
async for i in AIter([42]):
result.append(i)
raise Done
with self.assertRaises(Done):
foo().send(None)
self.assertEqual(result, [42])
def test_for_stop_iteration(self):
class Done(Exception): pass
class AIter(StopIteration):
i = 0
def __aiter__(self):
return self
async def __anext__(self):
if self.i:
raise StopAsyncIteration
self.i += 1
return self.value
result = []
async def foo():
async for i in AIter(42):
result.append(i)
raise Done
with self.assertRaises(Done):
foo().send(None)
self.assertEqual(result, [42])
def test_comp_1(self): def test_comp_1(self):
async def f(i): async def f(i):
return i return i

View File

@ -277,6 +277,27 @@ class ExceptionTest(unittest.TestCase):
# hence no warning. # hence no warning.
next(g) next(g)
def test_return_tuple(self):
def g():
return (yield 1)
gen = g()
self.assertEqual(next(gen), 1)
with self.assertRaises(StopIteration) as cm:
gen.send((2,))
self.assertEqual(cm.exception.value, (2,))
def test_return_stopiteration(self):
def g():
return (yield 1)
gen = g()
self.assertEqual(next(gen), 1)
with self.assertRaises(StopIteration) as cm:
gen.send(StopIteration(2))
self.assertIsInstance(cm.exception.value, StopIteration)
self.assertEqual(cm.exception.value.value, 2)
class YieldFromTests(unittest.TestCase): class YieldFromTests(unittest.TestCase):
def test_generator_gi_yieldfrom(self): def test_generator_gi_yieldfrom(self):

View File

@ -384,9 +384,10 @@ class TestPEP380Operation(unittest.TestCase):
trace.append("Starting g1") trace.append("Starting g1")
yield "g1 ham" yield "g1 ham"
ret = yield from g2() ret = yield from g2()
trace.append("g2 returned %s" % (ret,)) trace.append("g2 returned %r" % (ret,))
ret = yield from g2(42) for v in 1, (2,), StopIteration(3):
trace.append("g2 returned %s" % (ret,)) ret = yield from g2(v)
trace.append("g2 returned %r" % (ret,))
yield "g1 eggs" yield "g1 eggs"
trace.append("Finishing g1") trace.append("Finishing g1")
def g2(v = None): def g2(v = None):
@ -410,7 +411,17 @@ class TestPEP380Operation(unittest.TestCase):
"Yielded g2 spam", "Yielded g2 spam",
"Yielded g2 more spam", "Yielded g2 more spam",
"Finishing g2", "Finishing g2",
"g2 returned 42", "g2 returned 1",
"Starting g2",
"Yielded g2 spam",
"Yielded g2 more spam",
"Finishing g2",
"g2 returned (2,)",
"Starting g2",
"Yielded g2 spam",
"Yielded g2 more spam",
"Finishing g2",
"g2 returned StopIteration(3,)",
"Yielded g1 eggs", "Yielded g1 eggs",
"Finishing g1", "Finishing g1",
]) ])
@ -670,14 +681,16 @@ class TestPEP380Operation(unittest.TestCase):
next(gi) next(gi)
trace.append("f SHOULD NOT BE HERE") trace.append("f SHOULD NOT BE HERE")
except StopIteration as e: except StopIteration as e:
trace.append("f caught %s" % (repr(e),)) trace.append("f caught %r" % (e,))
def g(r): def g(r):
trace.append("g starting") trace.append("g starting")
yield yield
trace.append("g returning %s" % (r,)) trace.append("g returning %r" % (r,))
return r return r
f(None) f(None)
f(42) f(1)
f((2,))
f(StopIteration(3))
self.assertEqual(trace,[ self.assertEqual(trace,[
"g starting", "g starting",
"f resuming g", "f resuming g",
@ -685,8 +698,16 @@ class TestPEP380Operation(unittest.TestCase):
"f caught StopIteration()", "f caught StopIteration()",
"g starting", "g starting",
"f resuming g", "f resuming g",
"g returning 42", "g returning 1",
"f caught StopIteration(42,)", "f caught StopIteration(1,)",
"g starting",
"f resuming g",
"g returning (2,)",
"f caught StopIteration((2,),)",
"g starting",
"f resuming g",
"g returning StopIteration(3,)",
"f caught StopIteration(StopIteration(3,),)",
]) ])
def test_send_and_return_with_value(self): def test_send_and_return_with_value(self):
@ -706,22 +727,34 @@ class TestPEP380Operation(unittest.TestCase):
def g(r): def g(r):
trace.append("g starting") trace.append("g starting")
x = yield x = yield
trace.append("g received %s" % (x,)) trace.append("g received %r" % (x,))
trace.append("g returning %s" % (r,)) trace.append("g returning %r" % (r,))
return r return r
f(None) f(None)
f(42) f(1)
self.assertEqual(trace,[ f((2,))
f(StopIteration(3))
self.assertEqual(trace, [
"g starting", "g starting",
"f sending spam to g", "f sending spam to g",
"g received spam", "g received 'spam'",
"g returning None", "g returning None",
"f caught StopIteration()", "f caught StopIteration()",
"g starting", "g starting",
"f sending spam to g", "f sending spam to g",
"g received spam", "g received 'spam'",
"g returning 42", "g returning 1",
"f caught StopIteration(42,)", 'f caught StopIteration(1,)',
'g starting',
'f sending spam to g',
"g received 'spam'",
'g returning (2,)',
'f caught StopIteration((2,),)',
'g starting',
'f sending spam to g',
"g received 'spam'",
'g returning StopIteration(3,)',
'f caught StopIteration(StopIteration(3,),)'
]) ])
def test_catching_exception_from_subgen_and_returning(self): def test_catching_exception_from_subgen_and_returning(self):
@ -729,27 +762,29 @@ class TestPEP380Operation(unittest.TestCase):
Test catching an exception thrown into a Test catching an exception thrown into a
subgenerator and returning a value subgenerator and returning a value
""" """
trace = []
def inner(): def inner():
try: try:
yield 1 yield 1
except ValueError: except ValueError:
trace.append("inner caught ValueError") trace.append("inner caught ValueError")
return 2 return value
def outer(): def outer():
v = yield from inner() v = yield from inner()
trace.append("inner returned %r to outer" % v) trace.append("inner returned %r to outer" % (v,))
yield v yield v
g = outer()
trace.append(next(g)) for value in 2, (2,), StopIteration(2):
trace.append(g.throw(ValueError)) trace = []
self.assertEqual(trace,[ g = outer()
1, trace.append(next(g))
"inner caught ValueError", trace.append(repr(g.throw(ValueError)))
"inner returned 2 to outer", self.assertEqual(trace, [
2, 1,
]) "inner caught ValueError",
"inner returned %r to outer" % (value,),
repr(value),
])
def test_throwing_GeneratorExit_into_subgen_that_returns(self): def test_throwing_GeneratorExit_into_subgen_that_returns(self):
""" """

View File

@ -997,26 +997,12 @@ FutureIter_iternext(futureiterobject *it)
res = _asyncio_Future_result_impl(fut); res = _asyncio_Future_result_impl(fut);
if (res != NULL) { if (res != NULL) {
/* The result of the Future is not an exception. /* The result of the Future is not an exception. */
if (_PyGen_SetStopIterationValue(res) < 0) {
We construct an exception instance manually with Py_DECREF(res);
PyObject_CallFunctionObjArgs and pass it to PyErr_SetObject
(similarly to what genobject.c does).
We do this to handle a situation when "res" is a tuple, in which
case PyErr_SetObject would set the value of StopIteration to
the first element of the tuple.
(See PyErr_SetObject/_PyErr_CreateException code for details.)
*/
PyObject *e = PyObject_CallFunctionObjArgs(
PyExc_StopIteration, res, NULL);
Py_DECREF(res);
if (e == NULL) {
return NULL; return NULL;
} }
PyErr_SetObject(PyExc_StopIteration, e); Py_DECREF(res);
Py_DECREF(e);
} }
it->future = NULL; it->future = NULL;

View File

@ -208,16 +208,9 @@ gen_send_ex(PyGenObject *gen, PyObject *arg, int exc, int closing)
} }
} }
else { else {
PyObject *e = PyObject_CallFunctionObjArgs(
PyExc_StopIteration, result, NULL);
/* Async generators cannot return anything but None */ /* Async generators cannot return anything but None */
assert(!PyAsyncGen_CheckExact(gen)); assert(!PyAsyncGen_CheckExact(gen));
_PyGen_SetStopIterationValue(result);
if (e != NULL) {
PyErr_SetObject(PyExc_StopIteration, e);
Py_DECREF(e);
}
} }
Py_CLEAR(result); Py_CLEAR(result);
} }
@ -561,6 +554,43 @@ gen_iternext(PyGenObject *gen)
return gen_send_ex(gen, NULL, 0, 0); return gen_send_ex(gen, NULL, 0, 0);
} }
/*
* Set StopIteration with specified value. Value can be arbitrary object
* or NULL.
*
* Returns 0 if StopIteration is set and -1 if any other exception is set.
*/
int
_PyGen_SetStopIterationValue(PyObject *value)
{
PyObject *e;
if (value == NULL ||
(!PyTuple_Check(value) &&
!PyObject_TypeCheck(value, (PyTypeObject *) PyExc_StopIteration)))
{
/* Delay exception instantiation if we can */
PyErr_SetObject(PyExc_StopIteration, value);
return 0;
}
/* Construct an exception instance manually with
* PyObject_CallFunctionObjArgs and pass it to PyErr_SetObject.
*
* We do this to handle a situation when "value" is a tuple, in which
* case PyErr_SetObject would set the value of StopIteration to
* the first element of the tuple.
*
* (See PyErr_SetObject/_PyErr_CreateException code for details.)
*/
e = PyObject_CallFunctionObjArgs(PyExc_StopIteration, value, NULL);
if (e == NULL) {
return -1;
}
PyErr_SetObject(PyExc_StopIteration, e);
Py_DECREF(e);
return 0;
}
/* /*
* If StopIteration exception is set, fetches its 'value' * If StopIteration exception is set, fetches its 'value'
* attribute if any, otherwise sets pvalue to None. * attribute if any, otherwise sets pvalue to None.
@ -571,7 +601,8 @@ gen_iternext(PyGenObject *gen)
*/ */
int int
_PyGen_FetchStopIterationValue(PyObject **pvalue) { _PyGen_FetchStopIterationValue(PyObject **pvalue)
{
PyObject *et, *ev, *tb; PyObject *et, *ev, *tb;
PyObject *value = NULL; PyObject *value = NULL;
@ -583,8 +614,15 @@ _PyGen_FetchStopIterationValue(PyObject **pvalue) {
value = ((PyStopIterationObject *)ev)->value; value = ((PyStopIterationObject *)ev)->value;
Py_INCREF(value); Py_INCREF(value);
Py_DECREF(ev); Py_DECREF(ev);
} else if (et == PyExc_StopIteration) { } else if (et == PyExc_StopIteration && !PyTuple_Check(ev)) {
/* avoid normalisation and take ev as value */ /* Avoid normalisation and take ev as value.
*
* Normalization is required if the value is a tuple, in
* that case the value of StopIteration would be set to
* the first element of the tuple.
*
* (See _PyErr_CreateException code for details.)
*/
value = ev; value = ev;
} else { } else {
/* normalisation required */ /* normalisation required */
@ -1106,7 +1144,7 @@ typedef struct {
static PyObject * static PyObject *
aiter_wrapper_iternext(PyAIterWrapper *aw) aiter_wrapper_iternext(PyAIterWrapper *aw)
{ {
PyErr_SetObject(PyExc_StopIteration, aw->ags_aiter); _PyGen_SetStopIterationValue(aw->ags_aiter);
return NULL; return NULL;
} }
@ -1504,16 +1542,8 @@ async_gen_unwrap_value(PyAsyncGenObject *gen, PyObject *result)
if (_PyAsyncGenWrappedValue_CheckExact(result)) { if (_PyAsyncGenWrappedValue_CheckExact(result)) {
/* async yield */ /* async yield */
PyObject *e = PyObject_CallFunctionObjArgs( _PyGen_SetStopIterationValue(((_PyAsyncGenWrappedValue*)result)->agw_val);
PyExc_StopIteration,
((_PyAsyncGenWrappedValue*)result)->agw_val,
NULL);
Py_DECREF(result); Py_DECREF(result);
if (e == NULL) {
return NULL;
}
PyErr_SetObject(PyExc_StopIteration, e);
Py_DECREF(e);
return NULL; return NULL;
} }