From a125497ea302aff937a5c59f98c39dba4f1ab59b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 11 Feb 2014 11:34:30 +0100 Subject: [PATCH] asyncio, Tulip issue 126: call_soon(), call_soon_threadsafe(), call_later(), call_at() and run_in_executor() now raise a TypeError if the callback is a coroutine function. --- Lib/asyncio/base_events.py | 6 ++++++ Lib/asyncio/test_utils.py | 5 ++++- Lib/test/test_asyncio/test_base_events.py | 18 ++++++++++++++++++ Lib/test/test_asyncio/test_proactor_events.py | 2 +- Lib/test/test_asyncio/test_selector_events.py | 9 +++++---- Lib/test/test_asyncio/test_tasks.py | 12 +++++------- 6 files changed, 39 insertions(+), 13 deletions(-) diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 48b3ee3e9df..4b7b161ecaa 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -227,6 +227,8 @@ class BaseEventLoop(events.AbstractEventLoop): def call_at(self, when, callback, *args): """Like call_later(), but uses an absolute time.""" + if tasks.iscoroutinefunction(callback): + raise TypeError("coroutines cannot be used with call_at()") timer = events.TimerHandle(when, callback, args) heapq.heappush(self._scheduled, timer) return timer @@ -241,6 +243,8 @@ class BaseEventLoop(events.AbstractEventLoop): Any positional arguments after the callback will be passed to the callback when it is called. """ + if tasks.iscoroutinefunction(callback): + raise TypeError("coroutines cannot be used with call_soon()") handle = events.Handle(callback, args) self._ready.append(handle) return handle @@ -252,6 +256,8 @@ class BaseEventLoop(events.AbstractEventLoop): return handle def run_in_executor(self, executor, callback, *args): + if tasks.iscoroutinefunction(callback): + raise TypeError("coroutines cannot be used with run_in_executor()") if isinstance(callback, events.Handle): assert not args assert not isinstance(callback, events.TimerHandle) diff --git a/Lib/asyncio/test_utils.py b/Lib/asyncio/test_utils.py index 7c8e1dcbd6c..deab7c33122 100644 --- a/Lib/asyncio/test_utils.py +++ b/Lib/asyncio/test_utils.py @@ -135,7 +135,7 @@ def make_test_protocol(base): if name.startswith('__') and name.endswith('__'): # skip magic names continue - dct[name] = unittest.mock.Mock(return_value=None) + dct[name] = MockCallback(return_value=None) return type('TestProtocol', (base,) + base.__bases__, dct)() @@ -274,3 +274,6 @@ class TestLoop(base_events.BaseEventLoop): def _write_to_self(self): pass + +def MockCallback(**kwargs): + return unittest.mock.Mock(spec=['__call__'], **kwargs) diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index 5b05684723d..c6950ab3fa8 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -567,6 +567,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): m_socket.getaddrinfo.return_value = [ (2, 1, 6, '', ('127.0.0.1', 10100))] + m_socket.getaddrinfo._is_coroutine = False m_sock = m_socket.socket.return_value = unittest.mock.Mock() m_sock.bind.side_effect = Err @@ -577,6 +578,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): @unittest.mock.patch('asyncio.base_events.socket') def test_create_datagram_endpoint_no_addrinfo(self, m_socket): m_socket.getaddrinfo.return_value = [] + m_socket.getaddrinfo._is_coroutine = False coro = self.loop.create_datagram_endpoint( MyDatagramProto, local_addr=('localhost', 0)) @@ -681,6 +683,22 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): unittest.mock.ANY, MyProto, sock, None, None) + def test_call_coroutine(self): + @asyncio.coroutine + def coroutine_function(): + pass + + with self.assertRaises(TypeError): + self.loop.call_soon(coroutine_function) + with self.assertRaises(TypeError): + self.loop.call_soon_threadsafe(coroutine_function) + with self.assertRaises(TypeError): + self.loop.call_later(60, coroutine_function) + with self.assertRaises(TypeError): + self.loop.call_at(self.loop.time() + 60, coroutine_function) + with self.assertRaises(TypeError): + self.loop.run_in_executor(None, coroutine_function) + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py index 9964f425d21..6bea1a33685 100644 --- a/Lib/test/test_asyncio/test_proactor_events.py +++ b/Lib/test/test_asyncio/test_proactor_events.py @@ -402,7 +402,7 @@ class BaseProactorEventLoopTests(unittest.TestCase): NotImplementedError, BaseProactorEventLoop, self.proactor) def test_make_socket_transport(self): - tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock()) + tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol()) self.assertIsInstance(tr, _ProactorSocketTransport) def test_loop_self_reading(self): diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py index ad0b0be81e5..855a8954e86 100644 --- a/Lib/test/test_asyncio/test_selector_events.py +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -44,8 +44,8 @@ class BaseSelectorEventLoopTests(unittest.TestCase): def test_make_socket_transport(self): m = unittest.mock.Mock() self.loop.add_reader = unittest.mock.Mock() - self.assertIsInstance( - self.loop._make_socket_transport(m, m), _SelectorSocketTransport) + transport = self.loop._make_socket_transport(m, asyncio.Protocol()) + self.assertIsInstance(transport, _SelectorSocketTransport) @unittest.skipIf(ssl is None, 'No ssl module') def test_make_ssl_transport(self): @@ -54,8 +54,9 @@ class BaseSelectorEventLoopTests(unittest.TestCase): self.loop.add_writer = unittest.mock.Mock() self.loop.remove_reader = unittest.mock.Mock() self.loop.remove_writer = unittest.mock.Mock() - self.assertIsInstance( - self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport) + waiter = asyncio.Future(loop=self.loop) + transport = self.loop._make_ssl_transport(m, asyncio.Protocol(), m, waiter) + self.assertIsInstance(transport, _SelectorSslTransport) @unittest.mock.patch('asyncio.selector_events.ssl', None) def test_make_ssl_transport_without_ssl_error(self): diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py index 9abdfa5bc13..29bdaf5bd4f 100644 --- a/Lib/test/test_asyncio/test_tasks.py +++ b/Lib/test/test_asyncio/test_tasks.py @@ -2,8 +2,6 @@ import gc import unittest -import unittest.mock -from unittest.mock import Mock import asyncio from asyncio import test_utils @@ -1358,7 +1356,7 @@ class GatherTestsBase: def _check_success(self, **kwargs): a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)] fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs) - cb = Mock() + cb = test_utils.MockCallback() fut.add_done_callback(cb) b.set_result(1) a.set_result(2) @@ -1380,7 +1378,7 @@ class GatherTestsBase: def test_one_exception(self): a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e)) - cb = Mock() + cb = test_utils.MockCallback() fut.add_done_callback(cb) exc = ZeroDivisionError() a.set_result(1) @@ -1399,7 +1397,7 @@ class GatherTestsBase: a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)] fut = asyncio.gather(*self.wrap_futures(a, b, c, d), return_exceptions=True) - cb = Mock() + cb = test_utils.MockCallback() fut.add_done_callback(cb) exc = ZeroDivisionError() exc2 = RuntimeError() @@ -1460,7 +1458,7 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase): def test_one_cancellation(self): a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] fut = asyncio.gather(a, b, c, d, e) - cb = Mock() + cb = test_utils.MockCallback() fut.add_done_callback(cb) a.set_result(1) b.cancel() @@ -1479,7 +1477,7 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase): a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop) for i in range(6)] fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True) - cb = Mock() + cb = test_utils.MockCallback() fut.add_done_callback(cb) a.set_result(1) zde = ZeroDivisionError()