Merge 3.4 (asyncio)

This commit is contained in:
Victor Stinner 2015-01-29 00:55:46 +01:00
commit bea0a28439
6 changed files with 77 additions and 36 deletions

View File

@ -38,7 +38,7 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin,
self._server._attach() self._server._attach()
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
if waiter is not None: if waiter is not None:
# wait until protocol.connection_made() has been called # only wake up the waiter when connection_made() has been called
self._loop.call_soon(waiter._set_result_unless_cancelled, None) self._loop.call_soon(waiter._set_result_unless_cancelled, None)
def __repr__(self): def __repr__(self):

View File

@ -578,10 +578,12 @@ class _SelectorSocketTransport(_SelectorTransport):
self._eof = False self._eof = False
self._paused = False self._paused = False
self._loop.add_reader(self._sock_fd, self._read_ready)
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
# only start reading when connection_made() has been called
self._loop.call_soon(self._loop.add_reader,
self._sock_fd, self._read_ready)
if waiter is not None: if waiter is not None:
# wait until protocol.connection_made() has been called # only wake up the waiter when connection_made() has been called
self._loop.call_soon(waiter._set_result_unless_cancelled, None) self._loop.call_soon(waiter._set_result_unless_cancelled, None)
def pause_reading(self): def pause_reading(self):
@ -732,6 +734,16 @@ class _SelectorSslTransport(_SelectorTransport):
start_time = None start_time = None
self._on_handshake(start_time) self._on_handshake(start_time)
def _wakeup_waiter(self, exc=None):
if self._waiter is None:
return
if not self._waiter.cancelled():
if exc is not None:
self._waiter.set_exception(exc)
else:
self._waiter.set_result(None)
self._waiter = None
def _on_handshake(self, start_time): def _on_handshake(self, start_time):
try: try:
self._sock.do_handshake() self._sock.do_handshake()
@ -750,8 +762,7 @@ class _SelectorSslTransport(_SelectorTransport):
self._loop.remove_reader(self._sock_fd) self._loop.remove_reader(self._sock_fd)
self._loop.remove_writer(self._sock_fd) self._loop.remove_writer(self._sock_fd)
self._sock.close() self._sock.close()
if self._waiter is not None and not self._waiter.cancelled(): self._wakeup_waiter(exc)
self._waiter.set_exception(exc)
if isinstance(exc, Exception): if isinstance(exc, Exception):
return return
else: else:
@ -774,9 +785,7 @@ class _SelectorSslTransport(_SelectorTransport):
"on matching the hostname", "on matching the hostname",
self, exc_info=True) self, exc_info=True)
self._sock.close() self._sock.close()
if (self._waiter is not None self._wakeup_waiter(exc)
and not self._waiter.cancelled()):
self._waiter.set_exception(exc)
return return
# Add extra info that becomes available after handshake. # Add extra info that becomes available after handshake.
@ -789,10 +798,8 @@ class _SelectorSslTransport(_SelectorTransport):
self._write_wants_read = False self._write_wants_read = False
self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.add_reader(self._sock_fd, self._read_ready)
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
if self._waiter is not None: # only wake up the waiter when connection_made() has been called
# wait until protocol.connection_made() has been called self._loop.call_soon(self._wakeup_waiter)
self._loop.call_soon(self._waiter._set_result_unless_cancelled,
None)
if self._loop.get_debug(): if self._loop.get_debug():
dt = self._loop.time() - start_time dt = self._loop.time() - start_time
@ -924,7 +931,7 @@ class _SelectorDatagramTransport(_SelectorTransport):
self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.add_reader(self._sock_fd, self._read_ready)
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
if waiter is not None: if waiter is not None:
# wait until protocol.connection_made() has been called # only wake up the waiter when connection_made() has been called
self._loop.call_soon(waiter._set_result_unless_cancelled, None) self._loop.call_soon(waiter._set_result_unless_cancelled, None)
def get_write_buffer_size(self): def get_write_buffer_size(self):

View File

@ -418,6 +418,16 @@ class SSLProtocol(protocols.Protocol):
self._in_shutdown = False self._in_shutdown = False
self._transport = None self._transport = None
def _wakeup_waiter(self, exc=None):
if self._waiter is None:
return
if not self._waiter.cancelled():
if exc is not None:
self._waiter.set_exception(exc)
else:
self._waiter.set_result(None)
self._waiter = None
def connection_made(self, transport): def connection_made(self, transport):
"""Called when the low-level connection is made. """Called when the low-level connection is made.
@ -489,6 +499,9 @@ class SSLProtocol(protocols.Protocol):
try: try:
if self._loop.get_debug(): if self._loop.get_debug():
logger.debug("%r received EOF", self) logger.debug("%r received EOF", self)
self._wakeup_waiter(ConnectionResetError)
if not self._in_handshake: if not self._in_handshake:
keep_open = self._app_protocol.eof_received() keep_open = self._app_protocol.eof_received()
if keep_open: if keep_open:
@ -552,8 +565,7 @@ class SSLProtocol(protocols.Protocol):
self, exc_info=True) self, exc_info=True)
self._transport.close() self._transport.close()
if isinstance(exc, Exception): if isinstance(exc, Exception):
if self._waiter is not None and not self._waiter.cancelled(): self._wakeup_waiter(exc)
self._waiter.set_exception(exc)
return return
else: else:
raise raise
@ -568,9 +580,7 @@ class SSLProtocol(protocols.Protocol):
compression=sslobj.compression(), compression=sslobj.compression(),
) )
self._app_protocol.connection_made(self._app_transport) self._app_protocol.connection_made(self._app_transport)
if self._waiter is not None: self._wakeup_waiter()
# wait until protocol.connection_made() has been called
self._waiter._set_result_unless_cancelled(None)
self._session_established = True self._session_established = True
# In case transport.write() was already called. Don't call # In case transport.write() was already called. Don't call
# immediatly _process_write_backlog(), but schedule it: # immediatly _process_write_backlog(), but schedule it:

View File

@ -301,7 +301,7 @@ class _UnixReadPipeTransport(transports.ReadTransport):
self._loop.add_reader(self._fileno, self._read_ready) self._loop.add_reader(self._fileno, self._read_ready)
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
if waiter is not None: if waiter is not None:
# wait until protocol.connection_made() has been called # only wake up the waiter when connection_made() has been called
self._loop.call_soon(waiter._set_result_unless_cancelled, None) self._loop.call_soon(waiter._set_result_unless_cancelled, None)
def __repr__(self): def __repr__(self):
@ -409,7 +409,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin,
self._loop.call_soon(self._protocol.connection_made, self) self._loop.call_soon(self._protocol.connection_made, self)
if waiter is not None: if waiter is not None:
# wait until protocol.connection_made() has been called # only wake up the waiter when connection_made() has been called
self._loop.call_soon(waiter._set_result_unless_cancelled, None) self._loop.call_soon(waiter._set_result_unless_cancelled, None)
def __repr__(self): def __repr__(self):

View File

@ -59,6 +59,7 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
def test_make_socket_transport(self): def test_make_socket_transport(self):
m = mock.Mock() m = mock.Mock()
self.loop.add_reader = mock.Mock() self.loop.add_reader = mock.Mock()
self.loop.add_reader._is_coroutine = False
transport = self.loop._make_socket_transport(m, asyncio.Protocol()) transport = self.loop._make_socket_transport(m, asyncio.Protocol())
self.assertIsInstance(transport, _SelectorSocketTransport) self.assertIsInstance(transport, _SelectorSocketTransport)
close_transport(transport) close_transport(transport)
@ -67,6 +68,7 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
def test_make_ssl_transport(self): def test_make_ssl_transport(self):
m = mock.Mock() m = mock.Mock()
self.loop.add_reader = mock.Mock() self.loop.add_reader = mock.Mock()
self.loop.add_reader._is_coroutine = False
self.loop.add_writer = mock.Mock() self.loop.add_writer = mock.Mock()
self.loop.remove_reader = mock.Mock() self.loop.remove_reader = mock.Mock()
self.loop.remove_writer = mock.Mock() self.loop.remove_writer = mock.Mock()
@ -770,20 +772,24 @@ class SelectorSocketTransportTests(test_utils.TestCase):
return transport return transport
def test_ctor(self): def test_ctor(self):
tr = self.socket_transport() waiter = asyncio.Future(loop=self.loop)
tr = self.socket_transport(waiter=waiter)
self.loop.run_until_complete(waiter)
self.loop.assert_reader(7, tr._read_ready) self.loop.assert_reader(7, tr._read_ready)
test_utils.run_briefly(self.loop) test_utils.run_briefly(self.loop)
self.protocol.connection_made.assert_called_with(tr) self.protocol.connection_made.assert_called_with(tr)
def test_ctor_with_waiter(self): def test_ctor_with_waiter(self):
fut = asyncio.Future(loop=self.loop) waiter = asyncio.Future(loop=self.loop)
self.socket_transport(waiter=waiter)
self.loop.run_until_complete(waiter)
self.socket_transport(waiter=fut) self.assertIsNone(waiter.result())
test_utils.run_briefly(self.loop)
self.assertIsNone(fut.result())
def test_pause_resume_reading(self): def test_pause_resume_reading(self):
tr = self.socket_transport() tr = self.socket_transport()
test_utils.run_briefly(self.loop)
self.assertFalse(tr._paused) self.assertFalse(tr._paused)
self.loop.assert_reader(7, tr._read_ready) self.loop.assert_reader(7, tr._read_ready)
tr.pause_reading() tr.pause_reading()

View File

@ -12,21 +12,36 @@ from asyncio import sslproto
from asyncio import test_utils from asyncio import test_utils
@unittest.skipIf(ssl is None, 'No ssl module')
class SslProtoHandshakeTests(test_utils.TestCase): class SslProtoHandshakeTests(test_utils.TestCase):
def setUp(self): def setUp(self):
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
self.set_event_loop(self.loop) self.set_event_loop(self.loop)
@unittest.skipIf(ssl is None, 'No ssl module') def ssl_protocol(self, waiter=None):
sslcontext = test_utils.dummy_ssl_context()
app_proto = asyncio.Protocol()
return sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter)
def connection_made(self, ssl_proto, do_handshake=None):
transport = mock.Mock()
sslpipe = mock.Mock()
sslpipe.shutdown.return_value = b''
if do_handshake:
sslpipe.do_handshake.side_effect = do_handshake
else:
def mock_handshake(callback):
return []
sslpipe.do_handshake.side_effect = mock_handshake
with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe):
ssl_proto.connection_made(transport)
def test_cancel_handshake(self): def test_cancel_handshake(self):
# Python issue #23197: cancelling an handshake must not raise an # Python issue #23197: cancelling an handshake must not raise an
# exception or log an error, even if the handshake failed # exception or log an error, even if the handshake failed
sslcontext = test_utils.dummy_ssl_context()
app_proto = asyncio.Protocol()
waiter = asyncio.Future(loop=self.loop) waiter = asyncio.Future(loop=self.loop)
ssl_proto = sslproto.SSLProtocol(self.loop, app_proto, sslcontext, ssl_proto = self.ssl_protocol(waiter)
waiter)
handshake_fut = asyncio.Future(loop=self.loop) handshake_fut = asyncio.Future(loop=self.loop)
def do_handshake(callback): def do_handshake(callback):
@ -36,12 +51,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
return [] return []
waiter.cancel() waiter.cancel()
transport = mock.Mock() self.connection_made(ssl_proto, do_handshake)
sslpipe = mock.Mock()
sslpipe.shutdown.return_value = b''
sslpipe.do_handshake.side_effect = do_handshake
with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe):
ssl_proto.connection_made(transport)
with test_utils.disable_logger(): with test_utils.disable_logger():
self.loop.run_until_complete(handshake_fut) self.loop.run_until_complete(handshake_fut)
@ -49,6 +59,14 @@ class SslProtoHandshakeTests(test_utils.TestCase):
# Close the transport # Close the transport
ssl_proto._app_transport.close() ssl_proto._app_transport.close()
def test_eof_received_waiter(self):
waiter = asyncio.Future(loop=self.loop)
ssl_proto = self.ssl_protocol(waiter)
self.connection_made(ssl_proto)
ssl_proto.eof_received()
test_utils.run_briefly(self.loop)
self.assertIsInstance(waiter.exception(), ConnectionResetError)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()