bpo-33734: asyncio/ssl: a bunch of bugfixes (#7321)
* Fix AttributeError (not all SSL exceptions have 'errno' attribute) * Increase default handshake timeout from 10 to 60 seconds * Make sure start_tls can be cancelled correctly * Make sure any error in SSLProtocol gets propagated (instead of just being logged)
This commit is contained in:
parent
a8eb58546b
commit
9602643120
|
@ -351,7 +351,7 @@ Creating connections
|
|||
|
||||
* *ssl_handshake_timeout* is (for an SSL connection) the time in seconds
|
||||
to wait for the SSL handshake to complete before aborting the connection.
|
||||
``10.0`` seconds if ``None`` (default).
|
||||
``60.0`` seconds if ``None`` (default).
|
||||
|
||||
.. versionadded:: 3.7
|
||||
|
||||
|
@ -497,7 +497,7 @@ Creating listening connections
|
|||
|
||||
* *ssl_handshake_timeout* is (for an SSL server) the time in seconds to wait
|
||||
for the SSL handshake to complete before aborting the connection.
|
||||
``10.0`` seconds if ``None`` (default).
|
||||
``60.0`` seconds if ``None`` (default).
|
||||
|
||||
* *start_serving* set to ``True`` (the default) causes the created server
|
||||
to start accepting connections immediately. When set to ``False``,
|
||||
|
@ -559,7 +559,7 @@ Creating listening connections
|
|||
|
||||
* *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to
|
||||
wait for the SSL handshake to complete before aborting the connection.
|
||||
``10.0`` seconds if ``None`` (default).
|
||||
``60.0`` seconds if ``None`` (default).
|
||||
|
||||
When completed it returns a ``(transport, protocol)`` pair.
|
||||
|
||||
|
@ -628,7 +628,7 @@ TLS Upgrade
|
|||
|
||||
* *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to
|
||||
wait for the SSL handshake to complete before aborting the connection.
|
||||
``10.0`` seconds if ``None`` (default).
|
||||
``60.0`` seconds if ``None`` (default).
|
||||
|
||||
.. versionadded:: 3.7
|
||||
|
||||
|
|
|
@ -1114,7 +1114,12 @@ class BaseEventLoop(events.AbstractEventLoop):
|
|||
self.call_soon(ssl_protocol.connection_made, transport)
|
||||
self.call_soon(transport.resume_reading)
|
||||
|
||||
try:
|
||||
await waiter
|
||||
except Exception:
|
||||
transport.close()
|
||||
raise
|
||||
|
||||
return ssl_protocol._app_transport
|
||||
|
||||
async def create_datagram_endpoint(self, protocol_factory,
|
||||
|
|
|
@ -12,7 +12,8 @@ ACCEPT_RETRY_DELAY = 1
|
|||
DEBUG_STACK_DEPTH = 10
|
||||
|
||||
# Number of seconds to wait for SSL handshake to complete
|
||||
SSL_HANDSHAKE_TIMEOUT = 10.0
|
||||
# The default timeout matches that of Nginx.
|
||||
SSL_HANDSHAKE_TIMEOUT = 60.0
|
||||
|
||||
# Used in sendfile fallback code. We use fallback for platforms
|
||||
# that don't support sendfile, or for TLS connections.
|
||||
|
|
|
@ -352,8 +352,7 @@ class AbstractEventLoop:
|
|||
|
||||
ssl_handshake_timeout is the time in seconds that an SSL server
|
||||
will wait for completion of the SSL handshake before aborting the
|
||||
connection. Default is 10s, longer timeouts may increase vulnerability
|
||||
to DoS attacks (see https://support.f5.com/csp/article/K13834)
|
||||
connection. Default is 60s.
|
||||
|
||||
start_serving set to True (default) causes the created server
|
||||
to start accepting connections immediately. When set to False,
|
||||
|
@ -411,7 +410,7 @@ class AbstractEventLoop:
|
|||
accepted connections.
|
||||
|
||||
ssl_handshake_timeout is the time in seconds that an SSL server
|
||||
will wait for the SSL handshake to complete (defaults to 10s).
|
||||
will wait for the SSL handshake to complete (defaults to 60s).
|
||||
|
||||
start_serving set to True (default) causes the created server
|
||||
to start accepting connections immediately. When set to False,
|
||||
|
|
|
@ -214,13 +214,14 @@ class _SSLPipe(object):
|
|||
# Drain possible plaintext data after close_notify.
|
||||
appdata.append(self._incoming.read())
|
||||
except (ssl.SSLError, ssl.CertificateError) as exc:
|
||||
if getattr(exc, 'errno', None) not in (
|
||||
exc_errno = getattr(exc, 'errno', None)
|
||||
if exc_errno not in (
|
||||
ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE,
|
||||
ssl.SSL_ERROR_SYSCALL):
|
||||
if self._state == _DO_HANDSHAKE and self._handshake_cb:
|
||||
self._handshake_cb(exc)
|
||||
raise
|
||||
self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ)
|
||||
self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
|
||||
|
||||
# Check for record level data that needs to be sent back.
|
||||
# Happens for the initial handshake and renegotiations.
|
||||
|
@ -263,13 +264,14 @@ class _SSLPipe(object):
|
|||
# It is not allowed to call write() after unwrap() until the
|
||||
# close_notify is acknowledged. We return the condition to the
|
||||
# caller as a short write.
|
||||
exc_errno = getattr(exc, 'errno', None)
|
||||
if exc.reason == 'PROTOCOL_IS_SHUTDOWN':
|
||||
exc.errno = ssl.SSL_ERROR_WANT_READ
|
||||
if exc.errno not in (ssl.SSL_ERROR_WANT_READ,
|
||||
exc_errno = exc.errno = ssl.SSL_ERROR_WANT_READ
|
||||
if exc_errno not in (ssl.SSL_ERROR_WANT_READ,
|
||||
ssl.SSL_ERROR_WANT_WRITE,
|
||||
ssl.SSL_ERROR_SYSCALL):
|
||||
raise
|
||||
self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ)
|
||||
self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
|
||||
|
||||
# See if there's any record level data back for us.
|
||||
if self._outgoing.pending:
|
||||
|
@ -488,6 +490,12 @@ class SSLProtocol(protocols.Protocol):
|
|||
if self._session_established:
|
||||
self._session_established = False
|
||||
self._loop.call_soon(self._app_protocol.connection_lost, exc)
|
||||
else:
|
||||
# Most likely an exception occurred while in SSL handshake.
|
||||
# Just mark the app transport as closed so that its __del__
|
||||
# doesn't complain.
|
||||
if self._app_transport is not None:
|
||||
self._app_transport._closed = True
|
||||
self._transport = None
|
||||
self._app_transport = None
|
||||
self._wakeup_waiter(exc)
|
||||
|
@ -515,11 +523,8 @@ class SSLProtocol(protocols.Protocol):
|
|||
|
||||
try:
|
||||
ssldata, appdata = self._sslpipe.feed_ssldata(data)
|
||||
except ssl.SSLError as e:
|
||||
if self._loop.get_debug():
|
||||
logger.warning('%r: SSL error %s (reason %s)',
|
||||
self, e.errno, e.reason)
|
||||
self._abort()
|
||||
except Exception as e:
|
||||
self._fatal_error(e, 'SSL error in data received')
|
||||
return
|
||||
|
||||
for chunk in ssldata:
|
||||
|
@ -602,8 +607,12 @@ class SSLProtocol(protocols.Protocol):
|
|||
|
||||
def _check_handshake_timeout(self):
|
||||
if self._in_handshake is True:
|
||||
logger.warning("%r stalled during handshake", self)
|
||||
self._abort()
|
||||
msg = (
|
||||
f"SSL handshake is taking longer than "
|
||||
f"{self._ssl_handshake_timeout} seconds: "
|
||||
f"aborting the connection"
|
||||
)
|
||||
self._fatal_error(ConnectionAbortedError(msg))
|
||||
|
||||
def _on_handshake_complete(self, handshake_exc):
|
||||
self._in_handshake = False
|
||||
|
@ -615,21 +624,13 @@ class SSLProtocol(protocols.Protocol):
|
|||
raise handshake_exc
|
||||
|
||||
peercert = sslobj.getpeercert()
|
||||
except BaseException as exc:
|
||||
if self._loop.get_debug():
|
||||
except Exception as exc:
|
||||
if isinstance(exc, ssl.CertificateError):
|
||||
logger.warning("%r: SSL handshake failed "
|
||||
"on verifying the certificate",
|
||||
self, exc_info=True)
|
||||
msg = 'SSL handshake failed on verifying the certificate'
|
||||
else:
|
||||
logger.warning("%r: SSL handshake failed",
|
||||
self, exc_info=True)
|
||||
self._transport.close()
|
||||
if isinstance(exc, Exception):
|
||||
self._wakeup_waiter(exc)
|
||||
msg = 'SSL handshake failed'
|
||||
self._fatal_error(exc, msg)
|
||||
return
|
||||
else:
|
||||
raise
|
||||
|
||||
if self._loop.get_debug():
|
||||
dt = self._loop.time() - self._handshake_start_time
|
||||
|
@ -686,18 +687,14 @@ class SSLProtocol(protocols.Protocol):
|
|||
# delete it and reduce the outstanding buffer size.
|
||||
del self._write_backlog[0]
|
||||
self._write_buffer_size -= len(data)
|
||||
except BaseException as exc:
|
||||
except Exception as exc:
|
||||
if self._in_handshake:
|
||||
# BaseExceptions will be re-raised in _on_handshake_complete.
|
||||
# Exceptions will be re-raised in _on_handshake_complete.
|
||||
self._on_handshake_complete(exc)
|
||||
else:
|
||||
self._fatal_error(exc, 'Fatal error on SSL transport')
|
||||
if not isinstance(exc, Exception):
|
||||
# BaseException
|
||||
raise
|
||||
|
||||
def _fatal_error(self, exc, message='Fatal error on transport'):
|
||||
# Should be called from exception handler only.
|
||||
if isinstance(exc, base_events._FATAL_ERROR_IGNORE):
|
||||
if self._loop.get_debug():
|
||||
logger.debug("%r: %s", self, message, exc_info=True)
|
||||
|
|
|
@ -53,35 +53,6 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
|||
ssl_proto.connection_made(transport)
|
||||
return transport
|
||||
|
||||
def test_cancel_handshake(self):
|
||||
# Python issue #23197: cancelling a handshake must not raise an
|
||||
# exception or log an error, even if the handshake failed
|
||||
waiter = asyncio.Future(loop=self.loop)
|
||||
ssl_proto = self.ssl_protocol(waiter=waiter)
|
||||
handshake_fut = asyncio.Future(loop=self.loop)
|
||||
|
||||
def do_handshake(callback):
|
||||
exc = Exception()
|
||||
callback(exc)
|
||||
handshake_fut.set_result(None)
|
||||
return []
|
||||
|
||||
waiter.cancel()
|
||||
self.connection_made(ssl_proto, do_handshake=do_handshake)
|
||||
|
||||
with test_utils.disable_logger():
|
||||
self.loop.run_until_complete(handshake_fut)
|
||||
|
||||
def test_handshake_timeout(self):
|
||||
# bpo-29970: Check that a connection is aborted if handshake is not
|
||||
# completed in timeout period, instead of remaining open indefinitely
|
||||
ssl_proto = self.ssl_protocol()
|
||||
transport = self.connection_made(ssl_proto)
|
||||
|
||||
with test_utils.disable_logger():
|
||||
self.loop.run_until_complete(tasks.sleep(0.2, loop=self.loop))
|
||||
self.assertTrue(transport.abort.called)
|
||||
|
||||
def test_handshake_timeout_zero(self):
|
||||
sslcontext = test_utils.dummy_ssl_context()
|
||||
app_proto = mock.Mock()
|
||||
|
@ -392,6 +363,67 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
|
|||
asyncio.wait_for(client(srv.addr),
|
||||
loop=self.loop, timeout=self.TIMEOUT))
|
||||
|
||||
def test_start_tls_slow_client_cancel(self):
|
||||
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
|
||||
|
||||
client_context = test_utils.simple_client_sslcontext()
|
||||
server_waits_on_handshake = self.loop.create_future()
|
||||
|
||||
def serve(sock):
|
||||
sock.settimeout(self.TIMEOUT)
|
||||
|
||||
data = sock.recv_all(len(HELLO_MSG))
|
||||
self.assertEqual(len(data), len(HELLO_MSG))
|
||||
|
||||
try:
|
||||
self.loop.call_soon_threadsafe(
|
||||
server_waits_on_handshake.set_result, None)
|
||||
data = sock.recv_all(1024 * 1024)
|
||||
except ConnectionAbortedError:
|
||||
pass
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
class ClientProto(asyncio.Protocol):
|
||||
def __init__(self, on_data, on_eof):
|
||||
self.on_data = on_data
|
||||
self.on_eof = on_eof
|
||||
self.con_made_cnt = 0
|
||||
|
||||
def connection_made(proto, tr):
|
||||
proto.con_made_cnt += 1
|
||||
# Ensure connection_made gets called only once.
|
||||
self.assertEqual(proto.con_made_cnt, 1)
|
||||
|
||||
def data_received(self, data):
|
||||
self.on_data.set_result(data)
|
||||
|
||||
def eof_received(self):
|
||||
self.on_eof.set_result(True)
|
||||
|
||||
async def client(addr):
|
||||
await asyncio.sleep(0.5, loop=self.loop)
|
||||
|
||||
on_data = self.loop.create_future()
|
||||
on_eof = self.loop.create_future()
|
||||
|
||||
tr, proto = await self.loop.create_connection(
|
||||
lambda: ClientProto(on_data, on_eof), *addr)
|
||||
|
||||
tr.write(HELLO_MSG)
|
||||
|
||||
await server_waits_on_handshake
|
||||
|
||||
with self.assertRaises(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(
|
||||
self.loop.start_tls(tr, proto, client_context),
|
||||
0.5,
|
||||
loop=self.loop)
|
||||
|
||||
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
|
||||
self.loop.run_until_complete(
|
||||
asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
|
||||
|
||||
def test_start_tls_server_1(self):
|
||||
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
|
||||
|
||||
|
@ -481,6 +513,156 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
|
|||
|
||||
self.loop.run_until_complete(main())
|
||||
|
||||
def test_handshake_timeout(self):
|
||||
# bpo-29970: Check that a connection is aborted if handshake is not
|
||||
# completed in timeout period, instead of remaining open indefinitely
|
||||
client_sslctx = test_utils.simple_client_sslcontext()
|
||||
|
||||
messages = []
|
||||
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
|
||||
|
||||
server_side_aborted = False
|
||||
|
||||
def server(sock):
|
||||
nonlocal server_side_aborted
|
||||
try:
|
||||
sock.recv_all(1024 * 1024)
|
||||
except ConnectionAbortedError:
|
||||
server_side_aborted = True
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
async def client(addr):
|
||||
await asyncio.wait_for(
|
||||
self.loop.create_connection(
|
||||
asyncio.Protocol,
|
||||
*addr,
|
||||
ssl=client_sslctx,
|
||||
server_hostname='',
|
||||
ssl_handshake_timeout=10.0),
|
||||
0.5,
|
||||
loop=self.loop)
|
||||
|
||||
with self.tcp_server(server,
|
||||
max_clients=1,
|
||||
backlog=1) as srv:
|
||||
|
||||
with self.assertRaises(asyncio.TimeoutError):
|
||||
self.loop.run_until_complete(client(srv.addr))
|
||||
|
||||
self.assertTrue(server_side_aborted)
|
||||
|
||||
# Python issue #23197: cancelling a handshake must not raise an
|
||||
# exception or log an error, even if the handshake failed
|
||||
self.assertEqual(messages, [])
|
||||
|
||||
def test_create_connection_ssl_slow_handshake(self):
|
||||
client_sslctx = test_utils.simple_client_sslcontext()
|
||||
|
||||
messages = []
|
||||
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
|
||||
|
||||
def server(sock):
|
||||
try:
|
||||
sock.recv_all(1024 * 1024)
|
||||
except ConnectionAbortedError:
|
||||
pass
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
async def client(addr):
|
||||
reader, writer = await asyncio.open_connection(
|
||||
*addr,
|
||||
ssl=client_sslctx,
|
||||
server_hostname='',
|
||||
loop=self.loop,
|
||||
ssl_handshake_timeout=1.0)
|
||||
|
||||
with self.tcp_server(server,
|
||||
max_clients=1,
|
||||
backlog=1) as srv:
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ConnectionAbortedError,
|
||||
r'SSL handshake.*is taking longer'):
|
||||
|
||||
self.loop.run_until_complete(client(srv.addr))
|
||||
|
||||
self.assertEqual(messages, [])
|
||||
|
||||
def test_create_connection_ssl_failed_certificate(self):
|
||||
self.loop.set_exception_handler(lambda loop, ctx: None)
|
||||
|
||||
sslctx = test_utils.simple_server_sslcontext()
|
||||
client_sslctx = test_utils.simple_client_sslcontext(
|
||||
disable_verify=False)
|
||||
|
||||
def server(sock):
|
||||
try:
|
||||
sock.start_tls(
|
||||
sslctx,
|
||||
server_side=True)
|
||||
except ssl.SSLError:
|
||||
pass
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
async def client(addr):
|
||||
reader, writer = await asyncio.open_connection(
|
||||
*addr,
|
||||
ssl=client_sslctx,
|
||||
server_hostname='',
|
||||
loop=self.loop,
|
||||
ssl_handshake_timeout=1.0)
|
||||
|
||||
with self.tcp_server(server,
|
||||
max_clients=1,
|
||||
backlog=1) as srv:
|
||||
|
||||
with self.assertRaises(ssl.SSLCertVerificationError):
|
||||
self.loop.run_until_complete(client(srv.addr))
|
||||
|
||||
def test_start_tls_client_corrupted_ssl(self):
|
||||
self.loop.set_exception_handler(lambda loop, ctx: None)
|
||||
|
||||
sslctx = test_utils.simple_server_sslcontext()
|
||||
client_sslctx = test_utils.simple_client_sslcontext()
|
||||
|
||||
def server(sock):
|
||||
orig_sock = sock.dup()
|
||||
try:
|
||||
sock.start_tls(
|
||||
sslctx,
|
||||
server_side=True)
|
||||
sock.sendall(b'A\n')
|
||||
sock.recv_all(1)
|
||||
orig_sock.send(b'please corrupt the SSL connection')
|
||||
except ssl.SSLError:
|
||||
pass
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
async def client(addr):
|
||||
reader, writer = await asyncio.open_connection(
|
||||
*addr,
|
||||
ssl=client_sslctx,
|
||||
server_hostname='',
|
||||
loop=self.loop)
|
||||
|
||||
self.assertEqual(await reader.readline(), b'A\n')
|
||||
writer.write(b'B')
|
||||
with self.assertRaises(ssl.SSLError):
|
||||
await reader.readline()
|
||||
return 'OK'
|
||||
|
||||
with self.tcp_server(server,
|
||||
max_clients=1,
|
||||
backlog=1) as srv:
|
||||
|
||||
res = self.loop.run_until_complete(client(srv.addr))
|
||||
|
||||
self.assertEqual(res, 'OK')
|
||||
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase):
|
||||
|
|
|
@ -77,9 +77,10 @@ def simple_server_sslcontext():
|
|||
return server_context
|
||||
|
||||
|
||||
def simple_client_sslcontext():
|
||||
def simple_client_sslcontext(*, disable_verify=True):
|
||||
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
client_context.check_hostname = False
|
||||
if disable_verify:
|
||||
client_context.verify_mode = ssl.CERT_NONE
|
||||
return client_context
|
||||
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
asyncio/ssl: Fix AttributeError, increase default handshake timeout
|
Loading…
Reference in New Issue