mirror of https://github.com/python/cpython
asyncio: Fix SSLProtocol.eof_received()
Wake-up the waiter if it is not done yet.
This commit is contained in:
parent
a89f5f0492
commit
b507cbaac5
|
@ -489,6 +489,10 @@ class SSLProtocol(protocols.Protocol):
|
|||
try:
|
||||
if self._loop.get_debug():
|
||||
logger.debug("%r received EOF", self)
|
||||
|
||||
if self._waiter is not None and not self._waiter.done():
|
||||
self._waiter.set_exception(ConnectionResetError())
|
||||
|
||||
if not self._in_handshake:
|
||||
keep_open = self._app_protocol.eof_received()
|
||||
if keep_open:
|
||||
|
|
|
@ -12,21 +12,36 @@ from asyncio import sslproto
|
|||
from asyncio import test_utils
|
||||
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
class SslProtoHandshakeTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.set_event_loop(self.loop)
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
def ssl_protocol(self, waiter=None):
|
||||
sslcontext = test_utils.dummy_ssl_context()
|
||||
app_proto = asyncio.Protocol()
|
||||
return sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter)
|
||||
|
||||
def connection_made(self, ssl_proto, do_handshake=None):
|
||||
transport = mock.Mock()
|
||||
sslpipe = mock.Mock()
|
||||
sslpipe.shutdown.return_value = b''
|
||||
if do_handshake:
|
||||
sslpipe.do_handshake.side_effect = do_handshake
|
||||
else:
|
||||
def mock_handshake(callback):
|
||||
return []
|
||||
sslpipe.do_handshake.side_effect = mock_handshake
|
||||
with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe):
|
||||
ssl_proto.connection_made(transport)
|
||||
|
||||
def test_cancel_handshake(self):
|
||||
# Python issue #23197: cancelling an handshake must not raise an
|
||||
# exception or log an error, even if the handshake failed
|
||||
sslcontext = test_utils.dummy_ssl_context()
|
||||
app_proto = asyncio.Protocol()
|
||||
waiter = asyncio.Future(loop=self.loop)
|
||||
ssl_proto = sslproto.SSLProtocol(self.loop, app_proto, sslcontext,
|
||||
waiter)
|
||||
ssl_proto = self.ssl_protocol(waiter)
|
||||
handshake_fut = asyncio.Future(loop=self.loop)
|
||||
|
||||
def do_handshake(callback):
|
||||
|
@ -36,12 +51,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
|||
return []
|
||||
|
||||
waiter.cancel()
|
||||
transport = mock.Mock()
|
||||
sslpipe = mock.Mock()
|
||||
sslpipe.shutdown.return_value = b''
|
||||
sslpipe.do_handshake.side_effect = do_handshake
|
||||
with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe):
|
||||
ssl_proto.connection_made(transport)
|
||||
self.connection_made(ssl_proto, do_handshake)
|
||||
|
||||
with test_utils.disable_logger():
|
||||
self.loop.run_until_complete(handshake_fut)
|
||||
|
@ -49,6 +59,14 @@ class SslProtoHandshakeTests(test_utils.TestCase):
|
|||
# Close the transport
|
||||
ssl_proto._app_transport.close()
|
||||
|
||||
def test_eof_received_waiter(self):
|
||||
waiter = asyncio.Future(loop=self.loop)
|
||||
ssl_proto = self.ssl_protocol(waiter)
|
||||
self.connection_made(ssl_proto)
|
||||
ssl_proto.eof_received()
|
||||
test_utils.run_briefly(self.loop)
|
||||
self.assertIsInstance(waiter.exception(), ConnectionResetError)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue