asyncio, Tulip issue 126: call_soon(), call_soon_threadsafe(), call_later(),
call_at() and run_in_executor() now raise a TypeError if the callback is a coroutine function.
This commit is contained in:
parent
262a458b8a
commit
9af4a246f9
|
@ -227,6 +227,8 @@ class BaseEventLoop(events.AbstractEventLoop):
|
||||||
|
|
||||||
def call_at(self, when, callback, *args):
|
def call_at(self, when, callback, *args):
|
||||||
"""Like call_later(), but uses an absolute time."""
|
"""Like call_later(), but uses an absolute time."""
|
||||||
|
if tasks.iscoroutinefunction(callback):
|
||||||
|
raise TypeError("coroutines cannot be used with call_at()")
|
||||||
timer = events.TimerHandle(when, callback, args)
|
timer = events.TimerHandle(when, callback, args)
|
||||||
heapq.heappush(self._scheduled, timer)
|
heapq.heappush(self._scheduled, timer)
|
||||||
return timer
|
return timer
|
||||||
|
@ -241,6 +243,8 @@ class BaseEventLoop(events.AbstractEventLoop):
|
||||||
Any positional arguments after the callback will be passed to
|
Any positional arguments after the callback will be passed to
|
||||||
the callback when it is called.
|
the callback when it is called.
|
||||||
"""
|
"""
|
||||||
|
if tasks.iscoroutinefunction(callback):
|
||||||
|
raise TypeError("coroutines cannot be used with call_soon()")
|
||||||
handle = events.Handle(callback, args)
|
handle = events.Handle(callback, args)
|
||||||
self._ready.append(handle)
|
self._ready.append(handle)
|
||||||
return handle
|
return handle
|
||||||
|
@ -252,6 +256,8 @@ class BaseEventLoop(events.AbstractEventLoop):
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
def run_in_executor(self, executor, callback, *args):
|
def run_in_executor(self, executor, callback, *args):
|
||||||
|
if tasks.iscoroutinefunction(callback):
|
||||||
|
raise TypeError("coroutines cannot be used with run_in_executor()")
|
||||||
if isinstance(callback, events.Handle):
|
if isinstance(callback, events.Handle):
|
||||||
assert not args
|
assert not args
|
||||||
assert not isinstance(callback, events.TimerHandle)
|
assert not isinstance(callback, events.TimerHandle)
|
||||||
|
|
|
@ -135,7 +135,7 @@ def make_test_protocol(base):
|
||||||
if name.startswith('__') and name.endswith('__'):
|
if name.startswith('__') and name.endswith('__'):
|
||||||
# skip magic names
|
# skip magic names
|
||||||
continue
|
continue
|
||||||
dct[name] = unittest.mock.Mock(return_value=None)
|
dct[name] = MockCallback(return_value=None)
|
||||||
return type('TestProtocol', (base,) + base.__bases__, dct)()
|
return type('TestProtocol', (base,) + base.__bases__, dct)()
|
||||||
|
|
||||||
|
|
||||||
|
@ -274,3 +274,6 @@ class TestLoop(base_events.BaseEventLoop):
|
||||||
|
|
||||||
def _write_to_self(self):
|
def _write_to_self(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def MockCallback(**kwargs):
|
||||||
|
return unittest.mock.Mock(spec=['__call__'], **kwargs)
|
||||||
|
|
|
@ -567,6 +567,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
|
||||||
|
|
||||||
m_socket.getaddrinfo.return_value = [
|
m_socket.getaddrinfo.return_value = [
|
||||||
(2, 1, 6, '', ('127.0.0.1', 10100))]
|
(2, 1, 6, '', ('127.0.0.1', 10100))]
|
||||||
|
m_socket.getaddrinfo._is_coroutine = False
|
||||||
m_sock = m_socket.socket.return_value = unittest.mock.Mock()
|
m_sock = m_socket.socket.return_value = unittest.mock.Mock()
|
||||||
m_sock.bind.side_effect = Err
|
m_sock.bind.side_effect = Err
|
||||||
|
|
||||||
|
@ -577,6 +578,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
|
||||||
@unittest.mock.patch('asyncio.base_events.socket')
|
@unittest.mock.patch('asyncio.base_events.socket')
|
||||||
def test_create_datagram_endpoint_no_addrinfo(self, m_socket):
|
def test_create_datagram_endpoint_no_addrinfo(self, m_socket):
|
||||||
m_socket.getaddrinfo.return_value = []
|
m_socket.getaddrinfo.return_value = []
|
||||||
|
m_socket.getaddrinfo._is_coroutine = False
|
||||||
|
|
||||||
coro = self.loop.create_datagram_endpoint(
|
coro = self.loop.create_datagram_endpoint(
|
||||||
MyDatagramProto, local_addr=('localhost', 0))
|
MyDatagramProto, local_addr=('localhost', 0))
|
||||||
|
@ -681,6 +683,22 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
|
||||||
unittest.mock.ANY,
|
unittest.mock.ANY,
|
||||||
MyProto, sock, None, None)
|
MyProto, sock, None, None)
|
||||||
|
|
||||||
|
def test_call_coroutine(self):
|
||||||
|
@asyncio.coroutine
|
||||||
|
def coroutine_function():
|
||||||
|
pass
|
||||||
|
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
self.loop.call_soon(coroutine_function)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
self.loop.call_soon_threadsafe(coroutine_function)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
self.loop.call_later(60, coroutine_function)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
self.loop.call_at(self.loop.time() + 60, coroutine_function)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
self.loop.run_in_executor(None, coroutine_function)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -402,7 +402,7 @@ class BaseProactorEventLoopTests(unittest.TestCase):
|
||||||
NotImplementedError, BaseProactorEventLoop, self.proactor)
|
NotImplementedError, BaseProactorEventLoop, self.proactor)
|
||||||
|
|
||||||
def test_make_socket_transport(self):
|
def test_make_socket_transport(self):
|
||||||
tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock())
|
tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol())
|
||||||
self.assertIsInstance(tr, _ProactorSocketTransport)
|
self.assertIsInstance(tr, _ProactorSocketTransport)
|
||||||
|
|
||||||
def test_loop_self_reading(self):
|
def test_loop_self_reading(self):
|
||||||
|
|
|
@ -44,8 +44,8 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
|
||||||
def test_make_socket_transport(self):
|
def test_make_socket_transport(self):
|
||||||
m = unittest.mock.Mock()
|
m = unittest.mock.Mock()
|
||||||
self.loop.add_reader = unittest.mock.Mock()
|
self.loop.add_reader = unittest.mock.Mock()
|
||||||
self.assertIsInstance(
|
transport = self.loop._make_socket_transport(m, asyncio.Protocol())
|
||||||
self.loop._make_socket_transport(m, m), _SelectorSocketTransport)
|
self.assertIsInstance(transport, _SelectorSocketTransport)
|
||||||
|
|
||||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||||
def test_make_ssl_transport(self):
|
def test_make_ssl_transport(self):
|
||||||
|
@ -54,8 +54,9 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
|
||||||
self.loop.add_writer = unittest.mock.Mock()
|
self.loop.add_writer = unittest.mock.Mock()
|
||||||
self.loop.remove_reader = unittest.mock.Mock()
|
self.loop.remove_reader = unittest.mock.Mock()
|
||||||
self.loop.remove_writer = unittest.mock.Mock()
|
self.loop.remove_writer = unittest.mock.Mock()
|
||||||
self.assertIsInstance(
|
waiter = asyncio.Future(loop=self.loop)
|
||||||
self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport)
|
transport = self.loop._make_ssl_transport(m, asyncio.Protocol(), m, waiter)
|
||||||
|
self.assertIsInstance(transport, _SelectorSslTransport)
|
||||||
|
|
||||||
@unittest.mock.patch('asyncio.selector_events.ssl', None)
|
@unittest.mock.patch('asyncio.selector_events.ssl', None)
|
||||||
def test_make_ssl_transport_without_ssl_error(self):
|
def test_make_ssl_transport_without_ssl_error(self):
|
||||||
|
|
|
@ -2,8 +2,6 @@
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import unittest
|
import unittest
|
||||||
import unittest.mock
|
|
||||||
from unittest.mock import Mock
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from asyncio import test_utils
|
from asyncio import test_utils
|
||||||
|
@ -1358,7 +1356,7 @@ class GatherTestsBase:
|
||||||
def _check_success(self, **kwargs):
|
def _check_success(self, **kwargs):
|
||||||
a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)]
|
a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)]
|
||||||
fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs)
|
fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs)
|
||||||
cb = Mock()
|
cb = test_utils.MockCallback()
|
||||||
fut.add_done_callback(cb)
|
fut.add_done_callback(cb)
|
||||||
b.set_result(1)
|
b.set_result(1)
|
||||||
a.set_result(2)
|
a.set_result(2)
|
||||||
|
@ -1380,7 +1378,7 @@ class GatherTestsBase:
|
||||||
def test_one_exception(self):
|
def test_one_exception(self):
|
||||||
a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)]
|
a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)]
|
||||||
fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e))
|
fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e))
|
||||||
cb = Mock()
|
cb = test_utils.MockCallback()
|
||||||
fut.add_done_callback(cb)
|
fut.add_done_callback(cb)
|
||||||
exc = ZeroDivisionError()
|
exc = ZeroDivisionError()
|
||||||
a.set_result(1)
|
a.set_result(1)
|
||||||
|
@ -1399,7 +1397,7 @@ class GatherTestsBase:
|
||||||
a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)]
|
a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)]
|
||||||
fut = asyncio.gather(*self.wrap_futures(a, b, c, d),
|
fut = asyncio.gather(*self.wrap_futures(a, b, c, d),
|
||||||
return_exceptions=True)
|
return_exceptions=True)
|
||||||
cb = Mock()
|
cb = test_utils.MockCallback()
|
||||||
fut.add_done_callback(cb)
|
fut.add_done_callback(cb)
|
||||||
exc = ZeroDivisionError()
|
exc = ZeroDivisionError()
|
||||||
exc2 = RuntimeError()
|
exc2 = RuntimeError()
|
||||||
|
@ -1460,7 +1458,7 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase):
|
||||||
def test_one_cancellation(self):
|
def test_one_cancellation(self):
|
||||||
a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)]
|
a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)]
|
||||||
fut = asyncio.gather(a, b, c, d, e)
|
fut = asyncio.gather(a, b, c, d, e)
|
||||||
cb = Mock()
|
cb = test_utils.MockCallback()
|
||||||
fut.add_done_callback(cb)
|
fut.add_done_callback(cb)
|
||||||
a.set_result(1)
|
a.set_result(1)
|
||||||
b.cancel()
|
b.cancel()
|
||||||
|
@ -1479,7 +1477,7 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase):
|
||||||
a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop)
|
a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop)
|
||||||
for i in range(6)]
|
for i in range(6)]
|
||||||
fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True)
|
fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True)
|
||||||
cb = Mock()
|
cb = test_utils.MockCallback()
|
||||||
fut.add_done_callback(cb)
|
fut.add_done_callback(cb)
|
||||||
a.set_result(1)
|
a.set_result(1)
|
||||||
zde = ZeroDivisionError()
|
zde = ZeroDivisionError()
|
||||||
|
|
Loading…
Reference in New Issue