bpo-33734: asyncio/ssl: a bunch of bugfixes (#7321)

* Fix AttributeError (not all SSL exceptions have 'errno' attribute)

* Increase default handshake timeout from 10 to 60 seconds
* Make sure start_tls can be cancelled correctly
* Make sure any error in SSLProtocol gets propagated (instead of just being logged)
This commit is contained in:
Yury Selivanov 2018-06-04 11:32:35 -04:00 committed by GitHub
parent a8eb58546b
commit 9602643120
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 258 additions and 72 deletions

View File

@ -351,7 +351,7 @@ Creating connections
* *ssl_handshake_timeout* is (for an SSL connection) the time in seconds
to wait for the SSL handshake to complete before aborting the connection.
``10.0`` seconds if ``None`` (default).
``60.0`` seconds if ``None`` (default).
.. versionadded:: 3.7
@ -497,7 +497,7 @@ Creating listening connections
* *ssl_handshake_timeout* is (for an SSL server) the time in seconds to wait
for the SSL handshake to complete before aborting the connection.
``10.0`` seconds if ``None`` (default).
``60.0`` seconds if ``None`` (default).
* *start_serving* set to ``True`` (the default) causes the created server
to start accepting connections immediately. When set to ``False``,
@ -559,7 +559,7 @@ Creating listening connections
* *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to
wait for the SSL handshake to complete before aborting the connection.
``10.0`` seconds if ``None`` (default).
``60.0`` seconds if ``None`` (default).
When completed it returns a ``(transport, protocol)`` pair.
@ -628,7 +628,7 @@ TLS Upgrade
* *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to
wait for the SSL handshake to complete before aborting the connection.
``10.0`` seconds if ``None`` (default).
``60.0`` seconds if ``None`` (default).
.. versionadded:: 3.7

View File

@ -1114,7 +1114,12 @@ class BaseEventLoop(events.AbstractEventLoop):
self.call_soon(ssl_protocol.connection_made, transport)
self.call_soon(transport.resume_reading)
await waiter
try:
await waiter
except Exception:
transport.close()
raise
return ssl_protocol._app_transport
async def create_datagram_endpoint(self, protocol_factory,

View File

@ -12,7 +12,8 @@ ACCEPT_RETRY_DELAY = 1
DEBUG_STACK_DEPTH = 10
# Number of seconds to wait for SSL handshake to complete
SSL_HANDSHAKE_TIMEOUT = 10.0
# The default timeout matches that of Nginx.
SSL_HANDSHAKE_TIMEOUT = 60.0
# Used in sendfile fallback code. We use fallback for platforms
# that don't support sendfile, or for TLS connections.

View File

@ -352,8 +352,7 @@ class AbstractEventLoop:
ssl_handshake_timeout is the time in seconds that an SSL server
will wait for completion of the SSL handshake before aborting the
connection. Default is 10s, longer timeouts may increase vulnerability
to DoS attacks (see https://support.f5.com/csp/article/K13834)
connection. Default is 60s.
start_serving set to True (default) causes the created server
to start accepting connections immediately. When set to False,
@ -411,7 +410,7 @@ class AbstractEventLoop:
accepted connections.
ssl_handshake_timeout is the time in seconds that an SSL server
will wait for the SSL handshake to complete (defaults to 10s).
will wait for the SSL handshake to complete (defaults to 60s).
start_serving set to True (default) causes the created server
to start accepting connections immediately. When set to False,

View File

@ -214,13 +214,14 @@ class _SSLPipe(object):
# Drain possible plaintext data after close_notify.
appdata.append(self._incoming.read())
except (ssl.SSLError, ssl.CertificateError) as exc:
if getattr(exc, 'errno', None) not in (
exc_errno = getattr(exc, 'errno', None)
if exc_errno not in (
ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE,
ssl.SSL_ERROR_SYSCALL):
if self._state == _DO_HANDSHAKE and self._handshake_cb:
self._handshake_cb(exc)
raise
self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ)
self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
# Check for record level data that needs to be sent back.
# Happens for the initial handshake and renegotiations.
@ -263,13 +264,14 @@ class _SSLPipe(object):
# It is not allowed to call write() after unwrap() until the
# close_notify is acknowledged. We return the condition to the
# caller as a short write.
exc_errno = getattr(exc, 'errno', None)
if exc.reason == 'PROTOCOL_IS_SHUTDOWN':
exc.errno = ssl.SSL_ERROR_WANT_READ
if exc.errno not in (ssl.SSL_ERROR_WANT_READ,
exc_errno = exc.errno = ssl.SSL_ERROR_WANT_READ
if exc_errno not in (ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE,
ssl.SSL_ERROR_SYSCALL):
raise
self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ)
self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
# See if there's any record level data back for us.
if self._outgoing.pending:
@ -488,6 +490,12 @@ class SSLProtocol(protocols.Protocol):
if self._session_established:
self._session_established = False
self._loop.call_soon(self._app_protocol.connection_lost, exc)
else:
# Most likely an exception occurred while in SSL handshake.
# Just mark the app transport as closed so that its __del__
# doesn't complain.
if self._app_transport is not None:
self._app_transport._closed = True
self._transport = None
self._app_transport = None
self._wakeup_waiter(exc)
@ -515,11 +523,8 @@ class SSLProtocol(protocols.Protocol):
try:
ssldata, appdata = self._sslpipe.feed_ssldata(data)
except ssl.SSLError as e:
if self._loop.get_debug():
logger.warning('%r: SSL error %s (reason %s)',
self, e.errno, e.reason)
self._abort()
except Exception as e:
self._fatal_error(e, 'SSL error in data received')
return
for chunk in ssldata:
@ -602,8 +607,12 @@ class SSLProtocol(protocols.Protocol):
def _check_handshake_timeout(self):
if self._in_handshake is True:
logger.warning("%r stalled during handshake", self)
self._abort()
msg = (
f"SSL handshake is taking longer than "
f"{self._ssl_handshake_timeout} seconds: "
f"aborting the connection"
)
self._fatal_error(ConnectionAbortedError(msg))
def _on_handshake_complete(self, handshake_exc):
self._in_handshake = False
@ -615,21 +624,13 @@ class SSLProtocol(protocols.Protocol):
raise handshake_exc
peercert = sslobj.getpeercert()
except BaseException as exc:
if self._loop.get_debug():
if isinstance(exc, ssl.CertificateError):
logger.warning("%r: SSL handshake failed "
"on verifying the certificate",
self, exc_info=True)
else:
logger.warning("%r: SSL handshake failed",
self, exc_info=True)
self._transport.close()
if isinstance(exc, Exception):
self._wakeup_waiter(exc)
return
except Exception as exc:
if isinstance(exc, ssl.CertificateError):
msg = 'SSL handshake failed on verifying the certificate'
else:
raise
msg = 'SSL handshake failed'
self._fatal_error(exc, msg)
return
if self._loop.get_debug():
dt = self._loop.time() - self._handshake_start_time
@ -686,18 +687,14 @@ class SSLProtocol(protocols.Protocol):
# delete it and reduce the outstanding buffer size.
del self._write_backlog[0]
self._write_buffer_size -= len(data)
except BaseException as exc:
except Exception as exc:
if self._in_handshake:
# BaseExceptions will be re-raised in _on_handshake_complete.
# Exceptions will be re-raised in _on_handshake_complete.
self._on_handshake_complete(exc)
else:
self._fatal_error(exc, 'Fatal error on SSL transport')
if not isinstance(exc, Exception):
# BaseException
raise
def _fatal_error(self, exc, message='Fatal error on transport'):
# Should be called from exception handler only.
if isinstance(exc, base_events._FATAL_ERROR_IGNORE):
if self._loop.get_debug():
logger.debug("%r: %s", self, message, exc_info=True)

View File

@ -53,35 +53,6 @@ class SslProtoHandshakeTests(test_utils.TestCase):
ssl_proto.connection_made(transport)
return transport
def test_cancel_handshake(self):
# Python issue #23197: cancelling a handshake must not raise an
# exception or log an error, even if the handshake failed
waiter = asyncio.Future(loop=self.loop)
ssl_proto = self.ssl_protocol(waiter=waiter)
handshake_fut = asyncio.Future(loop=self.loop)
def do_handshake(callback):
exc = Exception()
callback(exc)
handshake_fut.set_result(None)
return []
waiter.cancel()
self.connection_made(ssl_proto, do_handshake=do_handshake)
with test_utils.disable_logger():
self.loop.run_until_complete(handshake_fut)
def test_handshake_timeout(self):
# bpo-29970: Check that a connection is aborted if handshake is not
# completed in timeout period, instead of remaining open indefinitely
ssl_proto = self.ssl_protocol()
transport = self.connection_made(ssl_proto)
with test_utils.disable_logger():
self.loop.run_until_complete(tasks.sleep(0.2, loop=self.loop))
self.assertTrue(transport.abort.called)
def test_handshake_timeout_zero(self):
sslcontext = test_utils.dummy_ssl_context()
app_proto = mock.Mock()
@ -392,6 +363,67 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
asyncio.wait_for(client(srv.addr),
loop=self.loop, timeout=self.TIMEOUT))
def test_start_tls_slow_client_cancel(self):
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
client_context = test_utils.simple_client_sslcontext()
server_waits_on_handshake = self.loop.create_future()
def serve(sock):
sock.settimeout(self.TIMEOUT)
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
try:
self.loop.call_soon_threadsafe(
server_waits_on_handshake.set_result, None)
data = sock.recv_all(1024 * 1024)
except ConnectionAbortedError:
pass
finally:
sock.close()
class ClientProto(asyncio.Protocol):
def __init__(self, on_data, on_eof):
self.on_data = on_data
self.on_eof = on_eof
self.con_made_cnt = 0
def connection_made(proto, tr):
proto.con_made_cnt += 1
# Ensure connection_made gets called only once.
self.assertEqual(proto.con_made_cnt, 1)
def data_received(self, data):
self.on_data.set_result(data)
def eof_received(self):
self.on_eof.set_result(True)
async def client(addr):
await asyncio.sleep(0.5, loop=self.loop)
on_data = self.loop.create_future()
on_eof = self.loop.create_future()
tr, proto = await self.loop.create_connection(
lambda: ClientProto(on_data, on_eof), *addr)
tr.write(HELLO_MSG)
await server_waits_on_handshake
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(
self.loop.start_tls(tr, proto, client_context),
0.5,
loop=self.loop)
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
self.loop.run_until_complete(
asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
def test_start_tls_server_1(self):
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
@ -481,6 +513,156 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
self.loop.run_until_complete(main())
def test_handshake_timeout(self):
# bpo-29970: Check that a connection is aborted if handshake is not
# completed in timeout period, instead of remaining open indefinitely
client_sslctx = test_utils.simple_client_sslcontext()
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
server_side_aborted = False
def server(sock):
nonlocal server_side_aborted
try:
sock.recv_all(1024 * 1024)
except ConnectionAbortedError:
server_side_aborted = True
finally:
sock.close()
async def client(addr):
await asyncio.wait_for(
self.loop.create_connection(
asyncio.Protocol,
*addr,
ssl=client_sslctx,
server_hostname='',
ssl_handshake_timeout=10.0),
0.5,
loop=self.loop)
with self.tcp_server(server,
max_clients=1,
backlog=1) as srv:
with self.assertRaises(asyncio.TimeoutError):
self.loop.run_until_complete(client(srv.addr))
self.assertTrue(server_side_aborted)
# Python issue #23197: cancelling a handshake must not raise an
# exception or log an error, even if the handshake failed
self.assertEqual(messages, [])
def test_create_connection_ssl_slow_handshake(self):
client_sslctx = test_utils.simple_client_sslcontext()
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
def server(sock):
try:
sock.recv_all(1024 * 1024)
except ConnectionAbortedError:
pass
finally:
sock.close()
async def client(addr):
reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='',
loop=self.loop,
ssl_handshake_timeout=1.0)
with self.tcp_server(server,
max_clients=1,
backlog=1) as srv:
with self.assertRaisesRegex(
ConnectionAbortedError,
r'SSL handshake.*is taking longer'):
self.loop.run_until_complete(client(srv.addr))
self.assertEqual(messages, [])
def test_create_connection_ssl_failed_certificate(self):
self.loop.set_exception_handler(lambda loop, ctx: None)
sslctx = test_utils.simple_server_sslcontext()
client_sslctx = test_utils.simple_client_sslcontext(
disable_verify=False)
def server(sock):
try:
sock.start_tls(
sslctx,
server_side=True)
except ssl.SSLError:
pass
finally:
sock.close()
async def client(addr):
reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='',
loop=self.loop,
ssl_handshake_timeout=1.0)
with self.tcp_server(server,
max_clients=1,
backlog=1) as srv:
with self.assertRaises(ssl.SSLCertVerificationError):
self.loop.run_until_complete(client(srv.addr))
def test_start_tls_client_corrupted_ssl(self):
self.loop.set_exception_handler(lambda loop, ctx: None)
sslctx = test_utils.simple_server_sslcontext()
client_sslctx = test_utils.simple_client_sslcontext()
def server(sock):
orig_sock = sock.dup()
try:
sock.start_tls(
sslctx,
server_side=True)
sock.sendall(b'A\n')
sock.recv_all(1)
orig_sock.send(b'please corrupt the SSL connection')
except ssl.SSLError:
pass
finally:
sock.close()
async def client(addr):
reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='',
loop=self.loop)
self.assertEqual(await reader.readline(), b'A\n')
writer.write(b'B')
with self.assertRaises(ssl.SSLError):
await reader.readline()
return 'OK'
with self.tcp_server(server,
max_clients=1,
backlog=1) as srv:
res = self.loop.run_until_complete(client(srv.addr))
self.assertEqual(res, 'OK')
@unittest.skipIf(ssl is None, 'No ssl module')
class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase):

View File

@ -77,10 +77,11 @@ def simple_server_sslcontext():
return server_context
def simple_client_sslcontext():
def simple_client_sslcontext(*, disable_verify=True):
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
client_context.check_hostname = False
client_context.verify_mode = ssl.CERT_NONE
if disable_verify:
client_context.verify_mode = ssl.CERT_NONE
return client_context

View File

@ -0,0 +1 @@
asyncio/ssl: Fix AttributeError, increase default handshake timeout