bpo-33734: asyncio/ssl: a bunch of bugfixes (#7321)

* Fix AttributeError (not all SSL exceptions have 'errno' attribute)

* Increase default handshake timeout from 10 to 60 seconds
* Make sure start_tls can be cancelled correctly
* Make sure any error in SSLProtocol gets propagated (instead of just being logged)
This commit is contained in:
Yury Selivanov 2018-06-04 11:32:35 -04:00 committed by GitHub
parent a8eb58546b
commit 9602643120
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 258 additions and 72 deletions

View File

@ -351,7 +351,7 @@ Creating connections
* *ssl_handshake_timeout* is (for an SSL connection) the time in seconds * *ssl_handshake_timeout* is (for an SSL connection) the time in seconds
to wait for the SSL handshake to complete before aborting the connection. to wait for the SSL handshake to complete before aborting the connection.
``10.0`` seconds if ``None`` (default). ``60.0`` seconds if ``None`` (default).
.. versionadded:: 3.7 .. versionadded:: 3.7
@ -497,7 +497,7 @@ Creating listening connections
* *ssl_handshake_timeout* is (for an SSL server) the time in seconds to wait * *ssl_handshake_timeout* is (for an SSL server) the time in seconds to wait
for the SSL handshake to complete before aborting the connection. for the SSL handshake to complete before aborting the connection.
``10.0`` seconds if ``None`` (default). ``60.0`` seconds if ``None`` (default).
* *start_serving* set to ``True`` (the default) causes the created server * *start_serving* set to ``True`` (the default) causes the created server
to start accepting connections immediately. When set to ``False``, to start accepting connections immediately. When set to ``False``,
@ -559,7 +559,7 @@ Creating listening connections
* *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to * *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to
wait for the SSL handshake to complete before aborting the connection. wait for the SSL handshake to complete before aborting the connection.
``10.0`` seconds if ``None`` (default). ``60.0`` seconds if ``None`` (default).
When completed it returns a ``(transport, protocol)`` pair. When completed it returns a ``(transport, protocol)`` pair.
@ -628,7 +628,7 @@ TLS Upgrade
* *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to * *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to
wait for the SSL handshake to complete before aborting the connection. wait for the SSL handshake to complete before aborting the connection.
``10.0`` seconds if ``None`` (default). ``60.0`` seconds if ``None`` (default).
.. versionadded:: 3.7 .. versionadded:: 3.7

View File

@ -1114,7 +1114,12 @@ class BaseEventLoop(events.AbstractEventLoop):
self.call_soon(ssl_protocol.connection_made, transport) self.call_soon(ssl_protocol.connection_made, transport)
self.call_soon(transport.resume_reading) self.call_soon(transport.resume_reading)
await waiter try:
await waiter
except Exception:
transport.close()
raise
return ssl_protocol._app_transport return ssl_protocol._app_transport
async def create_datagram_endpoint(self, protocol_factory, async def create_datagram_endpoint(self, protocol_factory,

View File

@ -12,7 +12,8 @@ ACCEPT_RETRY_DELAY = 1
DEBUG_STACK_DEPTH = 10 DEBUG_STACK_DEPTH = 10
# Number of seconds to wait for SSL handshake to complete # Number of seconds to wait for SSL handshake to complete
SSL_HANDSHAKE_TIMEOUT = 10.0 # The default timeout matches that of Nginx.
SSL_HANDSHAKE_TIMEOUT = 60.0
# Used in sendfile fallback code. We use fallback for platforms # Used in sendfile fallback code. We use fallback for platforms
# that don't support sendfile, or for TLS connections. # that don't support sendfile, or for TLS connections.

View File

@ -352,8 +352,7 @@ class AbstractEventLoop:
ssl_handshake_timeout is the time in seconds that an SSL server ssl_handshake_timeout is the time in seconds that an SSL server
will wait for completion of the SSL handshake before aborting the will wait for completion of the SSL handshake before aborting the
connection. Default is 10s, longer timeouts may increase vulnerability connection. Default is 60s.
to DoS attacks (see https://support.f5.com/csp/article/K13834)
start_serving set to True (default) causes the created server start_serving set to True (default) causes the created server
to start accepting connections immediately. When set to False, to start accepting connections immediately. When set to False,
@ -411,7 +410,7 @@ class AbstractEventLoop:
accepted connections. accepted connections.
ssl_handshake_timeout is the time in seconds that an SSL server ssl_handshake_timeout is the time in seconds that an SSL server
will wait for the SSL handshake to complete (defaults to 10s). will wait for the SSL handshake to complete (defaults to 60s).
start_serving set to True (default) causes the created server start_serving set to True (default) causes the created server
to start accepting connections immediately. When set to False, to start accepting connections immediately. When set to False,

View File

@ -214,13 +214,14 @@ class _SSLPipe(object):
# Drain possible plaintext data after close_notify. # Drain possible plaintext data after close_notify.
appdata.append(self._incoming.read()) appdata.append(self._incoming.read())
except (ssl.SSLError, ssl.CertificateError) as exc: except (ssl.SSLError, ssl.CertificateError) as exc:
if getattr(exc, 'errno', None) not in ( exc_errno = getattr(exc, 'errno', None)
if exc_errno not in (
ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE,
ssl.SSL_ERROR_SYSCALL): ssl.SSL_ERROR_SYSCALL):
if self._state == _DO_HANDSHAKE and self._handshake_cb: if self._state == _DO_HANDSHAKE and self._handshake_cb:
self._handshake_cb(exc) self._handshake_cb(exc)
raise raise
self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ) self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
# Check for record level data that needs to be sent back. # Check for record level data that needs to be sent back.
# Happens for the initial handshake and renegotiations. # Happens for the initial handshake and renegotiations.
@ -263,13 +264,14 @@ class _SSLPipe(object):
# It is not allowed to call write() after unwrap() until the # It is not allowed to call write() after unwrap() until the
# close_notify is acknowledged. We return the condition to the # close_notify is acknowledged. We return the condition to the
# caller as a short write. # caller as a short write.
exc_errno = getattr(exc, 'errno', None)
if exc.reason == 'PROTOCOL_IS_SHUTDOWN': if exc.reason == 'PROTOCOL_IS_SHUTDOWN':
exc.errno = ssl.SSL_ERROR_WANT_READ exc_errno = exc.errno = ssl.SSL_ERROR_WANT_READ
if exc.errno not in (ssl.SSL_ERROR_WANT_READ, if exc_errno not in (ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE, ssl.SSL_ERROR_WANT_WRITE,
ssl.SSL_ERROR_SYSCALL): ssl.SSL_ERROR_SYSCALL):
raise raise
self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ) self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
# See if there's any record level data back for us. # See if there's any record level data back for us.
if self._outgoing.pending: if self._outgoing.pending:
@ -488,6 +490,12 @@ class SSLProtocol(protocols.Protocol):
if self._session_established: if self._session_established:
self._session_established = False self._session_established = False
self._loop.call_soon(self._app_protocol.connection_lost, exc) self._loop.call_soon(self._app_protocol.connection_lost, exc)
else:
# Most likely an exception occurred while in SSL handshake.
# Just mark the app transport as closed so that its __del__
# doesn't complain.
if self._app_transport is not None:
self._app_transport._closed = True
self._transport = None self._transport = None
self._app_transport = None self._app_transport = None
self._wakeup_waiter(exc) self._wakeup_waiter(exc)
@ -515,11 +523,8 @@ class SSLProtocol(protocols.Protocol):
try: try:
ssldata, appdata = self._sslpipe.feed_ssldata(data) ssldata, appdata = self._sslpipe.feed_ssldata(data)
except ssl.SSLError as e: except Exception as e:
if self._loop.get_debug(): self._fatal_error(e, 'SSL error in data received')
logger.warning('%r: SSL error %s (reason %s)',
self, e.errno, e.reason)
self._abort()
return return
for chunk in ssldata: for chunk in ssldata:
@ -602,8 +607,12 @@ class SSLProtocol(protocols.Protocol):
def _check_handshake_timeout(self): def _check_handshake_timeout(self):
if self._in_handshake is True: if self._in_handshake is True:
logger.warning("%r stalled during handshake", self) msg = (
self._abort() f"SSL handshake is taking longer than "
f"{self._ssl_handshake_timeout} seconds: "
f"aborting the connection"
)
self._fatal_error(ConnectionAbortedError(msg))
def _on_handshake_complete(self, handshake_exc): def _on_handshake_complete(self, handshake_exc):
self._in_handshake = False self._in_handshake = False
@ -615,21 +624,13 @@ class SSLProtocol(protocols.Protocol):
raise handshake_exc raise handshake_exc
peercert = sslobj.getpeercert() peercert = sslobj.getpeercert()
except BaseException as exc: except Exception as exc:
if self._loop.get_debug(): if isinstance(exc, ssl.CertificateError):
if isinstance(exc, ssl.CertificateError): msg = 'SSL handshake failed on verifying the certificate'
logger.warning("%r: SSL handshake failed "
"on verifying the certificate",
self, exc_info=True)
else:
logger.warning("%r: SSL handshake failed",
self, exc_info=True)
self._transport.close()
if isinstance(exc, Exception):
self._wakeup_waiter(exc)
return
else: else:
raise msg = 'SSL handshake failed'
self._fatal_error(exc, msg)
return
if self._loop.get_debug(): if self._loop.get_debug():
dt = self._loop.time() - self._handshake_start_time dt = self._loop.time() - self._handshake_start_time
@ -686,18 +687,14 @@ class SSLProtocol(protocols.Protocol):
# delete it and reduce the outstanding buffer size. # delete it and reduce the outstanding buffer size.
del self._write_backlog[0] del self._write_backlog[0]
self._write_buffer_size -= len(data) self._write_buffer_size -= len(data)
except BaseException as exc: except Exception as exc:
if self._in_handshake: if self._in_handshake:
# BaseExceptions will be re-raised in _on_handshake_complete. # Exceptions will be re-raised in _on_handshake_complete.
self._on_handshake_complete(exc) self._on_handshake_complete(exc)
else: else:
self._fatal_error(exc, 'Fatal error on SSL transport') self._fatal_error(exc, 'Fatal error on SSL transport')
if not isinstance(exc, Exception):
# BaseException
raise
def _fatal_error(self, exc, message='Fatal error on transport'): def _fatal_error(self, exc, message='Fatal error on transport'):
# Should be called from exception handler only.
if isinstance(exc, base_events._FATAL_ERROR_IGNORE): if isinstance(exc, base_events._FATAL_ERROR_IGNORE):
if self._loop.get_debug(): if self._loop.get_debug():
logger.debug("%r: %s", self, message, exc_info=True) logger.debug("%r: %s", self, message, exc_info=True)

View File

@ -53,35 +53,6 @@ class SslProtoHandshakeTests(test_utils.TestCase):
ssl_proto.connection_made(transport) ssl_proto.connection_made(transport)
return transport return transport
def test_cancel_handshake(self):
# Python issue #23197: cancelling a handshake must not raise an
# exception or log an error, even if the handshake failed
waiter = asyncio.Future(loop=self.loop)
ssl_proto = self.ssl_protocol(waiter=waiter)
handshake_fut = asyncio.Future(loop=self.loop)
def do_handshake(callback):
exc = Exception()
callback(exc)
handshake_fut.set_result(None)
return []
waiter.cancel()
self.connection_made(ssl_proto, do_handshake=do_handshake)
with test_utils.disable_logger():
self.loop.run_until_complete(handshake_fut)
def test_handshake_timeout(self):
# bpo-29970: Check that a connection is aborted if handshake is not
# completed in timeout period, instead of remaining open indefinitely
ssl_proto = self.ssl_protocol()
transport = self.connection_made(ssl_proto)
with test_utils.disable_logger():
self.loop.run_until_complete(tasks.sleep(0.2, loop=self.loop))
self.assertTrue(transport.abort.called)
def test_handshake_timeout_zero(self): def test_handshake_timeout_zero(self):
sslcontext = test_utils.dummy_ssl_context() sslcontext = test_utils.dummy_ssl_context()
app_proto = mock.Mock() app_proto = mock.Mock()
@ -392,6 +363,67 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
asyncio.wait_for(client(srv.addr), asyncio.wait_for(client(srv.addr),
loop=self.loop, timeout=self.TIMEOUT)) loop=self.loop, timeout=self.TIMEOUT))
def test_start_tls_slow_client_cancel(self):
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
client_context = test_utils.simple_client_sslcontext()
server_waits_on_handshake = self.loop.create_future()
def serve(sock):
sock.settimeout(self.TIMEOUT)
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
try:
self.loop.call_soon_threadsafe(
server_waits_on_handshake.set_result, None)
data = sock.recv_all(1024 * 1024)
except ConnectionAbortedError:
pass
finally:
sock.close()
class ClientProto(asyncio.Protocol):
def __init__(self, on_data, on_eof):
self.on_data = on_data
self.on_eof = on_eof
self.con_made_cnt = 0
def connection_made(proto, tr):
proto.con_made_cnt += 1
# Ensure connection_made gets called only once.
self.assertEqual(proto.con_made_cnt, 1)
def data_received(self, data):
self.on_data.set_result(data)
def eof_received(self):
self.on_eof.set_result(True)
async def client(addr):
await asyncio.sleep(0.5, loop=self.loop)
on_data = self.loop.create_future()
on_eof = self.loop.create_future()
tr, proto = await self.loop.create_connection(
lambda: ClientProto(on_data, on_eof), *addr)
tr.write(HELLO_MSG)
await server_waits_on_handshake
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(
self.loop.start_tls(tr, proto, client_context),
0.5,
loop=self.loop)
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
self.loop.run_until_complete(
asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
def test_start_tls_server_1(self): def test_start_tls_server_1(self):
HELLO_MSG = b'1' * self.PAYLOAD_SIZE HELLO_MSG = b'1' * self.PAYLOAD_SIZE
@ -481,6 +513,156 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
self.loop.run_until_complete(main()) self.loop.run_until_complete(main())
def test_handshake_timeout(self):
# bpo-29970: Check that a connection is aborted if handshake is not
# completed in timeout period, instead of remaining open indefinitely
client_sslctx = test_utils.simple_client_sslcontext()
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
server_side_aborted = False
def server(sock):
nonlocal server_side_aborted
try:
sock.recv_all(1024 * 1024)
except ConnectionAbortedError:
server_side_aborted = True
finally:
sock.close()
async def client(addr):
await asyncio.wait_for(
self.loop.create_connection(
asyncio.Protocol,
*addr,
ssl=client_sslctx,
server_hostname='',
ssl_handshake_timeout=10.0),
0.5,
loop=self.loop)
with self.tcp_server(server,
max_clients=1,
backlog=1) as srv:
with self.assertRaises(asyncio.TimeoutError):
self.loop.run_until_complete(client(srv.addr))
self.assertTrue(server_side_aborted)
# Python issue #23197: cancelling a handshake must not raise an
# exception or log an error, even if the handshake failed
self.assertEqual(messages, [])
def test_create_connection_ssl_slow_handshake(self):
client_sslctx = test_utils.simple_client_sslcontext()
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
def server(sock):
try:
sock.recv_all(1024 * 1024)
except ConnectionAbortedError:
pass
finally:
sock.close()
async def client(addr):
reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='',
loop=self.loop,
ssl_handshake_timeout=1.0)
with self.tcp_server(server,
max_clients=1,
backlog=1) as srv:
with self.assertRaisesRegex(
ConnectionAbortedError,
r'SSL handshake.*is taking longer'):
self.loop.run_until_complete(client(srv.addr))
self.assertEqual(messages, [])
def test_create_connection_ssl_failed_certificate(self):
self.loop.set_exception_handler(lambda loop, ctx: None)
sslctx = test_utils.simple_server_sslcontext()
client_sslctx = test_utils.simple_client_sslcontext(
disable_verify=False)
def server(sock):
try:
sock.start_tls(
sslctx,
server_side=True)
except ssl.SSLError:
pass
finally:
sock.close()
async def client(addr):
reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='',
loop=self.loop,
ssl_handshake_timeout=1.0)
with self.tcp_server(server,
max_clients=1,
backlog=1) as srv:
with self.assertRaises(ssl.SSLCertVerificationError):
self.loop.run_until_complete(client(srv.addr))
def test_start_tls_client_corrupted_ssl(self):
self.loop.set_exception_handler(lambda loop, ctx: None)
sslctx = test_utils.simple_server_sslcontext()
client_sslctx = test_utils.simple_client_sslcontext()
def server(sock):
orig_sock = sock.dup()
try:
sock.start_tls(
sslctx,
server_side=True)
sock.sendall(b'A\n')
sock.recv_all(1)
orig_sock.send(b'please corrupt the SSL connection')
except ssl.SSLError:
pass
finally:
sock.close()
async def client(addr):
reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='',
loop=self.loop)
self.assertEqual(await reader.readline(), b'A\n')
writer.write(b'B')
with self.assertRaises(ssl.SSLError):
await reader.readline()
return 'OK'
with self.tcp_server(server,
max_clients=1,
backlog=1) as srv:
res = self.loop.run_until_complete(client(srv.addr))
self.assertEqual(res, 'OK')
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase): class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase):

View File

@ -77,10 +77,11 @@ def simple_server_sslcontext():
return server_context return server_context
def simple_client_sslcontext(): def simple_client_sslcontext(*, disable_verify=True):
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
client_context.check_hostname = False client_context.check_hostname = False
client_context.verify_mode = ssl.CERT_NONE if disable_verify:
client_context.verify_mode = ssl.CERT_NONE
return client_context return client_context

View File

@ -0,0 +1 @@
asyncio/ssl: Fix AttributeError, increase default handshake timeout