diff --git a/Include/genobject.h b/Include/genobject.h index 1ff32a8eafa..61e708a73ab 100644 --- a/Include/genobject.h +++ b/Include/genobject.h @@ -41,6 +41,7 @@ PyAPI_FUNC(PyObject *) PyGen_New(struct _frame *); PyAPI_FUNC(PyObject *) PyGen_NewWithQualName(struct _frame *, PyObject *name, PyObject *qualname); PyAPI_FUNC(int) PyGen_NeedsFinalizing(PyGenObject *); +PyAPI_FUNC(int) _PyGen_SetStopIterationValue(PyObject *); PyAPI_FUNC(int) _PyGen_FetchStopIterationValue(PyObject **); PyObject *_PyGen_Send(PyGenObject *, PyObject *); PyObject *_PyGen_yf(PyGenObject *); diff --git a/Lib/test/test_coroutines.py b/Lib/test/test_coroutines.py index e52654c14d0..4a327b5ba9f 100644 --- a/Lib/test/test_coroutines.py +++ b/Lib/test/test_coroutines.py @@ -710,6 +710,21 @@ class CoroutineTest(unittest.TestCase): coro.close() self.assertEqual(CHK, 1) + def test_coro_wrapper_send_tuple(self): + async def foo(): + return (10,) + + result = run_async__await__(foo()) + self.assertEqual(result, ([], (10,))) + + def test_coro_wrapper_send_stop_iterator(self): + async def foo(): + return StopIteration(10) + + result = run_async__await__(foo()) + self.assertIsInstance(result[1], StopIteration) + self.assertEqual(result[1].value, 10) + def test_cr_await(self): @types.coroutine def a(): @@ -1537,6 +1552,52 @@ class CoroutineTest(unittest.TestCase): warnings.simplefilter("error") run_async(foo()) + def test_for_tuple(self): + class Done(Exception): pass + + class AIter(tuple): + i = 0 + def __aiter__(self): + return self + async def __anext__(self): + if self.i >= len(self): + raise StopAsyncIteration + self.i += 1 + return self[self.i - 1] + + result = [] + async def foo(): + async for i in AIter([42]): + result.append(i) + raise Done + + with self.assertRaises(Done): + foo().send(None) + self.assertEqual(result, [42]) + + def test_for_stop_iteration(self): + class Done(Exception): pass + + class AIter(StopIteration): + i = 0 + def __aiter__(self): + return self + async def __anext__(self): + if self.i: + raise StopAsyncIteration + self.i += 1 + return self.value + + result = [] + async def foo(): + async for i in AIter(42): + result.append(i) + raise Done + + with self.assertRaises(Done): + foo().send(None) + self.assertEqual(result, [42]) + def test_copy(self): async def func(): pass coro = func() diff --git a/Lib/test/test_generators.py b/Lib/test/test_generators.py index 3f82462478d..cd6a43d753e 100644 --- a/Lib/test/test_generators.py +++ b/Lib/test/test_generators.py @@ -277,6 +277,27 @@ class ExceptionTest(unittest.TestCase): # hence no warning. next(g) + def test_return_tuple(self): + def g(): + return (yield 1) + + gen = g() + self.assertEqual(next(gen), 1) + with self.assertRaises(StopIteration) as cm: + gen.send((2,)) + self.assertEqual(cm.exception.value, (2,)) + + def test_return_stopiteration(self): + def g(): + return (yield 1) + + gen = g() + self.assertEqual(next(gen), 1) + with self.assertRaises(StopIteration) as cm: + gen.send(StopIteration(2)) + self.assertIsInstance(cm.exception.value, StopIteration) + self.assertEqual(cm.exception.value.value, 2) + class YieldFromTests(unittest.TestCase): def test_generator_gi_yieldfrom(self): diff --git a/Objects/genobject.c b/Objects/genobject.c index 9172e6a0102..0d5d54fdbaf 100644 --- a/Objects/genobject.c +++ b/Objects/genobject.c @@ -154,12 +154,7 @@ gen_send_ex(PyGenObject *gen, PyObject *arg, int exc, int closing) /* Delay exception instantiation if we can */ PyErr_SetNone(PyExc_StopIteration); } else { - PyObject *e = PyObject_CallFunctionObjArgs( - PyExc_StopIteration, result, NULL); - if (e != NULL) { - PyErr_SetObject(PyExc_StopIteration, e); - Py_DECREF(e); - } + _PyGen_SetStopIterationValue(result); } Py_CLEAR(result); } @@ -459,6 +454,43 @@ gen_iternext(PyGenObject *gen) return gen_send_ex(gen, NULL, 0, 0); } +/* + * Set StopIteration with specified value. Value can be arbitrary object + * or NULL. + * + * Returns 0 if StopIteration is set and -1 if any other exception is set. + */ +int +_PyGen_SetStopIterationValue(PyObject *value) +{ + PyObject *e; + + if (value == NULL || + (!PyTuple_Check(value) && + !PyObject_TypeCheck(value, (PyTypeObject *) PyExc_StopIteration))) + { + /* Delay exception instantiation if we can */ + PyErr_SetObject(PyExc_StopIteration, value); + return 0; + } + /* Construct an exception instance manually with + * PyObject_CallFunctionObjArgs and pass it to PyErr_SetObject. + * + * We do this to handle a situation when "value" is a tuple, in which + * case PyErr_SetObject would set the value of StopIteration to + * the first element of the tuple. + * + * (See PyErr_SetObject/_PyErr_CreateException code for details.) + */ + e = PyObject_CallFunctionObjArgs(PyExc_StopIteration, value, NULL); + if (e == NULL) { + return -1; + } + PyErr_SetObject(PyExc_StopIteration, e); + Py_DECREF(e); + return 0; +} + /* * If StopIteration exception is set, fetches its 'value' * attribute if any, otherwise sets pvalue to None. @@ -469,7 +501,8 @@ gen_iternext(PyGenObject *gen) */ int -_PyGen_FetchStopIterationValue(PyObject **pvalue) { +_PyGen_FetchStopIterationValue(PyObject **pvalue) +{ PyObject *et, *ev, *tb; PyObject *value = NULL; @@ -481,8 +514,15 @@ _PyGen_FetchStopIterationValue(PyObject **pvalue) { value = ((PyStopIterationObject *)ev)->value; Py_INCREF(value); Py_DECREF(ev); - } else if (et == PyExc_StopIteration) { - /* avoid normalisation and take ev as value */ + } else if (et == PyExc_StopIteration && !PyTuple_Check(ev)) { + /* Avoid normalisation and take ev as value. + * + * Normalization is required if the value is a tuple, in + * that case the value of StopIteration would be set to + * the first element of the tuple. + * + * (See _PyErr_CreateException code for details.) + */ value = ev; } else { /* normalisation required */ @@ -1012,7 +1052,7 @@ typedef struct { static PyObject * aiter_wrapper_iternext(PyAIterWrapper *aw) { - PyErr_SetObject(PyExc_StopIteration, aw->aw_aiter); + _PyGen_SetStopIterationValue(aw->aw_aiter); return NULL; }