diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py index 575d205404a..52fef181cec 100644 --- a/Lib/asyncio/tasks.py +++ b/Lib/asyncio/tasks.py @@ -334,6 +334,15 @@ def wait_for(fut, timeout, *, loop=None): if timeout is None: return (yield from fut) + if timeout <= 0: + fut = ensure_future(fut, loop=loop) + + if fut.done(): + return fut.result() + + fut.cancel() + raise futures.TimeoutError() + waiter = loop.create_future() timeout_handle = loop.call_later(timeout, _release_waiter, waiter) cb = functools.partial(_release_waiter, waiter) diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py index 36082ec7c6a..7ff56b560b6 100644 --- a/Lib/test/test_asyncio/test_tasks.py +++ b/Lib/test/test_asyncio/test_tasks.py @@ -661,6 +661,76 @@ class BaseTaskTests: t.cancel() self.assertRaises(asyncio.CancelledError, loop.run_until_complete, t) + def test_wait_for_timeout_less_then_0_or_0_future_done(self): + def gen(): + when = yield + self.assertAlmostEqual(0, when) + + loop = self.new_test_loop(gen) + + fut = self.new_future(loop) + fut.set_result('done') + + ret = loop.run_until_complete(asyncio.wait_for(fut, 0, loop=loop)) + + self.assertEqual(ret, 'done') + self.assertTrue(fut.done()) + self.assertAlmostEqual(0, loop.time()) + + def test_wait_for_timeout_less_then_0_or_0_coroutine_do_not_started(self): + def gen(): + when = yield + self.assertAlmostEqual(0, when) + + loop = self.new_test_loop(gen) + + foo_started = False + + @asyncio.coroutine + def foo(): + nonlocal foo_started + foo_started = True + + with self.assertRaises(asyncio.TimeoutError): + loop.run_until_complete(asyncio.wait_for(foo(), 0, loop=loop)) + + self.assertAlmostEqual(0, loop.time()) + self.assertEqual(foo_started, False) + + def test_wait_for_timeout_less_then_0_or_0(self): + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0, when) + + for timeout in [0, -1]: + with self.subTest(timeout=timeout): + loop = self.new_test_loop(gen) + + foo_running = None + + @asyncio.coroutine + def foo(): + nonlocal foo_running + foo_running = True + try: + yield from asyncio.sleep(0.2, loop=loop) + finally: + foo_running = False + return 'done' + + fut = self.new_task(loop, foo()) + + with self.assertRaises(asyncio.TimeoutError): + loop.run_until_complete(asyncio.wait_for( + fut, timeout, loop=loop)) + self.assertTrue(fut.done()) + # it should have been cancelled due to the timeout + self.assertTrue(fut.cancelled()) + self.assertAlmostEqual(0, loop.time()) + self.assertEqual(foo_running, False) + def test_wait_for(self): def gen(): diff --git a/Misc/NEWS.d/next/Library/2017-09-22-23-48-49.bpo-31556.9J0u5H.rst b/Misc/NEWS.d/next/Library/2017-09-22-23-48-49.bpo-31556.9J0u5H.rst new file mode 100644 index 00000000000..2e6b0284696 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2017-09-22-23-48-49.bpo-31556.9J0u5H.rst @@ -0,0 +1 @@ +Cancel asyncio.wait_for future faster if timeout <= 0