diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index ed170622144..0f533a5e590 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -38,7 +38,7 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin, self._server._attach() self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - # wait until protocol.connection_made() has been called + # only wake up the waiter when connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) def __repr__(self): diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 24f8461509a..f4996293621 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -578,10 +578,12 @@ class _SelectorSocketTransport(_SelectorTransport): self._eof = False self._paused = False - self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) + # only start reading when connection_made() has been called + self._loop.call_soon(self._loop.add_reader, + self._sock_fd, self._read_ready) if waiter is not None: - # wait until protocol.connection_made() has been called + # only wake up the waiter when connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) def pause_reading(self): @@ -732,6 +734,16 @@ class _SelectorSslTransport(_SelectorTransport): start_time = None self._on_handshake(start_time) + def _wakeup_waiter(self, exc=None): + if self._waiter is None: + return + if not self._waiter.cancelled(): + if exc is not None: + self._waiter.set_exception(exc) + else: + self._waiter.set_result(None) + self._waiter = None + def _on_handshake(self, start_time): try: self._sock.do_handshake() @@ -750,8 +762,7 @@ class _SelectorSslTransport(_SelectorTransport): self._loop.remove_reader(self._sock_fd) self._loop.remove_writer(self._sock_fd) self._sock.close() - if self._waiter is not None and not self._waiter.cancelled(): - self._waiter.set_exception(exc) + self._wakeup_waiter(exc) if isinstance(exc, Exception): return else: @@ -774,9 +785,7 @@ class _SelectorSslTransport(_SelectorTransport): "on matching the hostname", self, exc_info=True) self._sock.close() - if (self._waiter is not None - and not self._waiter.cancelled()): - self._waiter.set_exception(exc) + self._wakeup_waiter(exc) return # Add extra info that becomes available after handshake. @@ -789,10 +798,8 @@ class _SelectorSslTransport(_SelectorTransport): self._write_wants_read = False self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) - if self._waiter is not None: - # wait until protocol.connection_made() has been called - self._loop.call_soon(self._waiter._set_result_unless_cancelled, - None) + # only wake up the waiter when connection_made() has been called + self._loop.call_soon(self._wakeup_waiter) if self._loop.get_debug(): dt = self._loop.time() - start_time @@ -924,7 +931,7 @@ class _SelectorDatagramTransport(_SelectorTransport): self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - # wait until protocol.connection_made() has been called + # only wake up the waiter when connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) def get_write_buffer_size(self): diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py index f2b856c40cb..fc809b9831d 100644 --- a/Lib/asyncio/sslproto.py +++ b/Lib/asyncio/sslproto.py @@ -418,6 +418,16 @@ class SSLProtocol(protocols.Protocol): self._in_shutdown = False self._transport = None + def _wakeup_waiter(self, exc=None): + if self._waiter is None: + return + if not self._waiter.cancelled(): + if exc is not None: + self._waiter.set_exception(exc) + else: + self._waiter.set_result(None) + self._waiter = None + def connection_made(self, transport): """Called when the low-level connection is made. @@ -489,6 +499,9 @@ class SSLProtocol(protocols.Protocol): try: if self._loop.get_debug(): logger.debug("%r received EOF", self) + + self._wakeup_waiter(ConnectionResetError) + if not self._in_handshake: keep_open = self._app_protocol.eof_received() if keep_open: @@ -552,8 +565,7 @@ class SSLProtocol(protocols.Protocol): self, exc_info=True) self._transport.close() if isinstance(exc, Exception): - if self._waiter is not None and not self._waiter.cancelled(): - self._waiter.set_exception(exc) + self._wakeup_waiter(exc) return else: raise @@ -568,9 +580,7 @@ class SSLProtocol(protocols.Protocol): compression=sslobj.compression(), ) self._app_protocol.connection_made(self._app_transport) - if self._waiter is not None: - # wait until protocol.connection_made() has been called - self._waiter._set_result_unless_cancelled(None) + self._wakeup_waiter() self._session_established = True # In case transport.write() was already called. Don't call # immediatly _process_write_backlog(), but schedule it: diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py index 97f9addde88..67973f14f3f 100644 --- a/Lib/asyncio/unix_events.py +++ b/Lib/asyncio/unix_events.py @@ -301,7 +301,7 @@ class _UnixReadPipeTransport(transports.ReadTransport): self._loop.add_reader(self._fileno, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - # wait until protocol.connection_made() has been called + # only wake up the waiter when connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) def __repr__(self): @@ -409,7 +409,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin, self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - # wait until protocol.connection_made() has been called + # only wake up the waiter when connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) def __repr__(self): diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py index ad86ada3425..51526163953 100644 --- a/Lib/test/test_asyncio/test_selector_events.py +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -59,6 +59,7 @@ class BaseSelectorEventLoopTests(test_utils.TestCase): def test_make_socket_transport(self): m = mock.Mock() self.loop.add_reader = mock.Mock() + self.loop.add_reader._is_coroutine = False transport = self.loop._make_socket_transport(m, asyncio.Protocol()) self.assertIsInstance(transport, _SelectorSocketTransport) close_transport(transport) @@ -67,6 +68,7 @@ class BaseSelectorEventLoopTests(test_utils.TestCase): def test_make_ssl_transport(self): m = mock.Mock() self.loop.add_reader = mock.Mock() + self.loop.add_reader._is_coroutine = False self.loop.add_writer = mock.Mock() self.loop.remove_reader = mock.Mock() self.loop.remove_writer = mock.Mock() @@ -770,20 +772,24 @@ class SelectorSocketTransportTests(test_utils.TestCase): return transport def test_ctor(self): - tr = self.socket_transport() + waiter = asyncio.Future(loop=self.loop) + tr = self.socket_transport(waiter=waiter) + self.loop.run_until_complete(waiter) + self.loop.assert_reader(7, tr._read_ready) test_utils.run_briefly(self.loop) self.protocol.connection_made.assert_called_with(tr) def test_ctor_with_waiter(self): - fut = asyncio.Future(loop=self.loop) + waiter = asyncio.Future(loop=self.loop) + self.socket_transport(waiter=waiter) + self.loop.run_until_complete(waiter) - self.socket_transport(waiter=fut) - test_utils.run_briefly(self.loop) - self.assertIsNone(fut.result()) + self.assertIsNone(waiter.result()) def test_pause_resume_reading(self): tr = self.socket_transport() + test_utils.run_briefly(self.loop) self.assertFalse(tr._paused) self.loop.assert_reader(7, tr._read_ready) tr.pause_reading() diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py index b1a61c483d2..148e30dffeb 100644 --- a/Lib/test/test_asyncio/test_sslproto.py +++ b/Lib/test/test_asyncio/test_sslproto.py @@ -12,21 +12,36 @@ from asyncio import sslproto from asyncio import test_utils +@unittest.skipIf(ssl is None, 'No ssl module') class SslProtoHandshakeTests(test_utils.TestCase): def setUp(self): self.loop = asyncio.new_event_loop() self.set_event_loop(self.loop) - @unittest.skipIf(ssl is None, 'No ssl module') + def ssl_protocol(self, waiter=None): + sslcontext = test_utils.dummy_ssl_context() + app_proto = asyncio.Protocol() + return sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter) + + def connection_made(self, ssl_proto, do_handshake=None): + transport = mock.Mock() + sslpipe = mock.Mock() + sslpipe.shutdown.return_value = b'' + if do_handshake: + sslpipe.do_handshake.side_effect = do_handshake + else: + def mock_handshake(callback): + return [] + sslpipe.do_handshake.side_effect = mock_handshake + with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe): + ssl_proto.connection_made(transport) + def test_cancel_handshake(self): # Python issue #23197: cancelling an handshake must not raise an # exception or log an error, even if the handshake failed - sslcontext = test_utils.dummy_ssl_context() - app_proto = asyncio.Protocol() waiter = asyncio.Future(loop=self.loop) - ssl_proto = sslproto.SSLProtocol(self.loop, app_proto, sslcontext, - waiter) + ssl_proto = self.ssl_protocol(waiter) handshake_fut = asyncio.Future(loop=self.loop) def do_handshake(callback): @@ -36,12 +51,7 @@ class SslProtoHandshakeTests(test_utils.TestCase): return [] waiter.cancel() - transport = mock.Mock() - sslpipe = mock.Mock() - sslpipe.shutdown.return_value = b'' - sslpipe.do_handshake.side_effect = do_handshake - with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe): - ssl_proto.connection_made(transport) + self.connection_made(ssl_proto, do_handshake) with test_utils.disable_logger(): self.loop.run_until_complete(handshake_fut) @@ -49,6 +59,14 @@ class SslProtoHandshakeTests(test_utils.TestCase): # Close the transport ssl_proto._app_transport.close() + def test_eof_received_waiter(self): + waiter = asyncio.Future(loop=self.loop) + ssl_proto = self.ssl_protocol(waiter) + self.connection_made(ssl_proto) + ssl_proto.eof_received() + test_utils.run_briefly(self.loop) + self.assertIsInstance(waiter.exception(), ConnectionResetError) + if __name__ == '__main__': unittest.main()