From 51eb1c6b9c0b382dfd6e0428eacff0c7891a6fc3 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 20 Dec 2017 20:24:43 +0200 Subject: [PATCH] bpo-29970: Make ssh_handshake_timeout None by default (#4939) * Make ssh_handshake_timeout None by default. * Raise ValueError if ssl_handshake_timeout is used without ssl. * Raise ValueError if ssl_handshake_timeout is not positive. --- Doc/library/asyncio-eventloop.rst | 13 +++++++----- Lib/asyncio/base_events.py | 24 +++++++++++++++++------ Lib/asyncio/events.py | 9 ++++----- Lib/asyncio/proactor_events.py | 4 ++-- Lib/asyncio/sslproto.py | 9 ++++++++- Lib/asyncio/unix_events.py | 11 +++++++++-- Lib/test/test_asyncio/test_base_events.py | 16 +++++++++++++++ Lib/test/test_asyncio/test_events.py | 10 ++++++++++ Lib/test/test_asyncio/test_sslproto.py | 16 +++++++++++++++ Lib/test/test_asyncio/test_unix_events.py | 17 ++++++++++++++++ 10 files changed, 108 insertions(+), 21 deletions(-) diff --git a/Doc/library/asyncio-eventloop.rst b/Doc/library/asyncio-eventloop.rst index d20e995d353..5dd258df312 100644 --- a/Doc/library/asyncio-eventloop.rst +++ b/Doc/library/asyncio-eventloop.rst @@ -261,7 +261,7 @@ Tasks Creating connections -------------------- -.. coroutinemethod:: AbstractEventLoop.create_connection(protocol_factory, host=None, port=None, \*, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None, ssl_handshake_timeout=10.0) +.. coroutinemethod:: AbstractEventLoop.create_connection(protocol_factory, host=None, port=None, \*, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None, ssl_handshake_timeout=None) Create a streaming transport connection to a given Internet *host* and *port*: socket family :py:data:`~socket.AF_INET` or @@ -327,6 +327,7 @@ Creating connections * *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to wait for the SSL handshake to complete before aborting the connection. + ``10.0`` seconds if ``None`` (default). .. versionadded:: 3.7 @@ -393,7 +394,7 @@ Creating connections :ref:`UDP echo server protocol ` examples. -.. coroutinemethod:: AbstractEventLoop.create_unix_connection(protocol_factory, path=None, \*, ssl=None, sock=None, server_hostname=None, ssl_handshake_timeout=10.0) +.. coroutinemethod:: AbstractEventLoop.create_unix_connection(protocol_factory, path=None, \*, ssl=None, sock=None, server_hostname=None, ssl_handshake_timeout=None) Create UNIX connection: socket family :py:data:`~socket.AF_UNIX`, socket type :py:data:`~socket.SOCK_STREAM`. The :py:data:`~socket.AF_UNIX` socket @@ -423,7 +424,7 @@ Creating connections Creating listening connections ------------------------------ -.. coroutinemethod:: AbstractEventLoop.create_server(protocol_factory, host=None, port=None, \*, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None, reuse_port=None, ssl_handshake_timeout=10.0) +.. coroutinemethod:: AbstractEventLoop.create_server(protocol_factory, host=None, port=None, \*, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None, reuse_port=None, ssl_handshake_timeout=None) Create a TCP server (socket type :data:`~socket.SOCK_STREAM`) bound to *host* and *port*. @@ -469,6 +470,7 @@ Creating listening connections * *ssl_handshake_timeout* is (for an SSL server) the time in seconds to wait for the SSL handshake to complete before aborting the connection. + ``10.0`` seconds if ``None`` (default). .. versionadded:: 3.7 @@ -488,7 +490,7 @@ Creating listening connections The *host* parameter can now be a sequence of strings. -.. coroutinemethod:: AbstractEventLoop.create_unix_server(protocol_factory, path=None, \*, sock=None, backlog=100, ssl=None, ssl_handshake_timeout=10.0) +.. coroutinemethod:: AbstractEventLoop.create_unix_server(protocol_factory, path=None, \*, sock=None, backlog=100, ssl=None, ssl_handshake_timeout=None) Similar to :meth:`AbstractEventLoop.create_server`, but specific to the socket family :py:data:`~socket.AF_UNIX`. @@ -507,7 +509,7 @@ Creating listening connections The *path* parameter can now be a :class:`~pathlib.Path` object. -.. coroutinemethod:: BaseEventLoop.connect_accepted_socket(protocol_factory, sock, \*, ssl=None, ssl_handshake_timeout=10.0) +.. coroutinemethod:: BaseEventLoop.connect_accepted_socket(protocol_factory, sock, \*, ssl=None, ssl_handshake_timeout=None) Handle an accepted connection. @@ -524,6 +526,7 @@ Creating listening connections * *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to wait for the SSL handshake to complete before aborting the connection. + ``10.0`` seconds if ``None`` (default). When completed it returns a ``(transport, protocol)`` pair. diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 6246f4e221c..2ab8a76e0c5 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -29,7 +29,6 @@ import sys import warnings import weakref -from . import constants from . import coroutines from . import events from . import futures @@ -280,7 +279,7 @@ class BaseEventLoop(events.AbstractEventLoop): self, rawsock, protocol, sslcontext, waiter=None, *, server_side=False, server_hostname=None, extra=None, server=None, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): + ssl_handshake_timeout=None): """Create SSL transport.""" raise NotImplementedError @@ -643,7 +642,7 @@ class BaseEventLoop(events.AbstractEventLoop): *, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): + ssl_handshake_timeout=None): """Connect to a TCP server. Create a streaming transport connection to a given Internet host and @@ -674,6 +673,10 @@ class BaseEventLoop(events.AbstractEventLoop): 'when using ssl without a host') server_hostname = host + if ssl_handshake_timeout is not None and not ssl: + raise ValueError( + 'ssl_handshake_timeout is only meaningful with ssl') + if host is not None or port is not None: if sock is not None: raise ValueError( @@ -769,7 +772,7 @@ class BaseEventLoop(events.AbstractEventLoop): async def _create_connection_transport( self, sock, protocol_factory, ssl, server_hostname, server_side=False, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): + ssl_handshake_timeout=None): sock.setblocking(False) @@ -948,7 +951,7 @@ class BaseEventLoop(events.AbstractEventLoop): ssl=None, reuse_address=None, reuse_port=None, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): + ssl_handshake_timeout=None): """Create a TCP server. The host parameter can be a string, in that case the TCP server is @@ -966,6 +969,11 @@ class BaseEventLoop(events.AbstractEventLoop): """ if isinstance(ssl, bool): raise TypeError('ssl argument must be an SSLContext or None') + + if ssl_handshake_timeout is not None and ssl is None: + raise ValueError( + 'ssl_handshake_timeout is only meaningful with ssl') + if host is not None or port is not None: if sock is not None: raise ValueError( @@ -1046,7 +1054,7 @@ class BaseEventLoop(events.AbstractEventLoop): async def connect_accepted_socket( self, protocol_factory, sock, *, ssl=None, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): + ssl_handshake_timeout=None): """Handle an accepted connection. This is used by servers that accept connections outside of @@ -1058,6 +1066,10 @@ class BaseEventLoop(events.AbstractEventLoop): if sock.type != socket.SOCK_STREAM: raise ValueError(f'A Stream Socket was expected, got {sock!r}') + if ssl_handshake_timeout is not None and not ssl: + raise ValueError( + 'ssl_handshake_timeout is only meaningful with ssl') + transport, protocol = await self._create_connection_transport( sock, protocol_factory, ssl, '', server_side=True, ssl_handshake_timeout=ssl_handshake_timeout) diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py index c9033c020f3..3a5dbadbb10 100644 --- a/Lib/asyncio/events.py +++ b/Lib/asyncio/events.py @@ -17,7 +17,6 @@ import subprocess import sys import threading -from . import constants from . import format_helpers @@ -255,7 +254,7 @@ class AbstractEventLoop: *, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): + ssl_handshake_timeout=None): raise NotImplementedError async def create_server( @@ -263,7 +262,7 @@ class AbstractEventLoop: *, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None, reuse_port=None, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): + ssl_handshake_timeout=None): """A coroutine which creates a TCP server bound to host and port. The return value is a Server object which can be used to stop @@ -310,13 +309,13 @@ class AbstractEventLoop: self, protocol_factory, path=None, *, ssl=None, sock=None, server_hostname=None, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): + ssl_handshake_timeout=None): raise NotImplementedError async def create_unix_server( self, protocol_factory, path=None, *, sock=None, backlog=100, ssl=None, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): + ssl_handshake_timeout=None): """A coroutine which creates a UNIX Domain Socket server. The return value is a Server object, which can be used to stop diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index bc319b06ed6..2661cddef73 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -393,7 +393,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): self, rawsock, protocol, sslcontext, waiter=None, *, server_side=False, server_hostname=None, extra=None, server=None, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): + ssl_handshake_timeout=None): ssl_protocol = sslproto.SSLProtocol( self, protocol, sslcontext, waiter, server_side, server_hostname, @@ -491,7 +491,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): def _start_serving(self, protocol_factory, sock, sslcontext=None, server=None, backlog=100, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): + ssl_handshake_timeout=None): def loop(f=None): try: diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py index 8bcc6cc0433..2d377c4ae39 100644 --- a/Lib/asyncio/sslproto.py +++ b/Lib/asyncio/sslproto.py @@ -402,10 +402,17 @@ class SSLProtocol(protocols.Protocol): def __init__(self, loop, app_protocol, sslcontext, waiter, server_side=False, server_hostname=None, call_connection_made=True, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): + ssl_handshake_timeout=None): if ssl is None: raise RuntimeError('stdlib ssl module not available') + if ssl_handshake_timeout is None: + ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT + elif ssl_handshake_timeout <= 0: + raise ValueError( + f"ssl_handshake_timeout should be a positive number, " + f"got {ssl_handshake_timeout}") + if not sslcontext: sslcontext = _create_transport_context( server_side, server_hostname) diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py index e2344582268..69c719c3239 100644 --- a/Lib/asyncio/unix_events.py +++ b/Lib/asyncio/unix_events.py @@ -196,7 +196,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): self, protocol_factory, path=None, *, ssl=None, sock=None, server_hostname=None, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): + ssl_handshake_timeout=None): assert server_hostname is None or isinstance(server_hostname, str) if ssl: if server_hostname is None: @@ -205,6 +205,9 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): else: if server_hostname is not None: raise ValueError('server_hostname is only meaningful with ssl') + if ssl_handshake_timeout is not None: + raise ValueError( + 'ssl_handshake_timeout is only meaningful with ssl') if path is not None: if sock is not None: @@ -237,10 +240,14 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): async def create_unix_server( self, protocol_factory, path=None, *, sock=None, backlog=100, ssl=None, - ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): + ssl_handshake_timeout=None): if isinstance(ssl, bool): raise TypeError('ssl argument must be an SSLContext or None') + if ssl_handshake_timeout is not None and not ssl: + raise ValueError( + 'ssl_handshake_timeout is only meaningful with ssl') + if path is not None: if sock is not None: raise ValueError( diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index 488257b341f..1fc74737026 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -1053,6 +1053,14 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): 'A Stream Socket was expected'): self.loop.run_until_complete(coro) + def test_create_server_ssl_timeout_for_plain_socket(self): + coro = self.loop.create_server( + MyProto, 'example.com', 80, ssl_handshake_timeout=1) + with self.assertRaisesRegex( + ValueError, + 'ssl_handshake_timeout is only meaningful with ssl'): + self.loop.run_until_complete(coro) + @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'), 'no socket.SOCK_NONBLOCK (linux only)') def test_create_server_stream_bittype(self): @@ -1362,6 +1370,14 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): self.addCleanup(sock.close) self.assertRaises(ValueError, self.loop.run_until_complete, coro) + def test_create_connection_ssl_timeout_for_plain_socket(self): + coro = self.loop.create_connection( + MyProto, 'example.com', 80, ssl_handshake_timeout=1) + with self.assertRaisesRegex( + ValueError, + 'ssl_handshake_timeout is only meaningful with ssl'): + self.loop.run_until_complete(coro) + def test_create_server_empty_host(self): # if host is empty string use None instead host = object() diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index 58e94d46da8..aefa33ba251 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -852,6 +852,16 @@ class EventLoopTestsMixin: self.test_connect_accepted_socket(server_context, client_context) + def test_connect_accepted_socket_ssl_timeout_for_plain_socket(self): + sock = socket.socket() + self.addCleanup(sock.close) + coro = self.loop.connect_accepted_socket( + MyProto, sock, ssl_handshake_timeout=1) + with self.assertRaisesRegex( + ValueError, + 'ssl_handshake_timeout is only meaningful with ssl'): + self.loop.run_until_complete(coro) + @mock.patch('asyncio.base_events.socket') def create_server_multiple_hosts(self, family, hosts, mock_sock): @asyncio.coroutine diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py index 1c42a35128f..a7498e85c25 100644 --- a/Lib/test/test_asyncio/test_sslproto.py +++ b/Lib/test/test_asyncio/test_sslproto.py @@ -75,6 +75,22 @@ class SslProtoHandshakeTests(test_utils.TestCase): self.loop.run_until_complete(tasks.sleep(0.2, loop=self.loop)) self.assertTrue(transport.abort.called) + def test_handshake_timeout_zero(self): + sslcontext = test_utils.dummy_ssl_context() + app_proto = mock.Mock() + waiter = mock.Mock() + with self.assertRaisesRegex(ValueError, 'a positive number'): + sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter, + ssl_handshake_timeout=0) + + def test_handshake_timeout_negative(self): + sslcontext = test_utils.dummy_ssl_context() + app_proto = mock.Mock() + waiter = mock.Mock() + with self.assertRaisesRegex(ValueError, 'a positive number'): + sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter, + ssl_handshake_timeout=-10) + def test_eof_received_waiter(self): waiter = asyncio.Future(loop=self.loop) ssl_proto = self.ssl_protocol(waiter) diff --git a/Lib/test/test_asyncio/test_unix_events.py b/Lib/test/test_asyncio/test_unix_events.py index e13cc3749da..097d0ef73dc 100644 --- a/Lib/test/test_asyncio/test_unix_events.py +++ b/Lib/test/test_asyncio/test_unix_events.py @@ -327,6 +327,14 @@ class SelectorEventLoopUnixSocketTests(test_utils.TestCase): finally: os.unlink(fn) + def test_create_unix_server_ssl_timeout_with_plain_sock(self): + coro = self.loop.create_unix_server(lambda: None, path='spam', + ssl_handshake_timeout=1) + with self.assertRaisesRegex( + ValueError, + 'ssl_handshake_timeout is only meaningful with ssl'): + self.loop.run_until_complete(coro) + def test_create_unix_connection_path_inetsock(self): sock = socket.socket() with sock: @@ -383,6 +391,15 @@ class SelectorEventLoopUnixSocketTests(test_utils.TestCase): self.loop.run_until_complete(coro) + def test_create_unix_connection_ssl_timeout_with_plain_sock(self): + coro = self.loop.create_unix_connection(lambda: None, path='spam', + ssl_handshake_timeout=1) + with self.assertRaisesRegex( + ValueError, + 'ssl_handshake_timeout is only meaningful with ssl'): + self.loop.run_until_complete(coro) + + class UnixReadPipeTransportTests(test_utils.TestCase):