mirror of https://github.com/python/cpython
bpo-29970: Add timeout for SSL handshake in asyncio
10 seconds by default.
This commit is contained in:
parent
4b965930e8
commit
f7686c1f55
|
@ -261,7 +261,7 @@ Tasks
|
|||
Creating connections
|
||||
--------------------
|
||||
|
||||
.. coroutinemethod:: AbstractEventLoop.create_connection(protocol_factory, host=None, port=None, \*, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None)
|
||||
.. coroutinemethod:: AbstractEventLoop.create_connection(protocol_factory, host=None, port=None, \*, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None, ssl_handshake_timeout=10.0)
|
||||
|
||||
Create a streaming transport connection to a given Internet *host* and
|
||||
*port*: socket family :py:data:`~socket.AF_INET` or
|
||||
|
@ -325,6 +325,13 @@ Creating connections
|
|||
to bind the socket to locally. The *local_host* and *local_port*
|
||||
are looked up using getaddrinfo(), similarly to *host* and *port*.
|
||||
|
||||
* *ssl_handshake_timeout* is (for an SSL connection) the time in seconds
|
||||
to wait for the SSL handshake to complete before aborting the connection.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
|
||||
The *ssl_handshake_timeout* parameter.
|
||||
|
||||
.. versionchanged:: 3.5
|
||||
|
||||
On Windows with :class:`ProactorEventLoop`, SSL/TLS is now supported.
|
||||
|
@ -386,7 +393,7 @@ Creating connections
|
|||
:ref:`UDP echo server protocol <asyncio-udp-echo-server-protocol>` examples.
|
||||
|
||||
|
||||
.. coroutinemethod:: AbstractEventLoop.create_unix_connection(protocol_factory, path=None, \*, ssl=None, sock=None, server_hostname=None)
|
||||
.. coroutinemethod:: AbstractEventLoop.create_unix_connection(protocol_factory, path=None, \*, ssl=None, sock=None, server_hostname=None, ssl_handshake_timeout=10.0)
|
||||
|
||||
Create UNIX connection: socket family :py:data:`~socket.AF_UNIX`, socket
|
||||
type :py:data:`~socket.SOCK_STREAM`. The :py:data:`~socket.AF_UNIX` socket
|
||||
|
@ -404,6 +411,10 @@ Creating connections
|
|||
|
||||
Availability: UNIX.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
|
||||
The *ssl_handshake_timeout* parameter.
|
||||
|
||||
.. versionchanged:: 3.7
|
||||
|
||||
The *path* parameter can now be a :class:`~pathlib.Path` object.
|
||||
|
@ -412,7 +423,7 @@ Creating connections
|
|||
Creating listening connections
|
||||
------------------------------
|
||||
|
||||
.. coroutinemethod:: AbstractEventLoop.create_server(protocol_factory, host=None, port=None, \*, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None, reuse_port=None)
|
||||
.. coroutinemethod:: AbstractEventLoop.create_server(protocol_factory, host=None, port=None, \*, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None, reuse_port=None, ssl_handshake_timeout=10.0)
|
||||
|
||||
Create a TCP server (socket type :data:`~socket.SOCK_STREAM`) bound to
|
||||
*host* and *port*.
|
||||
|
@ -456,6 +467,13 @@ Creating listening connections
|
|||
set this flag when being created. This option is not supported on
|
||||
Windows.
|
||||
|
||||
* *ssl_handshake_timeout* is (for an SSL server) the time in seconds to wait
|
||||
for the SSL handshake to complete before aborting the connection.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
|
||||
The *ssl_handshake_timeout* parameter.
|
||||
|
||||
.. versionchanged:: 3.5
|
||||
|
||||
On Windows with :class:`ProactorEventLoop`, SSL/TLS is now supported.
|
||||
|
@ -470,7 +488,7 @@ Creating listening connections
|
|||
The *host* parameter can now be a sequence of strings.
|
||||
|
||||
|
||||
.. coroutinemethod:: AbstractEventLoop.create_unix_server(protocol_factory, path=None, \*, sock=None, backlog=100, ssl=None)
|
||||
.. coroutinemethod:: AbstractEventLoop.create_unix_server(protocol_factory, path=None, \*, sock=None, backlog=100, ssl=None, ssl_handshake_timeout=10.0)
|
||||
|
||||
Similar to :meth:`AbstractEventLoop.create_server`, but specific to the
|
||||
socket family :py:data:`~socket.AF_UNIX`.
|
||||
|
@ -481,11 +499,15 @@ Creating listening connections
|
|||
|
||||
Availability: UNIX.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
|
||||
The *ssl_handshake_timeout* parameter.
|
||||
|
||||
.. versionchanged:: 3.7
|
||||
|
||||
The *path* parameter can now be a :class:`~pathlib.Path` object.
|
||||
|
||||
.. coroutinemethod:: BaseEventLoop.connect_accepted_socket(protocol_factory, sock, \*, ssl=None)
|
||||
.. coroutinemethod:: BaseEventLoop.connect_accepted_socket(protocol_factory, sock, \*, ssl=None, ssl_handshake_timeout=10.0)
|
||||
|
||||
Handle an accepted connection.
|
||||
|
||||
|
@ -500,8 +522,15 @@ Creating listening connections
|
|||
* *ssl* can be set to an :class:`~ssl.SSLContext` to enable SSL over the
|
||||
accepted connections.
|
||||
|
||||
* *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to
|
||||
wait for the SSL handshake to complete before aborting the connection.
|
||||
|
||||
When completed it returns a ``(transport, protocol)`` pair.
|
||||
|
||||
.. versionadded:: 3.7
|
||||
|
||||
The *ssl_handshake_timeout* parameter.
|
||||
|
||||
.. versionadded:: 3.5.3
|
||||
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ import sys
|
|||
import warnings
|
||||
import weakref
|
||||
|
||||
from . import constants
|
||||
from . import coroutines
|
||||
from . import events
|
||||
from . import futures
|
||||
|
@ -275,9 +276,11 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
"""Create socket transport."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None,
|
||||
*, server_side=False, server_hostname=None,
|
||||
extra=None, server=None):
|
||||
def _make_ssl_transport(
|
||||
self, rawsock, protocol, sslcontext, waiter=None,
|
||||
*, server_side=False, server_hostname=None,
|
||||
extra=None, server=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
"""Create SSL transport."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -635,10 +638,12 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
return await self.run_in_executor(
|
||||
None, socket.getnameinfo, sockaddr, flags)
|
||||
|
||||
async def create_connection(self, protocol_factory, host=None, port=None,
|
||||
*, ssl=None, family=0,
|
||||
proto=0, flags=0, sock=None,
|
||||
local_addr=None, server_hostname=None):
|
||||
async def create_connection(
|
||||
self, protocol_factory, host=None, port=None,
|
||||
*, ssl=None, family=0,
|
||||
proto=0, flags=0, sock=None,
|
||||
local_addr=None, server_hostname=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
"""Connect to a TCP server.
|
||||
|
||||
Create a streaming transport connection to a given Internet host and
|
||||
|
@ -751,7 +756,8 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
f'A Stream Socket was expected, got {sock!r}')
|
||||
|
||||
transport, protocol = await self._create_connection_transport(
|
||||
sock, protocol_factory, ssl, server_hostname)
|
||||
sock, protocol_factory, ssl, server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||
if self._debug:
|
||||
# Get the socket from the transport because SSL transport closes
|
||||
# the old socket and creates a new SSL socket
|
||||
|
@ -760,8 +766,10 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
sock, host, port, transport, protocol)
|
||||
return transport, protocol
|
||||
|
||||
async def _create_connection_transport(self, sock, protocol_factory, ssl,
|
||||
server_hostname, server_side=False):
|
||||
async def _create_connection_transport(
|
||||
self, sock, protocol_factory, ssl,
|
||||
server_hostname, server_side=False,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
|
||||
sock.setblocking(False)
|
||||
|
||||
|
@ -771,7 +779,8 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
sslcontext = None if isinstance(ssl, bool) else ssl
|
||||
transport = self._make_ssl_transport(
|
||||
sock, protocol, sslcontext, waiter,
|
||||
server_side=server_side, server_hostname=server_hostname)
|
||||
server_side=server_side, server_hostname=server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||
else:
|
||||
transport = self._make_socket_transport(sock, protocol, waiter)
|
||||
|
||||
|
@ -929,15 +938,17 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
raise OSError(f'getaddrinfo({host!r}) returned empty list')
|
||||
return infos
|
||||
|
||||
async def create_server(self, protocol_factory, host=None, port=None,
|
||||
*,
|
||||
family=socket.AF_UNSPEC,
|
||||
flags=socket.AI_PASSIVE,
|
||||
sock=None,
|
||||
backlog=100,
|
||||
ssl=None,
|
||||
reuse_address=None,
|
||||
reuse_port=None):
|
||||
async def create_server(
|
||||
self, protocol_factory, host=None, port=None,
|
||||
*,
|
||||
family=socket.AF_UNSPEC,
|
||||
flags=socket.AI_PASSIVE,
|
||||
sock=None,
|
||||
backlog=100,
|
||||
ssl=None,
|
||||
reuse_address=None,
|
||||
reuse_port=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
"""Create a TCP server.
|
||||
|
||||
The host parameter can be a string, in that case the TCP server is
|
||||
|
@ -1026,13 +1037,16 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
for sock in sockets:
|
||||
sock.listen(backlog)
|
||||
sock.setblocking(False)
|
||||
self._start_serving(protocol_factory, sock, ssl, server, backlog)
|
||||
self._start_serving(protocol_factory, sock, ssl, server, backlog,
|
||||
ssl_handshake_timeout)
|
||||
if self._debug:
|
||||
logger.info("%r is serving", server)
|
||||
return server
|
||||
|
||||
async def connect_accepted_socket(self, protocol_factory, sock,
|
||||
*, ssl=None):
|
||||
async def connect_accepted_socket(
|
||||
self, protocol_factory, sock,
|
||||
*, ssl=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
"""Handle an accepted connection.
|
||||
|
||||
This is used by servers that accept connections outside of
|
||||
|
@ -1045,7 +1059,8 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
raise ValueError(f'A Stream Socket was expected, got {sock!r}')
|
||||
|
||||
transport, protocol = await self._create_connection_transport(
|
||||
sock, protocol_factory, ssl, '', server_side=True)
|
||||
sock, protocol_factory, ssl, '', server_side=True,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||
if self._debug:
|
||||
# Get the socket from the transport because SSL transport closes
|
||||
# the old socket and creates a new SSL socket
|
||||
|
|
|
@ -8,3 +8,6 @@ ACCEPT_RETRY_DELAY = 1
|
|||
# The larger the number, the slower the operation in debug mode
|
||||
# (see extract_stack() in format_helpers.py).
|
||||
DEBUG_STACK_DEPTH = 10
|
||||
|
||||
# Number of seconds to wait for SSL handshake to complete
|
||||
SSL_HANDSHAKE_TIMEOUT = 10.0
|
||||
|
|
|
@ -250,16 +250,20 @@ class AbstractEventLoop:
|
|||
async def getnameinfo(self, sockaddr, flags=0):
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_connection(self, protocol_factory, host=None, port=None,
|
||||
*, ssl=None, family=0, proto=0,
|
||||
flags=0, sock=None, local_addr=None,
|
||||
server_hostname=None):
|
||||
async def create_connection(
|
||||
self, protocol_factory, host=None, port=None,
|
||||
*, ssl=None, family=0, proto=0,
|
||||
flags=0, sock=None, local_addr=None,
|
||||
server_hostname=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_server(self, protocol_factory, host=None, port=None,
|
||||
*, family=socket.AF_UNSPEC,
|
||||
flags=socket.AI_PASSIVE, sock=None, backlog=100,
|
||||
ssl=None, reuse_address=None, reuse_port=None):
|
||||
async def create_server(
|
||||
self, protocol_factory, host=None, port=None,
|
||||
*, family=socket.AF_UNSPEC,
|
||||
flags=socket.AI_PASSIVE, sock=None, backlog=100,
|
||||
ssl=None, reuse_address=None, reuse_port=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
"""A coroutine which creates a TCP server bound to host and port.
|
||||
|
||||
The return value is a Server object which can be used to stop
|
||||
|
@ -294,16 +298,25 @@ class AbstractEventLoop:
|
|||
the same port as other existing endpoints are bound to, so long as
|
||||
they all set this flag when being created. This option is not
|
||||
supported on Windows.
|
||||
|
||||
ssl_handshake_timeout is the time in seconds that an SSL server
|
||||
will wait for completion of the SSL handshake before aborting the
|
||||
connection. Default is 10s, longer timeouts may increase vulnerability
|
||||
to DoS attacks (see https://support.f5.com/csp/article/K13834)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_unix_connection(self, protocol_factory, path=None, *,
|
||||
ssl=None, sock=None,
|
||||
server_hostname=None):
|
||||
async def create_unix_connection(
|
||||
self, protocol_factory, path=None, *,
|
||||
ssl=None, sock=None,
|
||||
server_hostname=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_unix_server(self, protocol_factory, path=None, *,
|
||||
sock=None, backlog=100, ssl=None):
|
||||
async def create_unix_server(
|
||||
self, protocol_factory, path=None, *,
|
||||
sock=None, backlog=100, ssl=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
"""A coroutine which creates a UNIX Domain Socket server.
|
||||
|
||||
The return value is a Server object, which can be used to stop
|
||||
|
@ -320,6 +333,9 @@ class AbstractEventLoop:
|
|||
|
||||
ssl can be set to an SSLContext to enable SSL over the
|
||||
accepted connections.
|
||||
|
||||
ssl_handshake_timeout is the time in seconds that an SSL server
|
||||
will wait for the SSL handshake to complete (defaults to 10s).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -389,11 +389,15 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
|
|||
return _ProactorSocketTransport(self, sock, protocol, waiter,
|
||||
extra, server)
|
||||
|
||||
def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None,
|
||||
*, server_side=False, server_hostname=None,
|
||||
extra=None, server=None):
|
||||
ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter,
|
||||
server_side, server_hostname)
|
||||
def _make_ssl_transport(
|
||||
self, rawsock, protocol, sslcontext, waiter=None,
|
||||
*, server_side=False, server_hostname=None,
|
||||
extra=None, server=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
ssl_protocol = sslproto.SSLProtocol(
|
||||
self, protocol, sslcontext, waiter,
|
||||
server_side, server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||
_ProactorSocketTransport(self, rawsock, ssl_protocol,
|
||||
extra=extra, server=server)
|
||||
return ssl_protocol._app_transport
|
||||
|
@ -486,7 +490,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
|
|||
self._csock.send(b'\0')
|
||||
|
||||
def _start_serving(self, protocol_factory, sock,
|
||||
sslcontext=None, server=None, backlog=100):
|
||||
sslcontext=None, server=None, backlog=100,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
|
||||
def loop(f=None):
|
||||
try:
|
||||
|
@ -499,7 +504,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
|
|||
if sslcontext is not None:
|
||||
self._make_ssl_transport(
|
||||
conn, protocol, sslcontext, server_side=True,
|
||||
extra={'peername': addr}, server=server)
|
||||
extra={'peername': addr}, server=server,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||
else:
|
||||
self._make_socket_transport(
|
||||
conn, protocol,
|
||||
|
|
|
@ -70,11 +70,15 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
|||
return _SelectorSocketTransport(self, sock, protocol, waiter,
|
||||
extra, server)
|
||||
|
||||
def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None,
|
||||
*, server_side=False, server_hostname=None,
|
||||
extra=None, server=None):
|
||||
ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter,
|
||||
server_side, server_hostname)
|
||||
def _make_ssl_transport(
|
||||
self, rawsock, protocol, sslcontext, waiter=None,
|
||||
*, server_side=False, server_hostname=None,
|
||||
extra=None, server=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
ssl_protocol = sslproto.SSLProtocol(
|
||||
self, protocol, sslcontext, waiter,
|
||||
server_side, server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||
_SelectorSocketTransport(self, rawsock, ssl_protocol,
|
||||
extra=extra, server=server)
|
||||
return ssl_protocol._app_transport
|
||||
|
@ -143,12 +147,16 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
|||
exc_info=True)
|
||||
|
||||
def _start_serving(self, protocol_factory, sock,
|
||||
sslcontext=None, server=None, backlog=100):
|
||||
sslcontext=None, server=None, backlog=100,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
self._add_reader(sock.fileno(), self._accept_connection,
|
||||
protocol_factory, sock, sslcontext, server, backlog)
|
||||
protocol_factory, sock, sslcontext, server, backlog,
|
||||
ssl_handshake_timeout)
|
||||
|
||||
def _accept_connection(self, protocol_factory, sock,
|
||||
sslcontext=None, server=None, backlog=100):
|
||||
def _accept_connection(
|
||||
self, protocol_factory, sock,
|
||||
sslcontext=None, server=None, backlog=100,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
# This method is only called once for each event loop tick where the
|
||||
# listening socket has triggered an EVENT_READ. There may be multiple
|
||||
# connections waiting for an .accept() so it is called in a loop.
|
||||
|
@ -179,17 +187,20 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
|||
self.call_later(constants.ACCEPT_RETRY_DELAY,
|
||||
self._start_serving,
|
||||
protocol_factory, sock, sslcontext, server,
|
||||
backlog)
|
||||
backlog, ssl_handshake_timeout)
|
||||
else:
|
||||
raise # The event loop will catch, log and ignore it.
|
||||
else:
|
||||
extra = {'peername': addr}
|
||||
accept = self._accept_connection2(
|
||||
protocol_factory, conn, extra, sslcontext, server)
|
||||
protocol_factory, conn, extra, sslcontext, server,
|
||||
ssl_handshake_timeout)
|
||||
self.create_task(accept)
|
||||
|
||||
async def _accept_connection2(self, protocol_factory, conn, extra,
|
||||
sslcontext=None, server=None):
|
||||
async def _accept_connection2(
|
||||
self, protocol_factory, conn, extra,
|
||||
sslcontext=None, server=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
protocol = None
|
||||
transport = None
|
||||
try:
|
||||
|
@ -198,7 +209,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
|||
if sslcontext:
|
||||
transport = self._make_ssl_transport(
|
||||
conn, protocol, sslcontext, waiter=waiter,
|
||||
server_side=True, extra=extra, server=server)
|
||||
server_side=True, extra=extra, server=server,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||
else:
|
||||
transport = self._make_socket_transport(
|
||||
conn, protocol, waiter=waiter, extra=extra,
|
||||
|
|
|
@ -6,6 +6,7 @@ except ImportError: # pragma: no cover
|
|||
ssl = None
|
||||
|
||||
from . import base_events
|
||||
from . import constants
|
||||
from . import protocols
|
||||
from . import transports
|
||||
from .log import logger
|
||||
|
@ -400,7 +401,8 @@ class SSLProtocol(protocols.Protocol):
|
|||
|
||||
def __init__(self, loop, app_protocol, sslcontext, waiter,
|
||||
server_side=False, server_hostname=None,
|
||||
call_connection_made=True):
|
||||
call_connection_made=True,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
if ssl is None:
|
||||
raise RuntimeError('stdlib ssl module not available')
|
||||
|
||||
|
@ -434,6 +436,7 @@ class SSLProtocol(protocols.Protocol):
|
|||
# transport, ex: SelectorSocketTransport
|
||||
self._transport = None
|
||||
self._call_connection_made = call_connection_made
|
||||
self._ssl_handshake_timeout = ssl_handshake_timeout
|
||||
|
||||
def _wakeup_waiter(self, exc=None):
|
||||
if self._waiter is None:
|
||||
|
@ -561,9 +564,18 @@ class SSLProtocol(protocols.Protocol):
|
|||
# the SSL handshake
|
||||
self._write_backlog.append((b'', 1))
|
||||
self._loop.call_soon(self._process_write_backlog)
|
||||
self._handshake_timeout_handle = \
|
||||
self._loop.call_later(self._ssl_handshake_timeout,
|
||||
self._check_handshake_timeout)
|
||||
|
||||
def _check_handshake_timeout(self):
|
||||
if self._in_handshake is True:
|
||||
logger.warning("%r stalled during handshake", self)
|
||||
self._abort()
|
||||
|
||||
def _on_handshake_complete(self, handshake_exc):
|
||||
self._in_handshake = False
|
||||
self._handshake_timeout_handle.cancel()
|
||||
|
||||
sslobj = self._sslpipe.ssl_object
|
||||
try:
|
||||
|
|
|
@ -192,9 +192,11 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
|
|||
def _child_watcher_callback(self, pid, returncode, transp):
|
||||
self.call_soon_threadsafe(transp._process_exited, returncode)
|
||||
|
||||
async def create_unix_connection(self, protocol_factory, path=None, *,
|
||||
ssl=None, sock=None,
|
||||
server_hostname=None):
|
||||
async def create_unix_connection(
|
||||
self, protocol_factory, path=None, *,
|
||||
ssl=None, sock=None,
|
||||
server_hostname=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
assert server_hostname is None or isinstance(server_hostname, str)
|
||||
if ssl:
|
||||
if server_hostname is None:
|
||||
|
@ -228,11 +230,14 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
|
|||
sock.setblocking(False)
|
||||
|
||||
transport, protocol = await self._create_connection_transport(
|
||||
sock, protocol_factory, ssl, server_hostname)
|
||||
sock, protocol_factory, ssl, server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||
return transport, protocol
|
||||
|
||||
async def create_unix_server(self, protocol_factory, path=None, *,
|
||||
sock=None, backlog=100, ssl=None):
|
||||
async def create_unix_server(
|
||||
self, protocol_factory, path=None, *,
|
||||
sock=None, backlog=100, ssl=None,
|
||||
ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
|
||||
if isinstance(ssl, bool):
|
||||
raise TypeError('ssl argument must be an SSLContext or None')
|
||||
|
||||
|
@ -283,7 +288,8 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
|
|||
server = base_events.Server(self, [sock])
|
||||
sock.listen(backlog)
|
||||
sock.setblocking(False)
|
||||
self._start_serving(protocol_factory, sock, ssl, server)
|
||||
self._start_serving(protocol_factory, sock, ssl, server,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout)
|
||||
return server
|
||||
|
||||
|
||||
|
|
|
@ -1301,34 +1301,45 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
|
|||
|
||||
self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport
|
||||
ANY = mock.ANY
|
||||
handshake_timeout = object()
|
||||
# First try the default server_hostname.
|
||||
self.loop._make_ssl_transport.reset_mock()
|
||||
coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True)
|
||||
coro = self.loop.create_connection(
|
||||
MyProto, 'python.org', 80, ssl=True,
|
||||
ssl_handshake_timeout=handshake_timeout)
|
||||
transport, _ = self.loop.run_until_complete(coro)
|
||||
transport.close()
|
||||
self.loop._make_ssl_transport.assert_called_with(
|
||||
ANY, ANY, ANY, ANY,
|
||||
server_side=False,
|
||||
server_hostname='python.org')
|
||||
server_hostname='python.org',
|
||||
ssl_handshake_timeout=handshake_timeout)
|
||||
# Next try an explicit server_hostname.
|
||||
self.loop._make_ssl_transport.reset_mock()
|
||||
coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True,
|
||||
server_hostname='perl.com')
|
||||
coro = self.loop.create_connection(
|
||||
MyProto, 'python.org', 80, ssl=True,
|
||||
server_hostname='perl.com',
|
||||
ssl_handshake_timeout=handshake_timeout)
|
||||
transport, _ = self.loop.run_until_complete(coro)
|
||||
transport.close()
|
||||
self.loop._make_ssl_transport.assert_called_with(
|
||||
ANY, ANY, ANY, ANY,
|
||||
server_side=False,
|
||||
server_hostname='perl.com')
|
||||
server_hostname='perl.com',
|
||||
ssl_handshake_timeout=handshake_timeout)
|
||||
# Finally try an explicit empty server_hostname.
|
||||
self.loop._make_ssl_transport.reset_mock()
|
||||
coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True,
|
||||
server_hostname='')
|
||||
coro = self.loop.create_connection(
|
||||
MyProto, 'python.org', 80, ssl=True,
|
||||
server_hostname='',
|
||||
ssl_handshake_timeout=handshake_timeout)
|
||||
transport, _ = self.loop.run_until_complete(coro)
|
||||
transport.close()
|
||||
self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY,
|
||||
server_side=False,
|
||||
server_hostname='')
|
||||
self.loop._make_ssl_transport.assert_called_with(
|
||||
ANY, ANY, ANY, ANY,
|
||||
server_side=False,
|
||||
server_hostname='',
|
||||
ssl_handshake_timeout=handshake_timeout)
|
||||
|
||||
def test_create_connection_no_ssl_server_hostname_errors(self):
|
||||
# When not using ssl, server_hostname must be None.
|
||||
|
@ -1687,7 +1698,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
|
|||
constants.ACCEPT_RETRY_DELAY,
|
||||
# self.loop._start_serving
|
||||
mock.ANY,
|
||||
MyProto, sock, None, None, mock.ANY)
|
||||
MyProto, sock, None, None, mock.ANY, mock.ANY)
|
||||
|
||||
def test_call_coroutine(self):
|
||||
@asyncio.coroutine
|
||||
|
|
|
@ -11,6 +11,7 @@ except ImportError:
|
|||
import asyncio
|
||||
from asyncio import log
|
||||
from asyncio import sslproto
|
||||
from asyncio import tasks
|
||||
from test.test_asyncio import utils as test_utils
|
||||
|
||||
|
||||
|
@ -25,7 +26,8 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
|||
def ssl_protocol(self, waiter=None):
|
||||
sslcontext = test_utils.dummy_ssl_context()
|
||||
app_proto = asyncio.Protocol()
|
||||
proto = sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter)
|
||||
proto = sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
|
||||
ssl_handshake_timeout=0.1)
|
||||
self.assertIs(proto._app_transport.get_protocol(), app_proto)
|
||||
self.addCleanup(proto._app_transport.close)
|
||||
return proto
|
||||
|
@ -63,6 +65,16 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
|||
with test_utils.disable_logger():
|
||||
self.loop.run_until_complete(handshake_fut)
|
||||
|
||||
def test_handshake_timeout(self):
|
||||
# bpo-29970: Check that a connection is aborted if handshake is not
|
||||
# completed in timeout period, instead of remaining open indefinitely
|
||||
ssl_proto = self.ssl_protocol()
|
||||
transport = self.connection_made(ssl_proto)
|
||||
|
||||
with test_utils.disable_logger():
|
||||
self.loop.run_until_complete(tasks.sleep(0.2, loop=self.loop))
|
||||
self.assertTrue(transport.abort.called)
|
||||
|
||||
def test_eof_received_waiter(self):
|
||||
waiter = asyncio.Future(loop=self.loop)
|
||||
ssl_proto = self.ssl_protocol(waiter)
|
||||
|
|
|
@ -63,6 +63,7 @@ Jeffrey Armstrong
|
|||
Jason Asbahr
|
||||
David Ascher
|
||||
Ammar Askar
|
||||
Neil Aspinall
|
||||
Chris AtLee
|
||||
Aymeric Augustin
|
||||
Cathy Avery
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Abort asyncio SSLProtocol connection if handshake not complete within 10s
|
Loading…
Reference in New Issue