bpo-31632: fix set_protocol() in _SSLProtocolTransport (#3817) (#3817)

This commit is contained in:
jlacoline 2017-10-19 19:49:57 +02:00 committed by Yury Selivanov
parent ce9e625445
commit ea2ef5d0ca
4 changed files with 15 additions and 6 deletions

View File

@ -293,11 +293,10 @@ class _SSLPipe(object):
class _SSLProtocolTransport(transports._FlowControlMixin, class _SSLProtocolTransport(transports._FlowControlMixin,
transports.Transport): transports.Transport):
def __init__(self, loop, ssl_protocol, app_protocol): def __init__(self, loop, ssl_protocol):
self._loop = loop self._loop = loop
# SSLProtocol instance # SSLProtocol instance
self._ssl_protocol = ssl_protocol self._ssl_protocol = ssl_protocol
self._app_protocol = app_protocol
self._closed = False self._closed = False
def get_extra_info(self, name, default=None): def get_extra_info(self, name, default=None):
@ -305,10 +304,10 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
return self._ssl_protocol._get_extra_info(name, default) return self._ssl_protocol._get_extra_info(name, default)
def set_protocol(self, protocol): def set_protocol(self, protocol):
self._app_protocol = protocol self._ssl_protocol._app_protocol = protocol
def get_protocol(self): def get_protocol(self):
return self._app_protocol return self._ssl_protocol._app_protocol
def is_closing(self): def is_closing(self):
return self._closed return self._closed
@ -431,8 +430,7 @@ class SSLProtocol(protocols.Protocol):
self._waiter = waiter self._waiter = waiter
self._loop = loop self._loop = loop
self._app_protocol = app_protocol self._app_protocol = app_protocol
self._app_transport = _SSLProtocolTransport(self._loop, self._app_transport = _SSLProtocolTransport(self._loop, self)
self, self._app_protocol)
# _SSLPipe instance (None until the connection is made) # _SSLPipe instance (None until the connection is made)
self._sslpipe = None self._sslpipe = None
self._session_established = False self._session_established = False

View File

@ -121,6 +121,14 @@ class SslProtoHandshakeTests(test_utils.TestCase):
ssl_proto.connection_lost(None) ssl_proto.connection_lost(None)
self.assertIsNone(ssl_proto._get_extra_info('socket')) self.assertIsNone(ssl_proto._get_extra_info('socket'))
def test_set_new_app_protocol(self):
waiter = asyncio.Future(loop=self.loop)
ssl_proto = self.ssl_protocol(waiter)
new_app_proto = asyncio.Protocol()
ssl_proto._app_transport.set_protocol(new_app_proto)
self.assertIs(ssl_proto._app_transport.get_protocol(), new_app_proto)
self.assertIs(ssl_proto._app_protocol, new_app_proto)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -858,6 +858,7 @@ Vladimir Kushnir
Erno Kuusela Erno Kuusela
Ross Lagerwall Ross Lagerwall
Cameron Laird Cameron Laird
Loïc Lajeanne
David Lam David Lam
Thomas Lamb Thomas Lamb
Valerie Lambert Valerie Lambert

View File

@ -0,0 +1,2 @@
Fix method set_protocol() of class _SSLProtocolTransport in asyncio module.
This method was previously modifying a wrong reference to the protocol.