diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index c1b9dc95ee6..e484746432a 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1507,10 +1507,14 @@ os.close(fd) def test_stream_server_abort(self): server_stream_aborted = False - fut = self.loop.create_future() + fut1 = self.loop.create_future() + fut2 = self.loop.create_future() async def handle_client(stream): - await fut + data = await stream.readexactly(4) + self.assertEqual(b'data', data) + fut1.set_result(None) + await fut2 self.assertEqual(b'', await stream.readline()) nonlocal server_stream_aborted server_stream_aborted = True @@ -1518,7 +1522,8 @@ os.close(fd) async def client(srv): addr = srv.sockets[0].getsockname() stream = await asyncio.connect(*addr) - fut.set_result(None) + await stream.write(b'data') + await fut2 self.assertEqual(b'', await stream.readline()) await stream.close() @@ -1526,43 +1531,9 @@ os.close(fd) async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as server: await server.start_serving() task = asyncio.create_task(client(server)) - await fut - await server.abort() - await task - - messages = [] - self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) - self.loop.run_until_complete(test()) - self.assertEqual(messages, []) - self.assertTrue(fut.done()) - self.assertTrue(server_stream_aborted) - - def test_stream_shutdown_hung_task(self): - fut1 = self.loop.create_future() - fut2 = self.loop.create_future() - - async def handle_client(stream): - while True: - await asyncio.sleep(0.01) - - async def client(srv): - addr = srv.sockets[0].getsockname() - stream = await asyncio.connect(*addr) - fut1.set_result(None) - await fut2 - self.assertEqual(b'', await stream.readline()) - await stream.close() - - async def test(): - async with asyncio.StreamServer(handle_client, - '127.0.0.1', - 0, - shutdown_timeout=0.3) as server: - await server.start_serving() - task = asyncio.create_task(client(server)) await fut1 - await server.close() fut2.set_result(None) + await server.abort() await task messages = [] @@ -1571,21 +1542,29 @@ os.close(fd) self.assertEqual(messages, []) self.assertTrue(fut1.done()) self.assertTrue(fut2.done()) + self.assertTrue(server_stream_aborted) - def test_stream_shutdown_hung_task_prevents_cancellation(self): + def test_stream_shutdown_hung_task(self): fut1 = self.loop.create_future() fut2 = self.loop.create_future() - do_handle_client = True + cancelled = self.loop.create_future() async def handle_client(stream): - while do_handle_client: - with contextlib.suppress(asyncio.CancelledError): + data = await stream.readexactly(4) + self.assertEqual(b'data', data) + fut1.set_result(None) + await fut2 + try: + while True: await asyncio.sleep(0.01) + except asyncio.CancelledError: + cancelled.set_result(None) + raise async def client(srv): addr = srv.sockets[0].getsockname() stream = await asyncio.connect(*addr) - fut1.set_result(None) + await stream.write(b'data') await fut2 self.assertEqual(b'', await stream.readline()) await stream.close() @@ -1598,11 +1577,57 @@ os.close(fd) await server.start_serving() task = asyncio.create_task(client(server)) await fut1 + fut2.set_result(None) + await server.close() + await task + await cancelled + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + self.assertTrue(fut1.done()) + self.assertTrue(fut2.done()) + self.assertTrue(cancelled.done()) + + def test_stream_shutdown_hung_task_prevents_cancellation(self): + fut1 = self.loop.create_future() + fut2 = self.loop.create_future() + cancelled = self.loop.create_future() + do_handle_client = True + + async def handle_client(stream): + data = await stream.readexactly(4) + self.assertEqual(b'data', data) + fut1.set_result(None) + await fut2 + while do_handle_client: + with contextlib.suppress(asyncio.CancelledError): + await asyncio.sleep(0.01) + cancelled.set_result(None) + + async def client(srv): + addr = srv.sockets[0].getsockname() + stream = await asyncio.connect(*addr) + await stream.write(b'data') + await fut2 + self.assertEqual(b'', await stream.readline()) + await stream.close() + + async def test(): + async with asyncio.StreamServer(handle_client, + '127.0.0.1', + 0, + shutdown_timeout=0.3) as server: + await server.start_serving() + task = asyncio.create_task(client(server)) + await fut1 + fut2.set_result(None) await server.close() nonlocal do_handle_client do_handle_client = False - fut2.set_result(None) await task + await cancelled messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) @@ -1612,6 +1637,7 @@ os.close(fd) "