asyncio, Tulip issue 126: call_soon(), call_soon_threadsafe(), call_later(),

call_at() and run_in_executor() now raise a TypeError if the callback is a
coroutine function.
This commit is contained in:
Victor Stinner 2014-02-11 11:34:30 +01:00
parent 1db2ba3a92
commit a125497ea3
6 changed files with 39 additions and 13 deletions

View File

@ -227,6 +227,8 @@ class BaseEventLoop(events.AbstractEventLoop):
def call_at(self, when, callback, *args): def call_at(self, when, callback, *args):
"""Like call_later(), but uses an absolute time.""" """Like call_later(), but uses an absolute time."""
if tasks.iscoroutinefunction(callback):
raise TypeError("coroutines cannot be used with call_at()")
timer = events.TimerHandle(when, callback, args) timer = events.TimerHandle(when, callback, args)
heapq.heappush(self._scheduled, timer) heapq.heappush(self._scheduled, timer)
return timer return timer
@ -241,6 +243,8 @@ class BaseEventLoop(events.AbstractEventLoop):
Any positional arguments after the callback will be passed to Any positional arguments after the callback will be passed to
the callback when it is called. the callback when it is called.
""" """
if tasks.iscoroutinefunction(callback):
raise TypeError("coroutines cannot be used with call_soon()")
handle = events.Handle(callback, args) handle = events.Handle(callback, args)
self._ready.append(handle) self._ready.append(handle)
return handle return handle
@ -252,6 +256,8 @@ class BaseEventLoop(events.AbstractEventLoop):
return handle return handle
def run_in_executor(self, executor, callback, *args): def run_in_executor(self, executor, callback, *args):
if tasks.iscoroutinefunction(callback):
raise TypeError("coroutines cannot be used with run_in_executor()")
if isinstance(callback, events.Handle): if isinstance(callback, events.Handle):
assert not args assert not args
assert not isinstance(callback, events.TimerHandle) assert not isinstance(callback, events.TimerHandle)

View File

@ -135,7 +135,7 @@ def make_test_protocol(base):
if name.startswith('__') and name.endswith('__'): if name.startswith('__') and name.endswith('__'):
# skip magic names # skip magic names
continue continue
dct[name] = unittest.mock.Mock(return_value=None) dct[name] = MockCallback(return_value=None)
return type('TestProtocol', (base,) + base.__bases__, dct)() return type('TestProtocol', (base,) + base.__bases__, dct)()
@ -274,3 +274,6 @@ class TestLoop(base_events.BaseEventLoop):
def _write_to_self(self): def _write_to_self(self):
pass pass
def MockCallback(**kwargs):
return unittest.mock.Mock(spec=['__call__'], **kwargs)

View File

@ -567,6 +567,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
m_socket.getaddrinfo.return_value = [ m_socket.getaddrinfo.return_value = [
(2, 1, 6, '', ('127.0.0.1', 10100))] (2, 1, 6, '', ('127.0.0.1', 10100))]
m_socket.getaddrinfo._is_coroutine = False
m_sock = m_socket.socket.return_value = unittest.mock.Mock() m_sock = m_socket.socket.return_value = unittest.mock.Mock()
m_sock.bind.side_effect = Err m_sock.bind.side_effect = Err
@ -577,6 +578,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
@unittest.mock.patch('asyncio.base_events.socket') @unittest.mock.patch('asyncio.base_events.socket')
def test_create_datagram_endpoint_no_addrinfo(self, m_socket): def test_create_datagram_endpoint_no_addrinfo(self, m_socket):
m_socket.getaddrinfo.return_value = [] m_socket.getaddrinfo.return_value = []
m_socket.getaddrinfo._is_coroutine = False
coro = self.loop.create_datagram_endpoint( coro = self.loop.create_datagram_endpoint(
MyDatagramProto, local_addr=('localhost', 0)) MyDatagramProto, local_addr=('localhost', 0))
@ -681,6 +683,22 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
unittest.mock.ANY, unittest.mock.ANY,
MyProto, sock, None, None) MyProto, sock, None, None)
def test_call_coroutine(self):
@asyncio.coroutine
def coroutine_function():
pass
with self.assertRaises(TypeError):
self.loop.call_soon(coroutine_function)
with self.assertRaises(TypeError):
self.loop.call_soon_threadsafe(coroutine_function)
with self.assertRaises(TypeError):
self.loop.call_later(60, coroutine_function)
with self.assertRaises(TypeError):
self.loop.call_at(self.loop.time() + 60, coroutine_function)
with self.assertRaises(TypeError):
self.loop.run_in_executor(None, coroutine_function)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -402,7 +402,7 @@ class BaseProactorEventLoopTests(unittest.TestCase):
NotImplementedError, BaseProactorEventLoop, self.proactor) NotImplementedError, BaseProactorEventLoop, self.proactor)
def test_make_socket_transport(self): def test_make_socket_transport(self):
tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock()) tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol())
self.assertIsInstance(tr, _ProactorSocketTransport) self.assertIsInstance(tr, _ProactorSocketTransport)
def test_loop_self_reading(self): def test_loop_self_reading(self):

View File

@ -44,8 +44,8 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
def test_make_socket_transport(self): def test_make_socket_transport(self):
m = unittest.mock.Mock() m = unittest.mock.Mock()
self.loop.add_reader = unittest.mock.Mock() self.loop.add_reader = unittest.mock.Mock()
self.assertIsInstance( transport = self.loop._make_socket_transport(m, asyncio.Protocol())
self.loop._make_socket_transport(m, m), _SelectorSocketTransport) self.assertIsInstance(transport, _SelectorSocketTransport)
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
def test_make_ssl_transport(self): def test_make_ssl_transport(self):
@ -54,8 +54,9 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.loop.add_writer = unittest.mock.Mock() self.loop.add_writer = unittest.mock.Mock()
self.loop.remove_reader = unittest.mock.Mock() self.loop.remove_reader = unittest.mock.Mock()
self.loop.remove_writer = unittest.mock.Mock() self.loop.remove_writer = unittest.mock.Mock()
self.assertIsInstance( waiter = asyncio.Future(loop=self.loop)
self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport) transport = self.loop._make_ssl_transport(m, asyncio.Protocol(), m, waiter)
self.assertIsInstance(transport, _SelectorSslTransport)
@unittest.mock.patch('asyncio.selector_events.ssl', None) @unittest.mock.patch('asyncio.selector_events.ssl', None)
def test_make_ssl_transport_without_ssl_error(self): def test_make_ssl_transport_without_ssl_error(self):

View File

@ -2,8 +2,6 @@
import gc import gc
import unittest import unittest
import unittest.mock
from unittest.mock import Mock
import asyncio import asyncio
from asyncio import test_utils from asyncio import test_utils
@ -1358,7 +1356,7 @@ class GatherTestsBase:
def _check_success(self, **kwargs): def _check_success(self, **kwargs):
a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)] a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)]
fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs) fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs)
cb = Mock() cb = test_utils.MockCallback()
fut.add_done_callback(cb) fut.add_done_callback(cb)
b.set_result(1) b.set_result(1)
a.set_result(2) a.set_result(2)
@ -1380,7 +1378,7 @@ class GatherTestsBase:
def test_one_exception(self): def test_one_exception(self):
a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)]
fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e)) fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e))
cb = Mock() cb = test_utils.MockCallback()
fut.add_done_callback(cb) fut.add_done_callback(cb)
exc = ZeroDivisionError() exc = ZeroDivisionError()
a.set_result(1) a.set_result(1)
@ -1399,7 +1397,7 @@ class GatherTestsBase:
a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)] a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)]
fut = asyncio.gather(*self.wrap_futures(a, b, c, d), fut = asyncio.gather(*self.wrap_futures(a, b, c, d),
return_exceptions=True) return_exceptions=True)
cb = Mock() cb = test_utils.MockCallback()
fut.add_done_callback(cb) fut.add_done_callback(cb)
exc = ZeroDivisionError() exc = ZeroDivisionError()
exc2 = RuntimeError() exc2 = RuntimeError()
@ -1460,7 +1458,7 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase):
def test_one_cancellation(self): def test_one_cancellation(self):
a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)]
fut = asyncio.gather(a, b, c, d, e) fut = asyncio.gather(a, b, c, d, e)
cb = Mock() cb = test_utils.MockCallback()
fut.add_done_callback(cb) fut.add_done_callback(cb)
a.set_result(1) a.set_result(1)
b.cancel() b.cancel()
@ -1479,7 +1477,7 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase):
a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop) a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop)
for i in range(6)] for i in range(6)]
fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True) fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True)
cb = Mock() cb = test_utils.MockCallback()
fut.add_done_callback(cb) fut.add_done_callback(cb)
a.set_result(1) a.set_result(1)
zde = ZeroDivisionError() zde = ZeroDivisionError()