asyncio: ensure_future() now understands awaitables

This commit is contained in:
Yury Selivanov 2015-10-02 15:00:19 -04:00
parent e2382c598c
commit 620279b9ac
2 changed files with 32 additions and 2 deletions

View File

@ -512,7 +512,7 @@ def async(coro_or_future, *, loop=None):
def ensure_future(coro_or_future, *, loop=None):
"""Wrap a coroutine in a future.
"""Wrap a coroutine or an awaitable in a future.
If the argument is a Future, it is returned directly.
"""
@ -527,8 +527,20 @@ def ensure_future(coro_or_future, *, loop=None):
if task._source_traceback:
del task._source_traceback[-1]
return task
elif compat.PY35 and inspect.isawaitable(coro_or_future):
return ensure_future(_wrap_awaitable(coro_or_future), loop=loop)
else:
raise TypeError('A Future or coroutine is required')
raise TypeError('A Future, a coroutine or an awaitable is required')
@coroutine
def _wrap_awaitable(awaitable):
"""Helper for asyncio.ensure_future().
Wraps awaitable (an object with __await__) into a coroutine
that will later be wrapped in a Task by ensure_future().
"""
return (yield from awaitable.__await__())
class _GatheringFuture(futures.Future):

View File

@ -153,6 +153,24 @@ class TaskTests(test_utils.TestCase):
t = asyncio.ensure_future(t_orig, loop=self.loop)
self.assertIs(t, t_orig)
@unittest.skipUnless(PY35, 'need python 3.5 or later')
def test_ensure_future_awaitable(self):
class Aw:
def __init__(self, coro):
self.coro = coro
def __await__(self):
return (yield from self.coro)
@asyncio.coroutine
def coro():
return 'ok'
loop = asyncio.new_event_loop()
self.set_event_loop(loop)
fut = asyncio.ensure_future(Aw(coro()), loop=loop)
loop.run_until_complete(fut)
assert fut.result() == 'ok'
def test_ensure_future_neither(self):
with self.assertRaises(TypeError):
asyncio.ensure_future('ok')