bpo-33654: Support protocol type switching in SSLTransport.set_protocol() (GH-7194)

(cherry picked from commit 2179022d94)

Co-authored-by: Yury Selivanov <yury@magic.io>
This commit is contained in:
Miss Islington (bot) 2018-05-29 02:19:09 -07:00 committed by GitHub
parent 2641ee5040
commit be5d616e55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 15 deletions

View File

@ -295,7 +295,7 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
return self._ssl_protocol._get_extra_info(name, default) return self._ssl_protocol._get_extra_info(name, default)
def set_protocol(self, protocol): def set_protocol(self, protocol):
self._ssl_protocol._app_protocol = protocol self._ssl_protocol._set_app_protocol(protocol)
def get_protocol(self): def get_protocol(self):
return self._ssl_protocol._app_protocol return self._ssl_protocol._app_protocol
@ -440,9 +440,7 @@ class SSLProtocol(protocols.Protocol):
self._waiter = waiter self._waiter = waiter
self._loop = loop self._loop = loop
self._app_protocol = app_protocol self._set_app_protocol(app_protocol)
self._app_protocol_is_buffer = \
isinstance(app_protocol, protocols.BufferedProtocol)
self._app_transport = _SSLProtocolTransport(self._loop, self) self._app_transport = _SSLProtocolTransport(self._loop, self)
# _SSLPipe instance (None until the connection is made) # _SSLPipe instance (None until the connection is made)
self._sslpipe = None self._sslpipe = None
@ -454,6 +452,11 @@ class SSLProtocol(protocols.Protocol):
self._call_connection_made = call_connection_made self._call_connection_made = call_connection_made
self._ssl_handshake_timeout = ssl_handshake_timeout self._ssl_handshake_timeout = ssl_handshake_timeout
def _set_app_protocol(self, app_protocol):
self._app_protocol = app_protocol
self._app_protocol_is_buffer = \
isinstance(app_protocol, protocols.BufferedProtocol)
def _wakeup_waiter(self, exc=None): def _wakeup_waiter(self, exc=None):
if self._waiter is None: if self._waiter is None:
return return

View File

@ -302,6 +302,7 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
server_context = test_utils.simple_server_sslcontext() server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext() client_context = test_utils.simple_client_sslcontext()
client_con_made_calls = 0
def serve(sock): def serve(sock):
sock.settimeout(self.TIMEOUT) sock.settimeout(self.TIMEOUT)
@ -315,20 +316,21 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
data = sock.recv_all(len(HELLO_MSG)) data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG))
sock.sendall(b'2')
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
sock.shutdown(socket.SHUT_RDWR) sock.shutdown(socket.SHUT_RDWR)
sock.close() sock.close()
class ClientProto(asyncio.BufferedProtocol): class ClientProtoFirst(asyncio.BufferedProtocol):
def __init__(self, on_data, on_eof): def __init__(self, on_data):
self.on_data = on_data self.on_data = on_data
self.on_eof = on_eof
self.con_made_cnt = 0
self.buf = bytearray(1) self.buf = bytearray(1)
def connection_made(proto, tr): def connection_made(self, tr):
proto.con_made_cnt += 1 nonlocal client_con_made_calls
# Ensure connection_made gets called only once. client_con_made_calls += 1
self.assertEqual(proto.con_made_cnt, 1)
def get_buffer(self, sizehint): def get_buffer(self, sizehint):
return self.buf return self.buf
@ -337,27 +339,50 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
assert nsize == 1 assert nsize == 1
self.on_data.set_result(bytes(self.buf[:nsize])) self.on_data.set_result(bytes(self.buf[:nsize]))
class ClientProtoSecond(asyncio.Protocol):
def __init__(self, on_data, on_eof):
self.on_data = on_data
self.on_eof = on_eof
self.con_made_cnt = 0
def connection_made(self, tr):
nonlocal client_con_made_calls
client_con_made_calls += 1
def data_received(self, data):
self.on_data.set_result(data)
def eof_received(self): def eof_received(self):
self.on_eof.set_result(True) self.on_eof.set_result(True)
async def client(addr): async def client(addr):
await asyncio.sleep(0.5, loop=self.loop) await asyncio.sleep(0.5, loop=self.loop)
on_data = self.loop.create_future() on_data1 = self.loop.create_future()
on_data2 = self.loop.create_future()
on_eof = self.loop.create_future() on_eof = self.loop.create_future()
tr, proto = await self.loop.create_connection( tr, proto = await self.loop.create_connection(
lambda: ClientProto(on_data, on_eof), *addr) lambda: ClientProtoFirst(on_data1), *addr)
tr.write(HELLO_MSG) tr.write(HELLO_MSG)
new_tr = await self.loop.start_tls(tr, proto, client_context) new_tr = await self.loop.start_tls(tr, proto, client_context)
self.assertEqual(await on_data, b'O') self.assertEqual(await on_data1, b'O')
new_tr.write(HELLO_MSG)
new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
self.assertEqual(await on_data2, b'2')
new_tr.write(HELLO_MSG) new_tr.write(HELLO_MSG)
await on_eof await on_eof
new_tr.close() new_tr.close()
# connection_made() should be called only once -- when
# we establish connection for the first time. Start TLS
# doesn't call connection_made() on application protocols.
self.assertEqual(client_con_made_calls, 1)
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
self.loop.run_until_complete( self.loop.run_until_complete(
asyncio.wait_for(client(srv.addr), asyncio.wait_for(client(srv.addr),

View File

@ -0,0 +1 @@
Support protocol type switching in SSLTransport.set_protocol().