bpo-29970: Make ssh_handshake_timeout None by default (#4939)

* Make ssh_handshake_timeout None by default.
* Raise ValueError if ssl_handshake_timeout is used without ssl.
* Raise ValueError if ssl_handshake_timeout is not positive.
This commit is contained in:
Andrew Svetlov 2017-12-20 20:24:43 +02:00 committed by GitHub
parent a7a751dd7b
commit 51eb1c6b9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 108 additions and 21 deletions

View File

@ -261,7 +261,7 @@ Tasks
Creating connections Creating connections
-------------------- --------------------
.. coroutinemethod:: AbstractEventLoop.create_connection(protocol_factory, host=None, port=None, \*, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None, ssl_handshake_timeout=10.0) .. coroutinemethod:: AbstractEventLoop.create_connection(protocol_factory, host=None, port=None, \*, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None, ssl_handshake_timeout=None)
Create a streaming transport connection to a given Internet *host* and Create a streaming transport connection to a given Internet *host* and
*port*: socket family :py:data:`~socket.AF_INET` or *port*: socket family :py:data:`~socket.AF_INET` or
@ -327,6 +327,7 @@ Creating connections
* *ssl_handshake_timeout* is (for an SSL connection) the time in seconds * *ssl_handshake_timeout* is (for an SSL connection) the time in seconds
to wait for the SSL handshake to complete before aborting the connection. to wait for the SSL handshake to complete before aborting the connection.
``10.0`` seconds if ``None`` (default).
.. versionadded:: 3.7 .. versionadded:: 3.7
@ -393,7 +394,7 @@ Creating connections
:ref:`UDP echo server protocol <asyncio-udp-echo-server-protocol>` examples. :ref:`UDP echo server protocol <asyncio-udp-echo-server-protocol>` examples.
.. coroutinemethod:: AbstractEventLoop.create_unix_connection(protocol_factory, path=None, \*, ssl=None, sock=None, server_hostname=None, ssl_handshake_timeout=10.0) .. coroutinemethod:: AbstractEventLoop.create_unix_connection(protocol_factory, path=None, \*, ssl=None, sock=None, server_hostname=None, ssl_handshake_timeout=None)
Create UNIX connection: socket family :py:data:`~socket.AF_UNIX`, socket Create UNIX connection: socket family :py:data:`~socket.AF_UNIX`, socket
type :py:data:`~socket.SOCK_STREAM`. The :py:data:`~socket.AF_UNIX` socket type :py:data:`~socket.SOCK_STREAM`. The :py:data:`~socket.AF_UNIX` socket
@ -423,7 +424,7 @@ Creating connections
Creating listening connections Creating listening connections
------------------------------ ------------------------------
.. coroutinemethod:: AbstractEventLoop.create_server(protocol_factory, host=None, port=None, \*, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None, reuse_port=None, ssl_handshake_timeout=10.0) .. coroutinemethod:: AbstractEventLoop.create_server(protocol_factory, host=None, port=None, \*, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None, reuse_port=None, ssl_handshake_timeout=None)
Create a TCP server (socket type :data:`~socket.SOCK_STREAM`) bound to Create a TCP server (socket type :data:`~socket.SOCK_STREAM`) bound to
*host* and *port*. *host* and *port*.
@ -469,6 +470,7 @@ Creating listening connections
* *ssl_handshake_timeout* is (for an SSL server) the time in seconds to wait * *ssl_handshake_timeout* is (for an SSL server) the time in seconds to wait
for the SSL handshake to complete before aborting the connection. for the SSL handshake to complete before aborting the connection.
``10.0`` seconds if ``None`` (default).
.. versionadded:: 3.7 .. versionadded:: 3.7
@ -488,7 +490,7 @@ Creating listening connections
The *host* parameter can now be a sequence of strings. The *host* parameter can now be a sequence of strings.
.. coroutinemethod:: AbstractEventLoop.create_unix_server(protocol_factory, path=None, \*, sock=None, backlog=100, ssl=None, ssl_handshake_timeout=10.0) .. coroutinemethod:: AbstractEventLoop.create_unix_server(protocol_factory, path=None, \*, sock=None, backlog=100, ssl=None, ssl_handshake_timeout=None)
Similar to :meth:`AbstractEventLoop.create_server`, but specific to the Similar to :meth:`AbstractEventLoop.create_server`, but specific to the
socket family :py:data:`~socket.AF_UNIX`. socket family :py:data:`~socket.AF_UNIX`.
@ -507,7 +509,7 @@ Creating listening connections
The *path* parameter can now be a :class:`~pathlib.Path` object. The *path* parameter can now be a :class:`~pathlib.Path` object.
.. coroutinemethod:: BaseEventLoop.connect_accepted_socket(protocol_factory, sock, \*, ssl=None, ssl_handshake_timeout=10.0) .. coroutinemethod:: BaseEventLoop.connect_accepted_socket(protocol_factory, sock, \*, ssl=None, ssl_handshake_timeout=None)
Handle an accepted connection. Handle an accepted connection.
@ -524,6 +526,7 @@ Creating listening connections
* *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to * *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to
wait for the SSL handshake to complete before aborting the connection. wait for the SSL handshake to complete before aborting the connection.
``10.0`` seconds if ``None`` (default).
When completed it returns a ``(transport, protocol)`` pair. When completed it returns a ``(transport, protocol)`` pair.

View File

@ -29,7 +29,6 @@ import sys
import warnings import warnings
import weakref import weakref
from . import constants
from . import coroutines from . import coroutines
from . import events from . import events
from . import futures from . import futures
@ -280,7 +279,7 @@ class BaseEventLoop(events.AbstractEventLoop):
self, rawsock, protocol, sslcontext, waiter=None, self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None, *, server_side=False, server_hostname=None,
extra=None, server=None, extra=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): ssl_handshake_timeout=None):
"""Create SSL transport.""" """Create SSL transport."""
raise NotImplementedError raise NotImplementedError
@ -643,7 +642,7 @@ class BaseEventLoop(events.AbstractEventLoop):
*, ssl=None, family=0, *, ssl=None, family=0,
proto=0, flags=0, sock=None, proto=0, flags=0, sock=None,
local_addr=None, server_hostname=None, local_addr=None, server_hostname=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): ssl_handshake_timeout=None):
"""Connect to a TCP server. """Connect to a TCP server.
Create a streaming transport connection to a given Internet host and Create a streaming transport connection to a given Internet host and
@ -674,6 +673,10 @@ class BaseEventLoop(events.AbstractEventLoop):
'when using ssl without a host') 'when using ssl without a host')
server_hostname = host server_hostname = host
if ssl_handshake_timeout is not None and not ssl:
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
if host is not None or port is not None: if host is not None or port is not None:
if sock is not None: if sock is not None:
raise ValueError( raise ValueError(
@ -769,7 +772,7 @@ class BaseEventLoop(events.AbstractEventLoop):
async def _create_connection_transport( async def _create_connection_transport(
self, sock, protocol_factory, ssl, self, sock, protocol_factory, ssl,
server_hostname, server_side=False, server_hostname, server_side=False,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): ssl_handshake_timeout=None):
sock.setblocking(False) sock.setblocking(False)
@ -948,7 +951,7 @@ class BaseEventLoop(events.AbstractEventLoop):
ssl=None, ssl=None,
reuse_address=None, reuse_address=None,
reuse_port=None, reuse_port=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): ssl_handshake_timeout=None):
"""Create a TCP server. """Create a TCP server.
The host parameter can be a string, in that case the TCP server is The host parameter can be a string, in that case the TCP server is
@ -966,6 +969,11 @@ class BaseEventLoop(events.AbstractEventLoop):
""" """
if isinstance(ssl, bool): if isinstance(ssl, bool):
raise TypeError('ssl argument must be an SSLContext or None') raise TypeError('ssl argument must be an SSLContext or None')
if ssl_handshake_timeout is not None and ssl is None:
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
if host is not None or port is not None: if host is not None or port is not None:
if sock is not None: if sock is not None:
raise ValueError( raise ValueError(
@ -1046,7 +1054,7 @@ class BaseEventLoop(events.AbstractEventLoop):
async def connect_accepted_socket( async def connect_accepted_socket(
self, protocol_factory, sock, self, protocol_factory, sock,
*, ssl=None, *, ssl=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): ssl_handshake_timeout=None):
"""Handle an accepted connection. """Handle an accepted connection.
This is used by servers that accept connections outside of This is used by servers that accept connections outside of
@ -1058,6 +1066,10 @@ class BaseEventLoop(events.AbstractEventLoop):
if sock.type != socket.SOCK_STREAM: if sock.type != socket.SOCK_STREAM:
raise ValueError(f'A Stream Socket was expected, got {sock!r}') raise ValueError(f'A Stream Socket was expected, got {sock!r}')
if ssl_handshake_timeout is not None and not ssl:
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
transport, protocol = await self._create_connection_transport( transport, protocol = await self._create_connection_transport(
sock, protocol_factory, ssl, '', server_side=True, sock, protocol_factory, ssl, '', server_side=True,
ssl_handshake_timeout=ssl_handshake_timeout) ssl_handshake_timeout=ssl_handshake_timeout)

View File

@ -17,7 +17,6 @@ import subprocess
import sys import sys
import threading import threading
from . import constants
from . import format_helpers from . import format_helpers
@ -255,7 +254,7 @@ class AbstractEventLoop:
*, ssl=None, family=0, proto=0, *, ssl=None, family=0, proto=0,
flags=0, sock=None, local_addr=None, flags=0, sock=None, local_addr=None,
server_hostname=None, server_hostname=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): ssl_handshake_timeout=None):
raise NotImplementedError raise NotImplementedError
async def create_server( async def create_server(
@ -263,7 +262,7 @@ class AbstractEventLoop:
*, family=socket.AF_UNSPEC, *, family=socket.AF_UNSPEC,
flags=socket.AI_PASSIVE, sock=None, backlog=100, flags=socket.AI_PASSIVE, sock=None, backlog=100,
ssl=None, reuse_address=None, reuse_port=None, ssl=None, reuse_address=None, reuse_port=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): ssl_handshake_timeout=None):
"""A coroutine which creates a TCP server bound to host and port. """A coroutine which creates a TCP server bound to host and port.
The return value is a Server object which can be used to stop The return value is a Server object which can be used to stop
@ -310,13 +309,13 @@ class AbstractEventLoop:
self, protocol_factory, path=None, *, self, protocol_factory, path=None, *,
ssl=None, sock=None, ssl=None, sock=None,
server_hostname=None, server_hostname=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): ssl_handshake_timeout=None):
raise NotImplementedError raise NotImplementedError
async def create_unix_server( async def create_unix_server(
self, protocol_factory, path=None, *, self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None, sock=None, backlog=100, ssl=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): ssl_handshake_timeout=None):
"""A coroutine which creates a UNIX Domain Socket server. """A coroutine which creates a UNIX Domain Socket server.
The return value is a Server object, which can be used to stop The return value is a Server object, which can be used to stop

View File

@ -393,7 +393,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
self, rawsock, protocol, sslcontext, waiter=None, self, rawsock, protocol, sslcontext, waiter=None,
*, server_side=False, server_hostname=None, *, server_side=False, server_hostname=None,
extra=None, server=None, extra=None, server=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): ssl_handshake_timeout=None):
ssl_protocol = sslproto.SSLProtocol( ssl_protocol = sslproto.SSLProtocol(
self, protocol, sslcontext, waiter, self, protocol, sslcontext, waiter,
server_side, server_hostname, server_side, server_hostname,
@ -491,7 +491,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
def _start_serving(self, protocol_factory, sock, def _start_serving(self, protocol_factory, sock,
sslcontext=None, server=None, backlog=100, sslcontext=None, server=None, backlog=100,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): ssl_handshake_timeout=None):
def loop(f=None): def loop(f=None):
try: try:

View File

@ -402,10 +402,17 @@ class SSLProtocol(protocols.Protocol):
def __init__(self, loop, app_protocol, sslcontext, waiter, def __init__(self, loop, app_protocol, sslcontext, waiter,
server_side=False, server_hostname=None, server_side=False, server_hostname=None,
call_connection_made=True, call_connection_made=True,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): ssl_handshake_timeout=None):
if ssl is None: if ssl is None:
raise RuntimeError('stdlib ssl module not available') raise RuntimeError('stdlib ssl module not available')
if ssl_handshake_timeout is None:
ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT
elif ssl_handshake_timeout <= 0:
raise ValueError(
f"ssl_handshake_timeout should be a positive number, "
f"got {ssl_handshake_timeout}")
if not sslcontext: if not sslcontext:
sslcontext = _create_transport_context( sslcontext = _create_transport_context(
server_side, server_hostname) server_side, server_hostname)

View File

@ -196,7 +196,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
self, protocol_factory, path=None, *, self, protocol_factory, path=None, *,
ssl=None, sock=None, ssl=None, sock=None,
server_hostname=None, server_hostname=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): ssl_handshake_timeout=None):
assert server_hostname is None or isinstance(server_hostname, str) assert server_hostname is None or isinstance(server_hostname, str)
if ssl: if ssl:
if server_hostname is None: if server_hostname is None:
@ -205,6 +205,9 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
else: else:
if server_hostname is not None: if server_hostname is not None:
raise ValueError('server_hostname is only meaningful with ssl') raise ValueError('server_hostname is only meaningful with ssl')
if ssl_handshake_timeout is not None:
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
if path is not None: if path is not None:
if sock is not None: if sock is not None:
@ -237,10 +240,14 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
async def create_unix_server( async def create_unix_server(
self, protocol_factory, path=None, *, self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None, sock=None, backlog=100, ssl=None,
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): ssl_handshake_timeout=None):
if isinstance(ssl, bool): if isinstance(ssl, bool):
raise TypeError('ssl argument must be an SSLContext or None') raise TypeError('ssl argument must be an SSLContext or None')
if ssl_handshake_timeout is not None and not ssl:
raise ValueError(
'ssl_handshake_timeout is only meaningful with ssl')
if path is not None: if path is not None:
if sock is not None: if sock is not None:
raise ValueError( raise ValueError(

View File

@ -1053,6 +1053,14 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
'A Stream Socket was expected'): 'A Stream Socket was expected'):
self.loop.run_until_complete(coro) self.loop.run_until_complete(coro)
def test_create_server_ssl_timeout_for_plain_socket(self):
coro = self.loop.create_server(
MyProto, 'example.com', 80, ssl_handshake_timeout=1)
with self.assertRaisesRegex(
ValueError,
'ssl_handshake_timeout is only meaningful with ssl'):
self.loop.run_until_complete(coro)
@unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'), @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'),
'no socket.SOCK_NONBLOCK (linux only)') 'no socket.SOCK_NONBLOCK (linux only)')
def test_create_server_stream_bittype(self): def test_create_server_stream_bittype(self):
@ -1362,6 +1370,14 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.addCleanup(sock.close) self.addCleanup(sock.close)
self.assertRaises(ValueError, self.loop.run_until_complete, coro) self.assertRaises(ValueError, self.loop.run_until_complete, coro)
def test_create_connection_ssl_timeout_for_plain_socket(self):
coro = self.loop.create_connection(
MyProto, 'example.com', 80, ssl_handshake_timeout=1)
with self.assertRaisesRegex(
ValueError,
'ssl_handshake_timeout is only meaningful with ssl'):
self.loop.run_until_complete(coro)
def test_create_server_empty_host(self): def test_create_server_empty_host(self):
# if host is empty string use None instead # if host is empty string use None instead
host = object() host = object()

View File

@ -852,6 +852,16 @@ class EventLoopTestsMixin:
self.test_connect_accepted_socket(server_context, client_context) self.test_connect_accepted_socket(server_context, client_context)
def test_connect_accepted_socket_ssl_timeout_for_plain_socket(self):
sock = socket.socket()
self.addCleanup(sock.close)
coro = self.loop.connect_accepted_socket(
MyProto, sock, ssl_handshake_timeout=1)
with self.assertRaisesRegex(
ValueError,
'ssl_handshake_timeout is only meaningful with ssl'):
self.loop.run_until_complete(coro)
@mock.patch('asyncio.base_events.socket') @mock.patch('asyncio.base_events.socket')
def create_server_multiple_hosts(self, family, hosts, mock_sock): def create_server_multiple_hosts(self, family, hosts, mock_sock):
@asyncio.coroutine @asyncio.coroutine

View File

@ -75,6 +75,22 @@ class SslProtoHandshakeTests(test_utils.TestCase):
self.loop.run_until_complete(tasks.sleep(0.2, loop=self.loop)) self.loop.run_until_complete(tasks.sleep(0.2, loop=self.loop))
self.assertTrue(transport.abort.called) self.assertTrue(transport.abort.called)
def test_handshake_timeout_zero(self):
sslcontext = test_utils.dummy_ssl_context()
app_proto = mock.Mock()
waiter = mock.Mock()
with self.assertRaisesRegex(ValueError, 'a positive number'):
sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
ssl_handshake_timeout=0)
def test_handshake_timeout_negative(self):
sslcontext = test_utils.dummy_ssl_context()
app_proto = mock.Mock()
waiter = mock.Mock()
with self.assertRaisesRegex(ValueError, 'a positive number'):
sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
ssl_handshake_timeout=-10)
def test_eof_received_waiter(self): def test_eof_received_waiter(self):
waiter = asyncio.Future(loop=self.loop) waiter = asyncio.Future(loop=self.loop)
ssl_proto = self.ssl_protocol(waiter) ssl_proto = self.ssl_protocol(waiter)

View File

@ -327,6 +327,14 @@ class SelectorEventLoopUnixSocketTests(test_utils.TestCase):
finally: finally:
os.unlink(fn) os.unlink(fn)
def test_create_unix_server_ssl_timeout_with_plain_sock(self):
coro = self.loop.create_unix_server(lambda: None, path='spam',
ssl_handshake_timeout=1)
with self.assertRaisesRegex(
ValueError,
'ssl_handshake_timeout is only meaningful with ssl'):
self.loop.run_until_complete(coro)
def test_create_unix_connection_path_inetsock(self): def test_create_unix_connection_path_inetsock(self):
sock = socket.socket() sock = socket.socket()
with sock: with sock:
@ -383,6 +391,15 @@ class SelectorEventLoopUnixSocketTests(test_utils.TestCase):
self.loop.run_until_complete(coro) self.loop.run_until_complete(coro)
def test_create_unix_connection_ssl_timeout_with_plain_sock(self):
coro = self.loop.create_unix_connection(lambda: None, path='spam',
ssl_handshake_timeout=1)
with self.assertRaisesRegex(
ValueError,
'ssl_handshake_timeout is only meaningful with ssl'):
self.loop.run_until_complete(coro)
class UnixReadPipeTransportTests(test_utils.TestCase): class UnixReadPipeTransportTests(test_utils.TestCase):