Issue #24325, #24400: Add more unittests for types.coroutine; tweak wrapper implementation.

This commit is contained in:
Yury Selivanov 2015-06-24 11:44:51 -04:00
parent 66f8828bfc
commit 00e3372358
3 changed files with 228 additions and 62 deletions

View File

@ -1,6 +1,7 @@
"""Tests support for new syntax introduced by PEP 492.""" """Tests support for new syntax introduced by PEP 492."""
import collections.abc import collections.abc
import types
import unittest import unittest
from test import support from test import support
@ -164,5 +165,23 @@ class CoroutineTests(BaseTest):
self.loop.run_until_complete(start()) self.loop.run_until_complete(start())
def test_types_coroutine(self):
def gen():
yield from ()
return 'spam'
@types.coroutine
def func():
return gen()
async def coro():
wrapper = func()
self.assertIsInstance(wrapper, types._GeneratorWrapper)
return await wrapper
data = self.loop.run_until_complete(coro())
self.assertEqual(data, 'spam')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -7,7 +7,8 @@ import pickle
import locale import locale
import sys import sys
import types import types
import unittest import unittest.mock
import weakref
class TypesTests(unittest.TestCase): class TypesTests(unittest.TestCase):
@ -1191,23 +1192,27 @@ class SimpleNamespaceTests(unittest.TestCase):
class CoroutineTests(unittest.TestCase): class CoroutineTests(unittest.TestCase):
def test_wrong_args(self): def test_wrong_args(self):
class Foo:
def __call__(self):
pass
def bar(): pass
samples = [None, 1, object()] samples = [None, 1, object()]
for sample in samples: for sample in samples:
with self.assertRaisesRegex(TypeError, with self.assertRaisesRegex(TypeError,
'types.coroutine.*expects a callable'): 'types.coroutine.*expects a callable'):
types.coroutine(sample) types.coroutine(sample)
def test_wrong_func(self): def test_non_gen_values(self):
@types.coroutine @types.coroutine
def foo(): def foo():
return 'spam' return 'spam'
self.assertEqual(foo(), 'spam') self.assertEqual(foo(), 'spam')
class Awaitable:
def __await__(self):
return ()
aw = Awaitable()
@types.coroutine
def foo():
return aw
self.assertIs(aw, foo())
def test_async_def(self): def test_async_def(self):
# Test that types.coroutine passes 'async def' coroutines # Test that types.coroutine passes 'async def' coroutines
# without modification # without modification
@ -1263,24 +1268,157 @@ class CoroutineTests(unittest.TestCase):
def send(self): pass def send(self): pass
def throw(self): pass def throw(self): pass
def close(self): pass def close(self): pass
def __iter__(self): return self def __iter__(self): pass
def __next__(self): pass def __next__(self): pass
gen = GenLike() # Setup generator mock object
gen = unittest.mock.MagicMock(GenLike)
gen.__iter__ = lambda gen: gen
gen.__name__ = 'gen'
gen.__qualname__ = 'test.gen'
self.assertIsInstance(gen, collections.abc.Generator)
self.assertIs(gen, iter(gen))
@types.coroutine @types.coroutine
def foo(): def foo(): return gen
return gen
self.assertIs(foo().__await__(), gen) wrapper = foo()
self.assertTrue(isinstance(foo(), collections.abc.Coroutine)) self.assertIsInstance(wrapper, types._GeneratorWrapper)
with self.assertRaises(AttributeError): self.assertIs(wrapper.__await__(), wrapper)
foo().gi_code # Wrapper proxies duck generators completely:
self.assertIs(iter(wrapper), wrapper)
self.assertIsInstance(wrapper, collections.abc.Coroutine)
self.assertIsInstance(wrapper, collections.abc.Awaitable)
self.assertIs(wrapper.__qualname__, gen.__qualname__)
self.assertIs(wrapper.__name__, gen.__name__)
# Test AttributeErrors
for name in {'gi_running', 'gi_frame', 'gi_code',
'cr_running', 'cr_frame', 'cr_code'}:
with self.assertRaises(AttributeError):
getattr(wrapper, name)
# Test attributes pass-through
gen.gi_running = object()
gen.gi_frame = object()
gen.gi_code = object()
self.assertIs(wrapper.gi_running, gen.gi_running)
self.assertIs(wrapper.gi_frame, gen.gi_frame)
self.assertIs(wrapper.gi_code, gen.gi_code)
self.assertIs(wrapper.cr_running, gen.gi_running)
self.assertIs(wrapper.cr_frame, gen.gi_frame)
self.assertIs(wrapper.cr_code, gen.gi_code)
wrapper.close()
gen.close.assert_called_once_with()
wrapper.send(1)
gen.send.assert_called_once_with(1)
wrapper.throw(1, 2, 3)
gen.throw.assert_called_once_with(1, 2, 3)
gen.reset_mock()
wrapper.throw(1, 2)
gen.throw.assert_called_once_with(1, 2)
gen.reset_mock()
wrapper.throw(1)
gen.throw.assert_called_once_with(1)
gen.reset_mock()
# Test exceptions propagation
error = Exception()
gen.throw.side_effect = error
try:
wrapper.throw(1)
except Exception as ex:
self.assertIs(ex, error)
else:
self.fail('wrapper did not propagate an exception')
# Test invalid args
gen.reset_mock()
with self.assertRaises(TypeError):
wrapper.throw()
self.assertFalse(gen.throw.called)
with self.assertRaises(TypeError):
wrapper.close(1)
self.assertFalse(gen.close.called)
with self.assertRaises(TypeError):
wrapper.send()
self.assertFalse(gen.send.called)
# Test that we do not double wrap
@types.coroutine
def bar(): return wrapper
self.assertIs(wrapper, bar())
# Test weakrefs support
ref = weakref.ref(wrapper)
self.assertIs(ref(), wrapper)
def test_duck_functional_gen(self):
class Generator:
"""Emulates the following generator (very clumsy):
def gen(fut):
result = yield fut
return result * 2
"""
def __init__(self, fut):
self._i = 0
self._fut = fut
def __iter__(self):
return self
def __next__(self):
return self.send(None)
def send(self, v):
try:
if self._i == 0:
assert v is None
return self._fut
if self._i == 1:
raise StopIteration(v * 2)
if self._i > 1:
raise StopIteration
finally:
self._i += 1
def throw(self, tp, *exc):
self._i = 100
if tp is not GeneratorExit:
raise tp
def close(self):
self.throw(GeneratorExit)
@types.coroutine
def foo(): return Generator('spam')
wrapper = foo()
self.assertIsInstance(wrapper, types._GeneratorWrapper)
async def corofunc():
return await foo() + 100
coro = corofunc()
self.assertEqual(coro.send(None), 'spam')
try:
coro.send(20)
except StopIteration as ex:
self.assertEqual(ex.args[0], 140)
else:
self.fail('StopIteration was expected')
def test_gen(self): def test_gen(self):
def gen(): yield def gen(): yield
gen = gen() gen = gen()
@types.coroutine @types.coroutine
def foo(): return gen def foo(): return gen
self.assertIs(foo().__await__(), gen) wrapper = foo()
self.assertIsInstance(wrapper, types._GeneratorWrapper)
self.assertIs(wrapper.__await__(), gen)
for name in ('__name__', '__qualname__', 'gi_code', for name in ('__name__', '__qualname__', 'gi_code',
'gi_running', 'gi_frame'): 'gi_running', 'gi_frame'):
@ -1289,19 +1427,8 @@ class CoroutineTests(unittest.TestCase):
self.assertIs(foo().cr_code, gen.gi_code) self.assertIs(foo().cr_code, gen.gi_code)
def test_genfunc(self): def test_genfunc(self):
def gen(): def gen(): yield
yield self.assertIs(types.coroutine(gen), gen)
self.assertFalse(isinstance(gen(), collections.abc.Coroutine))
self.assertFalse(isinstance(gen(), collections.abc.Awaitable))
gen_code = gen.__code__
decorated_gen = types.coroutine(gen)
self.assertIs(decorated_gen, gen)
self.assertIsNot(decorated_gen.__code__, gen_code)
decorated_gen2 = types.coroutine(decorated_gen)
self.assertIs(decorated_gen2.__code__, decorated_gen.__code__)
self.assertTrue(gen.__code__.co_flags & inspect.CO_ITERABLE_COROUTINE) self.assertTrue(gen.__code__.co_flags & inspect.CO_ITERABLE_COROUTINE)
self.assertFalse(gen.__code__.co_flags & inspect.CO_COROUTINE) self.assertFalse(gen.__code__.co_flags & inspect.CO_COROUTINE)
@ -1309,10 +1436,27 @@ class CoroutineTests(unittest.TestCase):
g = gen() g = gen()
self.assertTrue(g.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE) self.assertTrue(g.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE)
self.assertFalse(g.gi_code.co_flags & inspect.CO_COROUTINE) self.assertFalse(g.gi_code.co_flags & inspect.CO_COROUTINE)
self.assertTrue(isinstance(g, collections.abc.Coroutine)) self.assertIsInstance(g, collections.abc.Coroutine)
self.assertTrue(isinstance(g, collections.abc.Awaitable)) self.assertIsInstance(g, collections.abc.Awaitable)
g.close() # silence warning g.close() # silence warning
self.assertIs(types.coroutine(gen), gen)
def test_wrapper_object(self):
def gen():
yield
@types.coroutine
def coro():
return gen()
wrapper = coro()
self.assertIn('GeneratorWrapper', repr(wrapper))
self.assertEqual(repr(wrapper), str(wrapper))
self.assertTrue(set(dir(wrapper)).issuperset({
'__await__', '__iter__', '__next__', 'cr_code', 'cr_running',
'cr_frame', 'gi_code', 'gi_frame', 'gi_running', 'send',
'close', 'throw'}))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -166,6 +166,39 @@ class DynamicClassAttribute:
import functools as _functools import functools as _functools
import collections.abc as _collections_abc import collections.abc as _collections_abc
class _GeneratorWrapper:
# TODO: Implement this in C.
def __init__(self, gen):
self.__wrapped__ = gen
self.__isgen__ = gen.__class__ is GeneratorType
self.__name__ = getattr(gen, '__name__', None)
self.__qualname__ = getattr(gen, '__qualname__', None)
def send(self, val):
return self.__wrapped__.send(val)
def throw(self, tp, *rest):
return self.__wrapped__.throw(tp, *rest)
def close(self):
return self.__wrapped__.close()
@property
def gi_code(self):
return self.__wrapped__.gi_code
@property
def gi_frame(self):
return self.__wrapped__.gi_frame
@property
def gi_running(self):
return self.__wrapped__.gi_running
cr_code = gi_code
cr_frame = gi_frame
cr_running = gi_running
def __next__(self):
return next(self.__wrapped__)
def __iter__(self):
if self.__isgen__:
return self.__wrapped__
return self
__await__ = __iter__
def coroutine(func): def coroutine(func):
"""Convert regular generator function to a coroutine.""" """Convert regular generator function to a coroutine."""
@ -201,36 +234,6 @@ def coroutine(func):
# return generator-like objects (for instance generators # return generator-like objects (for instance generators
# compiled with Cython). # compiled with Cython).
class GeneratorWrapper:
def __init__(self, gen):
self.__wrapped__ = gen
self.__name__ = getattr(gen, '__name__', None)
self.__qualname__ = getattr(gen, '__qualname__', None)
def send(self, val):
return self.__wrapped__.send(val)
def throw(self, *args):
return self.__wrapped__.throw(*args)
def close(self):
return self.__wrapped__.close()
@property
def gi_code(self):
return self.__wrapped__.gi_code
@property
def gi_frame(self):
return self.__wrapped__.gi_frame
@property
def gi_running(self):
return self.__wrapped__.gi_running
cr_code = gi_code
cr_frame = gi_frame
cr_running = gi_running
def __next__(self):
return next(self.__wrapped__)
def __iter__(self):
return self.__wrapped__
def __await__(self):
return self.__wrapped__
@_functools.wraps(func) @_functools.wraps(func)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
coro = func(*args, **kwargs) coro = func(*args, **kwargs)
@ -243,7 +246,7 @@ def coroutine(func):
# 'coro' is either a pure Python generator iterator, or it # 'coro' is either a pure Python generator iterator, or it
# implements collections.abc.Generator (and does not implement # implements collections.abc.Generator (and does not implement
# collections.abc.Coroutine). # collections.abc.Coroutine).
return GeneratorWrapper(coro) return _GeneratorWrapper(coro)
# 'coro' is either an instance of collections.abc.Coroutine or # 'coro' is either an instance of collections.abc.Coroutine or
# some other object -- pass it through. # some other object -- pass it through.
return coro return coro