This commit is contained in:
parent
66f8828bfc
commit
00e3372358
|
@ -1,6 +1,7 @@
|
|||
"""Tests support for new syntax introduced by PEP 492."""
|
||||
|
||||
import collections.abc
|
||||
import types
|
||||
import unittest
|
||||
|
||||
from test import support
|
||||
|
@ -164,5 +165,23 @@ class CoroutineTests(BaseTest):
|
|||
self.loop.run_until_complete(start())
|
||||
|
||||
|
||||
def test_types_coroutine(self):
|
||||
def gen():
|
||||
yield from ()
|
||||
return 'spam'
|
||||
|
||||
@types.coroutine
|
||||
def func():
|
||||
return gen()
|
||||
|
||||
async def coro():
|
||||
wrapper = func()
|
||||
self.assertIsInstance(wrapper, types._GeneratorWrapper)
|
||||
return await wrapper
|
||||
|
||||
data = self.loop.run_until_complete(coro())
|
||||
self.assertEqual(data, 'spam')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -7,7 +7,8 @@ import pickle
|
|||
import locale
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
import unittest.mock
|
||||
import weakref
|
||||
|
||||
class TypesTests(unittest.TestCase):
|
||||
|
||||
|
@ -1191,23 +1192,27 @@ class SimpleNamespaceTests(unittest.TestCase):
|
|||
|
||||
class CoroutineTests(unittest.TestCase):
|
||||
def test_wrong_args(self):
|
||||
class Foo:
|
||||
def __call__(self):
|
||||
pass
|
||||
def bar(): pass
|
||||
|
||||
samples = [None, 1, object()]
|
||||
for sample in samples:
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'types.coroutine.*expects a callable'):
|
||||
types.coroutine(sample)
|
||||
|
||||
def test_wrong_func(self):
|
||||
def test_non_gen_values(self):
|
||||
@types.coroutine
|
||||
def foo():
|
||||
return 'spam'
|
||||
self.assertEqual(foo(), 'spam')
|
||||
|
||||
class Awaitable:
|
||||
def __await__(self):
|
||||
return ()
|
||||
aw = Awaitable()
|
||||
@types.coroutine
|
||||
def foo():
|
||||
return aw
|
||||
self.assertIs(aw, foo())
|
||||
|
||||
def test_async_def(self):
|
||||
# Test that types.coroutine passes 'async def' coroutines
|
||||
# without modification
|
||||
|
@ -1263,24 +1268,157 @@ class CoroutineTests(unittest.TestCase):
|
|||
def send(self): pass
|
||||
def throw(self): pass
|
||||
def close(self): pass
|
||||
def __iter__(self): return self
|
||||
def __iter__(self): pass
|
||||
def __next__(self): pass
|
||||
|
||||
gen = GenLike()
|
||||
# Setup generator mock object
|
||||
gen = unittest.mock.MagicMock(GenLike)
|
||||
gen.__iter__ = lambda gen: gen
|
||||
gen.__name__ = 'gen'
|
||||
gen.__qualname__ = 'test.gen'
|
||||
self.assertIsInstance(gen, collections.abc.Generator)
|
||||
self.assertIs(gen, iter(gen))
|
||||
|
||||
@types.coroutine
|
||||
def foo():
|
||||
return gen
|
||||
self.assertIs(foo().__await__(), gen)
|
||||
self.assertTrue(isinstance(foo(), collections.abc.Coroutine))
|
||||
def foo(): return gen
|
||||
|
||||
wrapper = foo()
|
||||
self.assertIsInstance(wrapper, types._GeneratorWrapper)
|
||||
self.assertIs(wrapper.__await__(), wrapper)
|
||||
# Wrapper proxies duck generators completely:
|
||||
self.assertIs(iter(wrapper), wrapper)
|
||||
|
||||
self.assertIsInstance(wrapper, collections.abc.Coroutine)
|
||||
self.assertIsInstance(wrapper, collections.abc.Awaitable)
|
||||
|
||||
self.assertIs(wrapper.__qualname__, gen.__qualname__)
|
||||
self.assertIs(wrapper.__name__, gen.__name__)
|
||||
|
||||
# Test AttributeErrors
|
||||
for name in {'gi_running', 'gi_frame', 'gi_code',
|
||||
'cr_running', 'cr_frame', 'cr_code'}:
|
||||
with self.assertRaises(AttributeError):
|
||||
foo().gi_code
|
||||
getattr(wrapper, name)
|
||||
|
||||
# Test attributes pass-through
|
||||
gen.gi_running = object()
|
||||
gen.gi_frame = object()
|
||||
gen.gi_code = object()
|
||||
self.assertIs(wrapper.gi_running, gen.gi_running)
|
||||
self.assertIs(wrapper.gi_frame, gen.gi_frame)
|
||||
self.assertIs(wrapper.gi_code, gen.gi_code)
|
||||
self.assertIs(wrapper.cr_running, gen.gi_running)
|
||||
self.assertIs(wrapper.cr_frame, gen.gi_frame)
|
||||
self.assertIs(wrapper.cr_code, gen.gi_code)
|
||||
|
||||
wrapper.close()
|
||||
gen.close.assert_called_once_with()
|
||||
|
||||
wrapper.send(1)
|
||||
gen.send.assert_called_once_with(1)
|
||||
|
||||
wrapper.throw(1, 2, 3)
|
||||
gen.throw.assert_called_once_with(1, 2, 3)
|
||||
gen.reset_mock()
|
||||
|
||||
wrapper.throw(1, 2)
|
||||
gen.throw.assert_called_once_with(1, 2)
|
||||
gen.reset_mock()
|
||||
|
||||
wrapper.throw(1)
|
||||
gen.throw.assert_called_once_with(1)
|
||||
gen.reset_mock()
|
||||
|
||||
# Test exceptions propagation
|
||||
error = Exception()
|
||||
gen.throw.side_effect = error
|
||||
try:
|
||||
wrapper.throw(1)
|
||||
except Exception as ex:
|
||||
self.assertIs(ex, error)
|
||||
else:
|
||||
self.fail('wrapper did not propagate an exception')
|
||||
|
||||
# Test invalid args
|
||||
gen.reset_mock()
|
||||
with self.assertRaises(TypeError):
|
||||
wrapper.throw()
|
||||
self.assertFalse(gen.throw.called)
|
||||
with self.assertRaises(TypeError):
|
||||
wrapper.close(1)
|
||||
self.assertFalse(gen.close.called)
|
||||
with self.assertRaises(TypeError):
|
||||
wrapper.send()
|
||||
self.assertFalse(gen.send.called)
|
||||
|
||||
# Test that we do not double wrap
|
||||
@types.coroutine
|
||||
def bar(): return wrapper
|
||||
self.assertIs(wrapper, bar())
|
||||
|
||||
# Test weakrefs support
|
||||
ref = weakref.ref(wrapper)
|
||||
self.assertIs(ref(), wrapper)
|
||||
|
||||
def test_duck_functional_gen(self):
|
||||
class Generator:
|
||||
"""Emulates the following generator (very clumsy):
|
||||
|
||||
def gen(fut):
|
||||
result = yield fut
|
||||
return result * 2
|
||||
"""
|
||||
def __init__(self, fut):
|
||||
self._i = 0
|
||||
self._fut = fut
|
||||
def __iter__(self):
|
||||
return self
|
||||
def __next__(self):
|
||||
return self.send(None)
|
||||
def send(self, v):
|
||||
try:
|
||||
if self._i == 0:
|
||||
assert v is None
|
||||
return self._fut
|
||||
if self._i == 1:
|
||||
raise StopIteration(v * 2)
|
||||
if self._i > 1:
|
||||
raise StopIteration
|
||||
finally:
|
||||
self._i += 1
|
||||
def throw(self, tp, *exc):
|
||||
self._i = 100
|
||||
if tp is not GeneratorExit:
|
||||
raise tp
|
||||
def close(self):
|
||||
self.throw(GeneratorExit)
|
||||
|
||||
@types.coroutine
|
||||
def foo(): return Generator('spam')
|
||||
|
||||
wrapper = foo()
|
||||
self.assertIsInstance(wrapper, types._GeneratorWrapper)
|
||||
|
||||
async def corofunc():
|
||||
return await foo() + 100
|
||||
coro = corofunc()
|
||||
|
||||
self.assertEqual(coro.send(None), 'spam')
|
||||
try:
|
||||
coro.send(20)
|
||||
except StopIteration as ex:
|
||||
self.assertEqual(ex.args[0], 140)
|
||||
else:
|
||||
self.fail('StopIteration was expected')
|
||||
|
||||
def test_gen(self):
|
||||
def gen(): yield
|
||||
gen = gen()
|
||||
@types.coroutine
|
||||
def foo(): return gen
|
||||
self.assertIs(foo().__await__(), gen)
|
||||
wrapper = foo()
|
||||
self.assertIsInstance(wrapper, types._GeneratorWrapper)
|
||||
self.assertIs(wrapper.__await__(), gen)
|
||||
|
||||
for name in ('__name__', '__qualname__', 'gi_code',
|
||||
'gi_running', 'gi_frame'):
|
||||
|
@ -1289,19 +1427,8 @@ class CoroutineTests(unittest.TestCase):
|
|||
self.assertIs(foo().cr_code, gen.gi_code)
|
||||
|
||||
def test_genfunc(self):
|
||||
def gen():
|
||||
yield
|
||||
|
||||
self.assertFalse(isinstance(gen(), collections.abc.Coroutine))
|
||||
self.assertFalse(isinstance(gen(), collections.abc.Awaitable))
|
||||
|
||||
gen_code = gen.__code__
|
||||
decorated_gen = types.coroutine(gen)
|
||||
self.assertIs(decorated_gen, gen)
|
||||
self.assertIsNot(decorated_gen.__code__, gen_code)
|
||||
|
||||
decorated_gen2 = types.coroutine(decorated_gen)
|
||||
self.assertIs(decorated_gen2.__code__, decorated_gen.__code__)
|
||||
def gen(): yield
|
||||
self.assertIs(types.coroutine(gen), gen)
|
||||
|
||||
self.assertTrue(gen.__code__.co_flags & inspect.CO_ITERABLE_COROUTINE)
|
||||
self.assertFalse(gen.__code__.co_flags & inspect.CO_COROUTINE)
|
||||
|
@ -1309,10 +1436,27 @@ class CoroutineTests(unittest.TestCase):
|
|||
g = gen()
|
||||
self.assertTrue(g.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE)
|
||||
self.assertFalse(g.gi_code.co_flags & inspect.CO_COROUTINE)
|
||||
self.assertTrue(isinstance(g, collections.abc.Coroutine))
|
||||
self.assertTrue(isinstance(g, collections.abc.Awaitable))
|
||||
self.assertIsInstance(g, collections.abc.Coroutine)
|
||||
self.assertIsInstance(g, collections.abc.Awaitable)
|
||||
g.close() # silence warning
|
||||
|
||||
self.assertIs(types.coroutine(gen), gen)
|
||||
|
||||
def test_wrapper_object(self):
|
||||
def gen():
|
||||
yield
|
||||
@types.coroutine
|
||||
def coro():
|
||||
return gen()
|
||||
|
||||
wrapper = coro()
|
||||
self.assertIn('GeneratorWrapper', repr(wrapper))
|
||||
self.assertEqual(repr(wrapper), str(wrapper))
|
||||
self.assertTrue(set(dir(wrapper)).issuperset({
|
||||
'__await__', '__iter__', '__next__', 'cr_code', 'cr_running',
|
||||
'cr_frame', 'gi_code', 'gi_frame', 'gi_running', 'send',
|
||||
'close', 'throw'}))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
65
Lib/types.py
65
Lib/types.py
|
@ -166,6 +166,39 @@ class DynamicClassAttribute:
|
|||
import functools as _functools
|
||||
import collections.abc as _collections_abc
|
||||
|
||||
class _GeneratorWrapper:
|
||||
# TODO: Implement this in C.
|
||||
def __init__(self, gen):
|
||||
self.__wrapped__ = gen
|
||||
self.__isgen__ = gen.__class__ is GeneratorType
|
||||
self.__name__ = getattr(gen, '__name__', None)
|
||||
self.__qualname__ = getattr(gen, '__qualname__', None)
|
||||
def send(self, val):
|
||||
return self.__wrapped__.send(val)
|
||||
def throw(self, tp, *rest):
|
||||
return self.__wrapped__.throw(tp, *rest)
|
||||
def close(self):
|
||||
return self.__wrapped__.close()
|
||||
@property
|
||||
def gi_code(self):
|
||||
return self.__wrapped__.gi_code
|
||||
@property
|
||||
def gi_frame(self):
|
||||
return self.__wrapped__.gi_frame
|
||||
@property
|
||||
def gi_running(self):
|
||||
return self.__wrapped__.gi_running
|
||||
cr_code = gi_code
|
||||
cr_frame = gi_frame
|
||||
cr_running = gi_running
|
||||
def __next__(self):
|
||||
return next(self.__wrapped__)
|
||||
def __iter__(self):
|
||||
if self.__isgen__:
|
||||
return self.__wrapped__
|
||||
return self
|
||||
__await__ = __iter__
|
||||
|
||||
def coroutine(func):
|
||||
"""Convert regular generator function to a coroutine."""
|
||||
|
||||
|
@ -201,36 +234,6 @@ def coroutine(func):
|
|||
# return generator-like objects (for instance generators
|
||||
# compiled with Cython).
|
||||
|
||||
class GeneratorWrapper:
|
||||
def __init__(self, gen):
|
||||
self.__wrapped__ = gen
|
||||
self.__name__ = getattr(gen, '__name__', None)
|
||||
self.__qualname__ = getattr(gen, '__qualname__', None)
|
||||
def send(self, val):
|
||||
return self.__wrapped__.send(val)
|
||||
def throw(self, *args):
|
||||
return self.__wrapped__.throw(*args)
|
||||
def close(self):
|
||||
return self.__wrapped__.close()
|
||||
@property
|
||||
def gi_code(self):
|
||||
return self.__wrapped__.gi_code
|
||||
@property
|
||||
def gi_frame(self):
|
||||
return self.__wrapped__.gi_frame
|
||||
@property
|
||||
def gi_running(self):
|
||||
return self.__wrapped__.gi_running
|
||||
cr_code = gi_code
|
||||
cr_frame = gi_frame
|
||||
cr_running = gi_running
|
||||
def __next__(self):
|
||||
return next(self.__wrapped__)
|
||||
def __iter__(self):
|
||||
return self.__wrapped__
|
||||
def __await__(self):
|
||||
return self.__wrapped__
|
||||
|
||||
@_functools.wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
coro = func(*args, **kwargs)
|
||||
|
@ -243,7 +246,7 @@ def coroutine(func):
|
|||
# 'coro' is either a pure Python generator iterator, or it
|
||||
# implements collections.abc.Generator (and does not implement
|
||||
# collections.abc.Coroutine).
|
||||
return GeneratorWrapper(coro)
|
||||
return _GeneratorWrapper(coro)
|
||||
# 'coro' is either an instance of collections.abc.Coroutine or
|
||||
# some other object -- pass it through.
|
||||
return coro
|
||||
|
|
Loading…
Reference in New Issue