bpo-35998: Avoid TimeoutError in test_asyncio: test_start_tls_server_1() (GH-14080)
This commit is contained in:
parent
431478d5d7
commit
f0749da9a5
|
@ -494,17 +494,14 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
|
||||||
|
|
||||||
def test_start_tls_server_1(self):
|
def test_start_tls_server_1(self):
|
||||||
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
|
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
|
||||||
|
ANSWER = b'answer'
|
||||||
|
|
||||||
server_context = test_utils.simple_server_sslcontext()
|
server_context = test_utils.simple_server_sslcontext()
|
||||||
client_context = test_utils.simple_client_sslcontext()
|
client_context = test_utils.simple_client_sslcontext()
|
||||||
if sys.platform.startswith('freebsd') or sys.platform.startswith('win'):
|
answer = None
|
||||||
# bpo-35031: Some FreeBSD and Windows buildbots fail to run this test
|
|
||||||
# as the eof was not being received by the server if the payload
|
|
||||||
# size is not big enough. This behaviour only appears if the
|
|
||||||
# client is using TLS1.3.
|
|
||||||
client_context.options |= ssl.OP_NO_TLSv1_3
|
|
||||||
|
|
||||||
def client(sock, addr):
|
def client(sock, addr):
|
||||||
|
nonlocal answer
|
||||||
sock.settimeout(self.TIMEOUT)
|
sock.settimeout(self.TIMEOUT)
|
||||||
|
|
||||||
sock.connect(addr)
|
sock.connect(addr)
|
||||||
|
@ -513,33 +510,36 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
|
||||||
|
|
||||||
sock.start_tls(client_context)
|
sock.start_tls(client_context)
|
||||||
sock.sendall(HELLO_MSG)
|
sock.sendall(HELLO_MSG)
|
||||||
|
answer = sock.recv_all(len(ANSWER))
|
||||||
sock.shutdown(socket.SHUT_RDWR)
|
|
||||||
sock.close()
|
sock.close()
|
||||||
|
|
||||||
class ServerProto(asyncio.Protocol):
|
class ServerProto(asyncio.Protocol):
|
||||||
def __init__(self, on_con, on_eof, on_con_lost):
|
def __init__(self, on_con, on_con_lost):
|
||||||
self.on_con = on_con
|
self.on_con = on_con
|
||||||
self.on_eof = on_eof
|
|
||||||
self.on_con_lost = on_con_lost
|
self.on_con_lost = on_con_lost
|
||||||
self.data = b''
|
self.data = b''
|
||||||
|
self.transport = None
|
||||||
|
|
||||||
def connection_made(self, tr):
|
def connection_made(self, tr):
|
||||||
|
self.transport = tr
|
||||||
self.on_con.set_result(tr)
|
self.on_con.set_result(tr)
|
||||||
|
|
||||||
|
def replace_transport(self, tr):
|
||||||
|
self.transport = tr
|
||||||
|
|
||||||
def data_received(self, data):
|
def data_received(self, data):
|
||||||
self.data += data
|
self.data += data
|
||||||
|
if len(self.data) >= len(HELLO_MSG):
|
||||||
def eof_received(self):
|
self.transport.write(ANSWER)
|
||||||
self.on_eof.set_result(1)
|
|
||||||
|
|
||||||
def connection_lost(self, exc):
|
def connection_lost(self, exc):
|
||||||
|
self.transport = None
|
||||||
if exc is None:
|
if exc is None:
|
||||||
self.on_con_lost.set_result(None)
|
self.on_con_lost.set_result(None)
|
||||||
else:
|
else:
|
||||||
self.on_con_lost.set_exception(exc)
|
self.on_con_lost.set_exception(exc)
|
||||||
|
|
||||||
async def main(proto, on_con, on_eof, on_con_lost):
|
async def main(proto, on_con, on_con_lost):
|
||||||
tr = await on_con
|
tr = await on_con
|
||||||
tr.write(HELLO_MSG)
|
tr.write(HELLO_MSG)
|
||||||
|
|
||||||
|
@ -550,16 +550,16 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
|
||||||
server_side=True,
|
server_side=True,
|
||||||
ssl_handshake_timeout=self.TIMEOUT)
|
ssl_handshake_timeout=self.TIMEOUT)
|
||||||
|
|
||||||
await on_eof
|
proto.replace_transport(new_tr)
|
||||||
|
|
||||||
await on_con_lost
|
await on_con_lost
|
||||||
self.assertEqual(proto.data, HELLO_MSG)
|
self.assertEqual(proto.data, HELLO_MSG)
|
||||||
new_tr.close()
|
new_tr.close()
|
||||||
|
|
||||||
async def run_main():
|
async def run_main():
|
||||||
on_con = self.loop.create_future()
|
on_con = self.loop.create_future()
|
||||||
on_eof = self.loop.create_future()
|
|
||||||
on_con_lost = self.loop.create_future()
|
on_con_lost = self.loop.create_future()
|
||||||
proto = ServerProto(on_con, on_eof, on_con_lost)
|
proto = ServerProto(on_con, on_con_lost)
|
||||||
|
|
||||||
server = await self.loop.create_server(
|
server = await self.loop.create_server(
|
||||||
lambda: proto, '127.0.0.1', 0)
|
lambda: proto, '127.0.0.1', 0)
|
||||||
|
@ -568,11 +568,12 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
|
||||||
with self.tcp_client(lambda sock: client(sock, addr),
|
with self.tcp_client(lambda sock: client(sock, addr),
|
||||||
timeout=self.TIMEOUT):
|
timeout=self.TIMEOUT):
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
main(proto, on_con, on_eof, on_con_lost),
|
main(proto, on_con, on_con_lost),
|
||||||
timeout=self.TIMEOUT)
|
timeout=self.TIMEOUT)
|
||||||
|
|
||||||
server.close()
|
server.close()
|
||||||
await server.wait_closed()
|
await server.wait_closed()
|
||||||
|
self.assertEqual(answer, ANSWER)
|
||||||
|
|
||||||
self.loop.run_until_complete(run_main())
|
self.loop.run_until_complete(run_main())
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Avoid TimeoutError in test_asyncio: test_start_tls_server_1()
|
Loading…
Reference in New Issue