From 00e337235855999bfa5339a7d87322b9e0f07148 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Wed, 24 Jun 2015 11:44:51 -0400 Subject: [PATCH] Issue #24325, #24400: Add more unittests for types.coroutine; tweak wrapper implementation. --- Lib/test/test_asyncio/test_pep492.py | 19 +++ Lib/test/test_types.py | 206 +++++++++++++++++++++++---- Lib/types.py | 65 +++++---- 3 files changed, 228 insertions(+), 62 deletions(-) diff --git a/Lib/test/test_asyncio/test_pep492.py b/Lib/test/test_asyncio/test_pep492.py index fe69d32896f..5c7e9aecffc 100644 --- a/Lib/test/test_asyncio/test_pep492.py +++ b/Lib/test/test_asyncio/test_pep492.py @@ -1,6 +1,7 @@ """Tests support for new syntax introduced by PEP 492.""" import collections.abc +import types import unittest from test import support @@ -164,5 +165,23 @@ class CoroutineTests(BaseTest): self.loop.run_until_complete(start()) + def test_types_coroutine(self): + def gen(): + yield from () + return 'spam' + + @types.coroutine + def func(): + return gen() + + async def coro(): + wrapper = func() + self.assertIsInstance(wrapper, types._GeneratorWrapper) + return await wrapper + + data = self.loop.run_until_complete(coro()) + self.assertEqual(data, 'spam') + + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py index 5b971d194dc..e489898d947 100644 --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -7,7 +7,8 @@ import pickle import locale import sys import types -import unittest +import unittest.mock +import weakref class TypesTests(unittest.TestCase): @@ -1191,23 +1192,27 @@ class SimpleNamespaceTests(unittest.TestCase): class CoroutineTests(unittest.TestCase): def test_wrong_args(self): - class Foo: - def __call__(self): - pass - def bar(): pass - samples = [None, 1, object()] for sample in samples: with self.assertRaisesRegex(TypeError, 'types.coroutine.*expects a callable'): types.coroutine(sample) - def test_wrong_func(self): + def test_non_gen_values(self): @types.coroutine def foo(): return 'spam' self.assertEqual(foo(), 'spam') + class Awaitable: + def __await__(self): + return () + aw = Awaitable() + @types.coroutine + def foo(): + return aw + self.assertIs(aw, foo()) + def test_async_def(self): # Test that types.coroutine passes 'async def' coroutines # without modification @@ -1263,24 +1268,157 @@ class CoroutineTests(unittest.TestCase): def send(self): pass def throw(self): pass def close(self): pass - def __iter__(self): return self + def __iter__(self): pass def __next__(self): pass - gen = GenLike() + # Setup generator mock object + gen = unittest.mock.MagicMock(GenLike) + gen.__iter__ = lambda gen: gen + gen.__name__ = 'gen' + gen.__qualname__ = 'test.gen' + self.assertIsInstance(gen, collections.abc.Generator) + self.assertIs(gen, iter(gen)) + @types.coroutine - def foo(): - return gen - self.assertIs(foo().__await__(), gen) - self.assertTrue(isinstance(foo(), collections.abc.Coroutine)) - with self.assertRaises(AttributeError): - foo().gi_code + def foo(): return gen + + wrapper = foo() + self.assertIsInstance(wrapper, types._GeneratorWrapper) + self.assertIs(wrapper.__await__(), wrapper) + # Wrapper proxies duck generators completely: + self.assertIs(iter(wrapper), wrapper) + + self.assertIsInstance(wrapper, collections.abc.Coroutine) + self.assertIsInstance(wrapper, collections.abc.Awaitable) + + self.assertIs(wrapper.__qualname__, gen.__qualname__) + self.assertIs(wrapper.__name__, gen.__name__) + + # Test AttributeErrors + for name in {'gi_running', 'gi_frame', 'gi_code', + 'cr_running', 'cr_frame', 'cr_code'}: + with self.assertRaises(AttributeError): + getattr(wrapper, name) + + # Test attributes pass-through + gen.gi_running = object() + gen.gi_frame = object() + gen.gi_code = object() + self.assertIs(wrapper.gi_running, gen.gi_running) + self.assertIs(wrapper.gi_frame, gen.gi_frame) + self.assertIs(wrapper.gi_code, gen.gi_code) + self.assertIs(wrapper.cr_running, gen.gi_running) + self.assertIs(wrapper.cr_frame, gen.gi_frame) + self.assertIs(wrapper.cr_code, gen.gi_code) + + wrapper.close() + gen.close.assert_called_once_with() + + wrapper.send(1) + gen.send.assert_called_once_with(1) + + wrapper.throw(1, 2, 3) + gen.throw.assert_called_once_with(1, 2, 3) + gen.reset_mock() + + wrapper.throw(1, 2) + gen.throw.assert_called_once_with(1, 2) + gen.reset_mock() + + wrapper.throw(1) + gen.throw.assert_called_once_with(1) + gen.reset_mock() + + # Test exceptions propagation + error = Exception() + gen.throw.side_effect = error + try: + wrapper.throw(1) + except Exception as ex: + self.assertIs(ex, error) + else: + self.fail('wrapper did not propagate an exception') + + # Test invalid args + gen.reset_mock() + with self.assertRaises(TypeError): + wrapper.throw() + self.assertFalse(gen.throw.called) + with self.assertRaises(TypeError): + wrapper.close(1) + self.assertFalse(gen.close.called) + with self.assertRaises(TypeError): + wrapper.send() + self.assertFalse(gen.send.called) + + # Test that we do not double wrap + @types.coroutine + def bar(): return wrapper + self.assertIs(wrapper, bar()) + + # Test weakrefs support + ref = weakref.ref(wrapper) + self.assertIs(ref(), wrapper) + + def test_duck_functional_gen(self): + class Generator: + """Emulates the following generator (very clumsy): + + def gen(fut): + result = yield fut + return result * 2 + """ + def __init__(self, fut): + self._i = 0 + self._fut = fut + def __iter__(self): + return self + def __next__(self): + return self.send(None) + def send(self, v): + try: + if self._i == 0: + assert v is None + return self._fut + if self._i == 1: + raise StopIteration(v * 2) + if self._i > 1: + raise StopIteration + finally: + self._i += 1 + def throw(self, tp, *exc): + self._i = 100 + if tp is not GeneratorExit: + raise tp + def close(self): + self.throw(GeneratorExit) + + @types.coroutine + def foo(): return Generator('spam') + + wrapper = foo() + self.assertIsInstance(wrapper, types._GeneratorWrapper) + + async def corofunc(): + return await foo() + 100 + coro = corofunc() + + self.assertEqual(coro.send(None), 'spam') + try: + coro.send(20) + except StopIteration as ex: + self.assertEqual(ex.args[0], 140) + else: + self.fail('StopIteration was expected') def test_gen(self): def gen(): yield gen = gen() @types.coroutine def foo(): return gen - self.assertIs(foo().__await__(), gen) + wrapper = foo() + self.assertIsInstance(wrapper, types._GeneratorWrapper) + self.assertIs(wrapper.__await__(), gen) for name in ('__name__', '__qualname__', 'gi_code', 'gi_running', 'gi_frame'): @@ -1289,19 +1427,8 @@ class CoroutineTests(unittest.TestCase): self.assertIs(foo().cr_code, gen.gi_code) def test_genfunc(self): - def gen(): - yield - - self.assertFalse(isinstance(gen(), collections.abc.Coroutine)) - self.assertFalse(isinstance(gen(), collections.abc.Awaitable)) - - gen_code = gen.__code__ - decorated_gen = types.coroutine(gen) - self.assertIs(decorated_gen, gen) - self.assertIsNot(decorated_gen.__code__, gen_code) - - decorated_gen2 = types.coroutine(decorated_gen) - self.assertIs(decorated_gen2.__code__, decorated_gen.__code__) + def gen(): yield + self.assertIs(types.coroutine(gen), gen) self.assertTrue(gen.__code__.co_flags & inspect.CO_ITERABLE_COROUTINE) self.assertFalse(gen.__code__.co_flags & inspect.CO_COROUTINE) @@ -1309,10 +1436,27 @@ class CoroutineTests(unittest.TestCase): g = gen() self.assertTrue(g.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE) self.assertFalse(g.gi_code.co_flags & inspect.CO_COROUTINE) - self.assertTrue(isinstance(g, collections.abc.Coroutine)) - self.assertTrue(isinstance(g, collections.abc.Awaitable)) + self.assertIsInstance(g, collections.abc.Coroutine) + self.assertIsInstance(g, collections.abc.Awaitable) g.close() # silence warning + self.assertIs(types.coroutine(gen), gen) + + def test_wrapper_object(self): + def gen(): + yield + @types.coroutine + def coro(): + return gen() + + wrapper = coro() + self.assertIn('GeneratorWrapper', repr(wrapper)) + self.assertEqual(repr(wrapper), str(wrapper)) + self.assertTrue(set(dir(wrapper)).issuperset({ + '__await__', '__iter__', '__next__', 'cr_code', 'cr_running', + 'cr_frame', 'gi_code', 'gi_frame', 'gi_running', 'send', + 'close', 'throw'})) + if __name__ == '__main__': unittest.main() diff --git a/Lib/types.py b/Lib/types.py index dc1b040f89e..1d44653ef97 100644 --- a/Lib/types.py +++ b/Lib/types.py @@ -166,6 +166,39 @@ class DynamicClassAttribute: import functools as _functools import collections.abc as _collections_abc +class _GeneratorWrapper: + # TODO: Implement this in C. + def __init__(self, gen): + self.__wrapped__ = gen + self.__isgen__ = gen.__class__ is GeneratorType + self.__name__ = getattr(gen, '__name__', None) + self.__qualname__ = getattr(gen, '__qualname__', None) + def send(self, val): + return self.__wrapped__.send(val) + def throw(self, tp, *rest): + return self.__wrapped__.throw(tp, *rest) + def close(self): + return self.__wrapped__.close() + @property + def gi_code(self): + return self.__wrapped__.gi_code + @property + def gi_frame(self): + return self.__wrapped__.gi_frame + @property + def gi_running(self): + return self.__wrapped__.gi_running + cr_code = gi_code + cr_frame = gi_frame + cr_running = gi_running + def __next__(self): + return next(self.__wrapped__) + def __iter__(self): + if self.__isgen__: + return self.__wrapped__ + return self + __await__ = __iter__ + def coroutine(func): """Convert regular generator function to a coroutine.""" @@ -201,36 +234,6 @@ def coroutine(func): # return generator-like objects (for instance generators # compiled with Cython). - class GeneratorWrapper: - def __init__(self, gen): - self.__wrapped__ = gen - self.__name__ = getattr(gen, '__name__', None) - self.__qualname__ = getattr(gen, '__qualname__', None) - def send(self, val): - return self.__wrapped__.send(val) - def throw(self, *args): - return self.__wrapped__.throw(*args) - def close(self): - return self.__wrapped__.close() - @property - def gi_code(self): - return self.__wrapped__.gi_code - @property - def gi_frame(self): - return self.__wrapped__.gi_frame - @property - def gi_running(self): - return self.__wrapped__.gi_running - cr_code = gi_code - cr_frame = gi_frame - cr_running = gi_running - def __next__(self): - return next(self.__wrapped__) - def __iter__(self): - return self.__wrapped__ - def __await__(self): - return self.__wrapped__ - @_functools.wraps(func) def wrapped(*args, **kwargs): coro = func(*args, **kwargs) @@ -243,7 +246,7 @@ def coroutine(func): # 'coro' is either a pure Python generator iterator, or it # implements collections.abc.Generator (and does not implement # collections.abc.Coroutine). - return GeneratorWrapper(coro) + return _GeneratorWrapper(coro) # 'coro' is either an instance of collections.abc.Coroutine or # some other object -- pass it through. return coro