Issue #23996: Added _PyGen_SetStopIterationValue for safe raising

StopIteration with value. More safely handle non-normalized exceptions
in -_PyGen_FetchStopIterationValue.
This commit is contained in:
Serhiy Storchaka 2016-11-06 18:44:42 +02:00
parent 04b3d8b697
commit 24411f8a8d
4 changed files with 133 additions and 10 deletions

View File

@ -41,6 +41,7 @@ PyAPI_FUNC(PyObject *) PyGen_New(struct _frame *);
PyAPI_FUNC(PyObject *) PyGen_NewWithQualName(struct _frame *, PyAPI_FUNC(PyObject *) PyGen_NewWithQualName(struct _frame *,
PyObject *name, PyObject *qualname); PyObject *name, PyObject *qualname);
PyAPI_FUNC(int) PyGen_NeedsFinalizing(PyGenObject *); PyAPI_FUNC(int) PyGen_NeedsFinalizing(PyGenObject *);
PyAPI_FUNC(int) _PyGen_SetStopIterationValue(PyObject *);
PyAPI_FUNC(int) _PyGen_FetchStopIterationValue(PyObject **); PyAPI_FUNC(int) _PyGen_FetchStopIterationValue(PyObject **);
PyObject *_PyGen_Send(PyGenObject *, PyObject *); PyObject *_PyGen_Send(PyGenObject *, PyObject *);
PyObject *_PyGen_yf(PyGenObject *); PyObject *_PyGen_yf(PyGenObject *);

View File

@ -710,6 +710,21 @@ class CoroutineTest(unittest.TestCase):
coro.close() coro.close()
self.assertEqual(CHK, 1) self.assertEqual(CHK, 1)
def test_coro_wrapper_send_tuple(self):
async def foo():
return (10,)
result = run_async__await__(foo())
self.assertEqual(result, ([], (10,)))
def test_coro_wrapper_send_stop_iterator(self):
async def foo():
return StopIteration(10)
result = run_async__await__(foo())
self.assertIsInstance(result[1], StopIteration)
self.assertEqual(result[1].value, 10)
def test_cr_await(self): def test_cr_await(self):
@types.coroutine @types.coroutine
def a(): def a():
@ -1537,6 +1552,52 @@ class CoroutineTest(unittest.TestCase):
warnings.simplefilter("error") warnings.simplefilter("error")
run_async(foo()) run_async(foo())
def test_for_tuple(self):
class Done(Exception): pass
class AIter(tuple):
i = 0
def __aiter__(self):
return self
async def __anext__(self):
if self.i >= len(self):
raise StopAsyncIteration
self.i += 1
return self[self.i - 1]
result = []
async def foo():
async for i in AIter([42]):
result.append(i)
raise Done
with self.assertRaises(Done):
foo().send(None)
self.assertEqual(result, [42])
def test_for_stop_iteration(self):
class Done(Exception): pass
class AIter(StopIteration):
i = 0
def __aiter__(self):
return self
async def __anext__(self):
if self.i:
raise StopAsyncIteration
self.i += 1
return self.value
result = []
async def foo():
async for i in AIter(42):
result.append(i)
raise Done
with self.assertRaises(Done):
foo().send(None)
self.assertEqual(result, [42])
def test_copy(self): def test_copy(self):
async def func(): pass async def func(): pass
coro = func() coro = func()

View File

@ -277,6 +277,27 @@ class ExceptionTest(unittest.TestCase):
# hence no warning. # hence no warning.
next(g) next(g)
def test_return_tuple(self):
def g():
return (yield 1)
gen = g()
self.assertEqual(next(gen), 1)
with self.assertRaises(StopIteration) as cm:
gen.send((2,))
self.assertEqual(cm.exception.value, (2,))
def test_return_stopiteration(self):
def g():
return (yield 1)
gen = g()
self.assertEqual(next(gen), 1)
with self.assertRaises(StopIteration) as cm:
gen.send(StopIteration(2))
self.assertIsInstance(cm.exception.value, StopIteration)
self.assertEqual(cm.exception.value.value, 2)
class YieldFromTests(unittest.TestCase): class YieldFromTests(unittest.TestCase):
def test_generator_gi_yieldfrom(self): def test_generator_gi_yieldfrom(self):

View File

@ -154,12 +154,7 @@ gen_send_ex(PyGenObject *gen, PyObject *arg, int exc, int closing)
/* Delay exception instantiation if we can */ /* Delay exception instantiation if we can */
PyErr_SetNone(PyExc_StopIteration); PyErr_SetNone(PyExc_StopIteration);
} else { } else {
PyObject *e = PyObject_CallFunctionObjArgs( _PyGen_SetStopIterationValue(result);
PyExc_StopIteration, result, NULL);
if (e != NULL) {
PyErr_SetObject(PyExc_StopIteration, e);
Py_DECREF(e);
}
} }
Py_CLEAR(result); Py_CLEAR(result);
} }
@ -459,6 +454,43 @@ gen_iternext(PyGenObject *gen)
return gen_send_ex(gen, NULL, 0, 0); return gen_send_ex(gen, NULL, 0, 0);
} }
/*
* Set StopIteration with specified value. Value can be arbitrary object
* or NULL.
*
* Returns 0 if StopIteration is set and -1 if any other exception is set.
*/
int
_PyGen_SetStopIterationValue(PyObject *value)
{
PyObject *e;
if (value == NULL ||
(!PyTuple_Check(value) &&
!PyObject_TypeCheck(value, (PyTypeObject *) PyExc_StopIteration)))
{
/* Delay exception instantiation if we can */
PyErr_SetObject(PyExc_StopIteration, value);
return 0;
}
/* Construct an exception instance manually with
* PyObject_CallFunctionObjArgs and pass it to PyErr_SetObject.
*
* We do this to handle a situation when "value" is a tuple, in which
* case PyErr_SetObject would set the value of StopIteration to
* the first element of the tuple.
*
* (See PyErr_SetObject/_PyErr_CreateException code for details.)
*/
e = PyObject_CallFunctionObjArgs(PyExc_StopIteration, value, NULL);
if (e == NULL) {
return -1;
}
PyErr_SetObject(PyExc_StopIteration, e);
Py_DECREF(e);
return 0;
}
/* /*
* If StopIteration exception is set, fetches its 'value' * If StopIteration exception is set, fetches its 'value'
* attribute if any, otherwise sets pvalue to None. * attribute if any, otherwise sets pvalue to None.
@ -469,7 +501,8 @@ gen_iternext(PyGenObject *gen)
*/ */
int int
_PyGen_FetchStopIterationValue(PyObject **pvalue) { _PyGen_FetchStopIterationValue(PyObject **pvalue)
{
PyObject *et, *ev, *tb; PyObject *et, *ev, *tb;
PyObject *value = NULL; PyObject *value = NULL;
@ -481,8 +514,15 @@ _PyGen_FetchStopIterationValue(PyObject **pvalue) {
value = ((PyStopIterationObject *)ev)->value; value = ((PyStopIterationObject *)ev)->value;
Py_INCREF(value); Py_INCREF(value);
Py_DECREF(ev); Py_DECREF(ev);
} else if (et == PyExc_StopIteration) { } else if (et == PyExc_StopIteration && !PyTuple_Check(ev)) {
/* avoid normalisation and take ev as value */ /* Avoid normalisation and take ev as value.
*
* Normalization is required if the value is a tuple, in
* that case the value of StopIteration would be set to
* the first element of the tuple.
*
* (See _PyErr_CreateException code for details.)
*/
value = ev; value = ev;
} else { } else {
/* normalisation required */ /* normalisation required */
@ -1012,7 +1052,7 @@ typedef struct {
static PyObject * static PyObject *
aiter_wrapper_iternext(PyAIterWrapper *aw) aiter_wrapper_iternext(PyAIterWrapper *aw)
{ {
PyErr_SetObject(PyExc_StopIteration, aw->aw_aiter); _PyGen_SetStopIterationValue(aw->aw_aiter);
return NULL; return NULL;
} }