bpo-35998: Avoid TimeoutError in test_asyncio: test_start_tls_server_1() (GH-14080)

This commit is contained in:
Andrew Svetlov 2019-06-14 18:26:24 +03:00 committed by Victor Stinner
parent 431478d5d7
commit f0749da9a5
2 changed files with 20 additions and 18 deletions

View File

@ -494,17 +494,14 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
def test_start_tls_server_1(self): def test_start_tls_server_1(self):
HELLO_MSG = b'1' * self.PAYLOAD_SIZE HELLO_MSG = b'1' * self.PAYLOAD_SIZE
ANSWER = b'answer'
server_context = test_utils.simple_server_sslcontext() server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext() client_context = test_utils.simple_client_sslcontext()
if sys.platform.startswith('freebsd') or sys.platform.startswith('win'): answer = None
# bpo-35031: Some FreeBSD and Windows buildbots fail to run this test
# as the eof was not being received by the server if the payload
# size is not big enough. This behaviour only appears if the
# client is using TLS1.3.
client_context.options |= ssl.OP_NO_TLSv1_3
def client(sock, addr): def client(sock, addr):
nonlocal answer
sock.settimeout(self.TIMEOUT) sock.settimeout(self.TIMEOUT)
sock.connect(addr) sock.connect(addr)
@ -513,33 +510,36 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
sock.start_tls(client_context) sock.start_tls(client_context)
sock.sendall(HELLO_MSG) sock.sendall(HELLO_MSG)
answer = sock.recv_all(len(ANSWER))
sock.shutdown(socket.SHUT_RDWR)
sock.close() sock.close()
class ServerProto(asyncio.Protocol): class ServerProto(asyncio.Protocol):
def __init__(self, on_con, on_eof, on_con_lost): def __init__(self, on_con, on_con_lost):
self.on_con = on_con self.on_con = on_con
self.on_eof = on_eof
self.on_con_lost = on_con_lost self.on_con_lost = on_con_lost
self.data = b'' self.data = b''
self.transport = None
def connection_made(self, tr): def connection_made(self, tr):
self.transport = tr
self.on_con.set_result(tr) self.on_con.set_result(tr)
def replace_transport(self, tr):
self.transport = tr
def data_received(self, data): def data_received(self, data):
self.data += data self.data += data
if len(self.data) >= len(HELLO_MSG):
def eof_received(self): self.transport.write(ANSWER)
self.on_eof.set_result(1)
def connection_lost(self, exc): def connection_lost(self, exc):
self.transport = None
if exc is None: if exc is None:
self.on_con_lost.set_result(None) self.on_con_lost.set_result(None)
else: else:
self.on_con_lost.set_exception(exc) self.on_con_lost.set_exception(exc)
async def main(proto, on_con, on_eof, on_con_lost): async def main(proto, on_con, on_con_lost):
tr = await on_con tr = await on_con
tr.write(HELLO_MSG) tr.write(HELLO_MSG)
@ -550,16 +550,16 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
server_side=True, server_side=True,
ssl_handshake_timeout=self.TIMEOUT) ssl_handshake_timeout=self.TIMEOUT)
await on_eof proto.replace_transport(new_tr)
await on_con_lost await on_con_lost
self.assertEqual(proto.data, HELLO_MSG) self.assertEqual(proto.data, HELLO_MSG)
new_tr.close() new_tr.close()
async def run_main(): async def run_main():
on_con = self.loop.create_future() on_con = self.loop.create_future()
on_eof = self.loop.create_future()
on_con_lost = self.loop.create_future() on_con_lost = self.loop.create_future()
proto = ServerProto(on_con, on_eof, on_con_lost) proto = ServerProto(on_con, on_con_lost)
server = await self.loop.create_server( server = await self.loop.create_server(
lambda: proto, '127.0.0.1', 0) lambda: proto, '127.0.0.1', 0)
@ -568,11 +568,12 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
with self.tcp_client(lambda sock: client(sock, addr), with self.tcp_client(lambda sock: client(sock, addr),
timeout=self.TIMEOUT): timeout=self.TIMEOUT):
await asyncio.wait_for( await asyncio.wait_for(
main(proto, on_con, on_eof, on_con_lost), main(proto, on_con, on_con_lost),
timeout=self.TIMEOUT) timeout=self.TIMEOUT)
server.close() server.close()
await server.wait_closed() await server.wait_closed()
self.assertEqual(answer, ANSWER)
self.loop.run_until_complete(run_main()) self.loop.run_until_complete(run_main())

View File

@ -0,0 +1 @@
Avoid TimeoutError in test_asyncio: test_start_tls_server_1()