bpo-33037: Skip sending/receiving after SSL transport closing (GH-6044) (GH-6057)
* Skip write()/data_received() if sslpipe is destroyed
(cherry picked from commit 5e80a71ab6
)
Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
This commit is contained in:
parent
89090789de
commit
bf0d116517
|
@ -504,6 +504,10 @@ class SSLProtocol(protocols.Protocol):
|
||||||
|
|
||||||
The argument is a bytes object.
|
The argument is a bytes object.
|
||||||
"""
|
"""
|
||||||
|
if self._sslpipe is None:
|
||||||
|
# transport closing, sslpipe is destroyed
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ssldata, appdata = self._sslpipe.feed_ssldata(data)
|
ssldata, appdata = self._sslpipe.feed_ssldata(data)
|
||||||
except ssl.SSLError as e:
|
except ssl.SSLError as e:
|
||||||
|
@ -636,7 +640,7 @@ class SSLProtocol(protocols.Protocol):
|
||||||
|
|
||||||
def _process_write_backlog(self):
|
def _process_write_backlog(self):
|
||||||
# Try to make progress on the write backlog.
|
# Try to make progress on the write backlog.
|
||||||
if self._transport is None:
|
if self._transport is None or self._sslpipe is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -26,16 +26,17 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
||||||
self.loop = asyncio.new_event_loop()
|
self.loop = asyncio.new_event_loop()
|
||||||
self.set_event_loop(self.loop)
|
self.set_event_loop(self.loop)
|
||||||
|
|
||||||
def ssl_protocol(self, waiter=None):
|
def ssl_protocol(self, *, waiter=None, proto=None):
|
||||||
sslcontext = test_utils.dummy_ssl_context()
|
sslcontext = test_utils.dummy_ssl_context()
|
||||||
app_proto = asyncio.Protocol()
|
if proto is None: # app protocol
|
||||||
proto = sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
|
proto = asyncio.Protocol()
|
||||||
ssl_handshake_timeout=0.1)
|
ssl_proto = sslproto.SSLProtocol(self.loop, proto, sslcontext, waiter,
|
||||||
self.assertIs(proto._app_transport.get_protocol(), app_proto)
|
ssl_handshake_timeout=0.1)
|
||||||
self.addCleanup(proto._app_transport.close)
|
self.assertIs(ssl_proto._app_transport.get_protocol(), proto)
|
||||||
return proto
|
self.addCleanup(ssl_proto._app_transport.close)
|
||||||
|
return ssl_proto
|
||||||
|
|
||||||
def connection_made(self, ssl_proto, do_handshake=None):
|
def connection_made(self, ssl_proto, *, do_handshake=None):
|
||||||
transport = mock.Mock()
|
transport = mock.Mock()
|
||||||
sslpipe = mock.Mock()
|
sslpipe = mock.Mock()
|
||||||
sslpipe.shutdown.return_value = b''
|
sslpipe.shutdown.return_value = b''
|
||||||
|
@ -53,7 +54,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
||||||
# Python issue #23197: cancelling a handshake must not raise an
|
# Python issue #23197: cancelling a handshake must not raise an
|
||||||
# exception or log an error, even if the handshake failed
|
# exception or log an error, even if the handshake failed
|
||||||
waiter = asyncio.Future(loop=self.loop)
|
waiter = asyncio.Future(loop=self.loop)
|
||||||
ssl_proto = self.ssl_protocol(waiter)
|
ssl_proto = self.ssl_protocol(waiter=waiter)
|
||||||
handshake_fut = asyncio.Future(loop=self.loop)
|
handshake_fut = asyncio.Future(loop=self.loop)
|
||||||
|
|
||||||
def do_handshake(callback):
|
def do_handshake(callback):
|
||||||
|
@ -63,7 +64,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
waiter.cancel()
|
waiter.cancel()
|
||||||
self.connection_made(ssl_proto, do_handshake)
|
self.connection_made(ssl_proto, do_handshake=do_handshake)
|
||||||
|
|
||||||
with test_utils.disable_logger():
|
with test_utils.disable_logger():
|
||||||
self.loop.run_until_complete(handshake_fut)
|
self.loop.run_until_complete(handshake_fut)
|
||||||
|
@ -96,7 +97,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
||||||
|
|
||||||
def test_eof_received_waiter(self):
|
def test_eof_received_waiter(self):
|
||||||
waiter = asyncio.Future(loop=self.loop)
|
waiter = asyncio.Future(loop=self.loop)
|
||||||
ssl_proto = self.ssl_protocol(waiter)
|
ssl_proto = self.ssl_protocol(waiter=waiter)
|
||||||
self.connection_made(ssl_proto)
|
self.connection_made(ssl_proto)
|
||||||
ssl_proto.eof_received()
|
ssl_proto.eof_received()
|
||||||
test_utils.run_briefly(self.loop)
|
test_utils.run_briefly(self.loop)
|
||||||
|
@ -107,7 +108,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
||||||
# _fatal_error() generates a NameError if sslproto.py
|
# _fatal_error() generates a NameError if sslproto.py
|
||||||
# does not import base_events.
|
# does not import base_events.
|
||||||
waiter = asyncio.Future(loop=self.loop)
|
waiter = asyncio.Future(loop=self.loop)
|
||||||
ssl_proto = self.ssl_protocol(waiter)
|
ssl_proto = self.ssl_protocol(waiter=waiter)
|
||||||
# Temporarily turn off error logging so as not to spoil test output.
|
# Temporarily turn off error logging so as not to spoil test output.
|
||||||
log_level = log.logger.getEffectiveLevel()
|
log_level = log.logger.getEffectiveLevel()
|
||||||
log.logger.setLevel(logging.FATAL)
|
log.logger.setLevel(logging.FATAL)
|
||||||
|
@ -121,7 +122,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
||||||
# From issue #472.
|
# From issue #472.
|
||||||
# yield from waiter hang if lost_connection was called.
|
# yield from waiter hang if lost_connection was called.
|
||||||
waiter = asyncio.Future(loop=self.loop)
|
waiter = asyncio.Future(loop=self.loop)
|
||||||
ssl_proto = self.ssl_protocol(waiter)
|
ssl_proto = self.ssl_protocol(waiter=waiter)
|
||||||
self.connection_made(ssl_proto)
|
self.connection_made(ssl_proto)
|
||||||
ssl_proto.connection_lost(ConnectionAbortedError)
|
ssl_proto.connection_lost(ConnectionAbortedError)
|
||||||
test_utils.run_briefly(self.loop)
|
test_utils.run_briefly(self.loop)
|
||||||
|
@ -130,10 +131,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
||||||
def test_close_during_handshake(self):
|
def test_close_during_handshake(self):
|
||||||
# bpo-29743 Closing transport during handshake process leaks socket
|
# bpo-29743 Closing transport during handshake process leaks socket
|
||||||
waiter = asyncio.Future(loop=self.loop)
|
waiter = asyncio.Future(loop=self.loop)
|
||||||
ssl_proto = self.ssl_protocol(waiter)
|
ssl_proto = self.ssl_protocol(waiter=waiter)
|
||||||
|
|
||||||
def do_handshake(callback):
|
|
||||||
return []
|
|
||||||
|
|
||||||
transport = self.connection_made(ssl_proto)
|
transport = self.connection_made(ssl_proto)
|
||||||
test_utils.run_briefly(self.loop)
|
test_utils.run_briefly(self.loop)
|
||||||
|
@ -143,7 +141,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
||||||
|
|
||||||
def test_get_extra_info_on_closed_connection(self):
|
def test_get_extra_info_on_closed_connection(self):
|
||||||
waiter = asyncio.Future(loop=self.loop)
|
waiter = asyncio.Future(loop=self.loop)
|
||||||
ssl_proto = self.ssl_protocol(waiter)
|
ssl_proto = self.ssl_protocol(waiter=waiter)
|
||||||
self.assertIsNone(ssl_proto._get_extra_info('socket'))
|
self.assertIsNone(ssl_proto._get_extra_info('socket'))
|
||||||
default = object()
|
default = object()
|
||||||
self.assertIs(ssl_proto._get_extra_info('socket', default), default)
|
self.assertIs(ssl_proto._get_extra_info('socket', default), default)
|
||||||
|
@ -154,12 +152,31 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
||||||
|
|
||||||
def test_set_new_app_protocol(self):
|
def test_set_new_app_protocol(self):
|
||||||
waiter = asyncio.Future(loop=self.loop)
|
waiter = asyncio.Future(loop=self.loop)
|
||||||
ssl_proto = self.ssl_protocol(waiter)
|
ssl_proto = self.ssl_protocol(waiter=waiter)
|
||||||
new_app_proto = asyncio.Protocol()
|
new_app_proto = asyncio.Protocol()
|
||||||
ssl_proto._app_transport.set_protocol(new_app_proto)
|
ssl_proto._app_transport.set_protocol(new_app_proto)
|
||||||
self.assertIs(ssl_proto._app_transport.get_protocol(), new_app_proto)
|
self.assertIs(ssl_proto._app_transport.get_protocol(), new_app_proto)
|
||||||
self.assertIs(ssl_proto._app_protocol, new_app_proto)
|
self.assertIs(ssl_proto._app_protocol, new_app_proto)
|
||||||
|
|
||||||
|
def test_data_received_after_closing(self):
|
||||||
|
ssl_proto = self.ssl_protocol()
|
||||||
|
self.connection_made(ssl_proto)
|
||||||
|
transp = ssl_proto._app_transport
|
||||||
|
|
||||||
|
transp.close()
|
||||||
|
|
||||||
|
# should not raise
|
||||||
|
self.assertIsNone(ssl_proto.data_received(b'data'))
|
||||||
|
|
||||||
|
def test_write_after_closing(self):
|
||||||
|
ssl_proto = self.ssl_protocol()
|
||||||
|
self.connection_made(ssl_proto)
|
||||||
|
transp = ssl_proto._app_transport
|
||||||
|
transp.close()
|
||||||
|
|
||||||
|
# should not raise
|
||||||
|
self.assertIsNone(transp.write(b'data'))
|
||||||
|
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
# Start TLS Tests
|
# Start TLS Tests
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Skip sending/receiving data after SSL transport closing.
|
Loading…
Reference in New Issue