diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 0c49099bc49..d1f8aef4bb9 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -566,46 +566,10 @@ class StreamTests(test_utils.TestCase): test_utils.run_briefly(self.loop) self.assertIs(stream._waiter, None) - def test_start_server(self): - class MyServer: +class NewStreamTests(unittest.IsolatedAsyncioTestCase): - def __init__(self, loop): - self.server = None - self.loop = loop - - async def handle_client(self, client_reader, client_writer): - data = await client_reader.readline() - client_writer.write(data) - await client_writer.drain() - client_writer.close() - await client_writer.wait_closed() - - def start(self): - sock = socket.create_server(('127.0.0.1', 0)) - self.server = self.loop.run_until_complete( - asyncio.start_server(self.handle_client, - sock=sock)) - return sock.getsockname() - - def handle_client_callback(self, client_reader, client_writer): - self.loop.create_task(self.handle_client(client_reader, - client_writer)) - - def start_callback(self): - sock = socket.create_server(('127.0.0.1', 0)) - addr = sock.getsockname() - sock.close() - self.server = self.loop.run_until_complete( - asyncio.start_server(self.handle_client_callback, - host=addr[0], port=addr[1])) - return addr - - def stop(self): - if self.server is not None: - self.server.close() - self.loop.run_until_complete(self.server.wait_closed()) - self.server = None + async def test_start_server(self): async def client(addr): reader, writer = await asyncio.open_connection(*addr) @@ -617,61 +581,43 @@ class StreamTests(test_utils.TestCase): await writer.wait_closed() return msgback - messages = [] - self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + async def handle_client(client_reader, client_writer): + data = await client_reader.readline() + client_writer.write(data) + await client_writer.drain() + client_writer.close() + await client_writer.wait_closed() - # test the server variant with a coroutine as client handler - server = MyServer(self.loop) - addr = server.start() - msg = self.loop.run_until_complete(self.loop.create_task(client(addr))) - server.stop() - self.assertEqual(msg, b"hello world!\n") + with self.subTest(msg="coroutine"): + server = await asyncio.start_server( + handle_client, + host=socket_helper.HOSTv4 + ) + addr = server.sockets[0].getsockname() + msg = await client(addr) + server.close() + await server.wait_closed() + self.assertEqual(msg, b"hello world!\n") - # test the server variant with a callback as client handler - server = MyServer(self.loop) - addr = server.start_callback() - msg = self.loop.run_until_complete(self.loop.create_task(client(addr))) - server.stop() - self.assertEqual(msg, b"hello world!\n") + with self.subTest(msg="callback"): + async def handle_client_callback(client_reader, client_writer): + asyncio.get_running_loop().create_task( + handle_client(client_reader, client_writer) + ) - self.assertEqual(messages, []) + server = await asyncio.start_server( + handle_client_callback, + host=socket_helper.HOSTv4 + ) + addr = server.sockets[0].getsockname() + reader, writer = await asyncio.open_connection(*addr) + msg = await client(addr) + server.close() + await server.wait_closed() + self.assertEqual(msg, b"hello world!\n") @socket_helper.skip_unless_bind_unix_socket - def test_start_unix_server(self): - - class MyServer: - - def __init__(self, loop, path): - self.server = None - self.loop = loop - self.path = path - - async def handle_client(self, client_reader, client_writer): - data = await client_reader.readline() - client_writer.write(data) - await client_writer.drain() - client_writer.close() - await client_writer.wait_closed() - - def start(self): - self.server = self.loop.run_until_complete( - asyncio.start_unix_server(self.handle_client, - path=self.path)) - - def handle_client_callback(self, client_reader, client_writer): - self.loop.create_task(self.handle_client(client_reader, - client_writer)) - - def start_callback(self): - start = asyncio.start_unix_server(self.handle_client_callback, - path=self.path) - self.server = self.loop.run_until_complete(start) - - def stop(self): - if self.server is not None: - self.server.close() - self.loop.run_until_complete(self.server.wait_closed()) - self.server = None + async def test_start_unix_server(self): async def client(path): reader, writer = await asyncio.open_unix_connection(path) @@ -683,64 +629,42 @@ class StreamTests(test_utils.TestCase): await writer.wait_closed() return msgback - messages = [] - self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + async def handle_client(client_reader, client_writer): + data = await client_reader.readline() + client_writer.write(data) + await client_writer.drain() + client_writer.close() + await client_writer.wait_closed() - # test the server variant with a coroutine as client handler - with test_utils.unix_socket_path() as path: - server = MyServer(self.loop, path) - server.start() - msg = self.loop.run_until_complete( - self.loop.create_task(client(path))) - server.stop() - self.assertEqual(msg, b"hello world!\n") + with self.subTest(msg="coroutine"): + with test_utils.unix_socket_path() as path: + server = await asyncio.start_unix_server( + handle_client, + path=path + ) + msg = await client(path) + server.close() + await server.wait_closed() + self.assertEqual(msg, b"hello world!\n") - # test the server variant with a callback as client handler - with test_utils.unix_socket_path() as path: - server = MyServer(self.loop, path) - server.start_callback() - msg = self.loop.run_until_complete( - self.loop.create_task(client(path))) - server.stop() - self.assertEqual(msg, b"hello world!\n") + with self.subTest(msg="callback"): + async def handle_client_callback(client_reader, client_writer): + asyncio.get_running_loop().create_task( + handle_client(client_reader, client_writer) + ) - self.assertEqual(messages, []) + with test_utils.unix_socket_path() as path: + server = await asyncio.start_unix_server( + handle_client_callback, + path=path + ) + msg = await client(path) + server.close() + await server.wait_closed() + self.assertEqual(msg, b"hello world!\n") @unittest.skipIf(ssl is None, 'No ssl module') - def test_start_tls(self): - - class MyServer: - - def __init__(self, loop): - self.server = None - self.loop = loop - - async def handle_client(self, client_reader, client_writer): - data1 = await client_reader.readline() - client_writer.write(data1) - await client_writer.drain() - assert client_writer.get_extra_info('sslcontext') is None - await client_writer.start_tls( - test_utils.simple_server_sslcontext()) - assert client_writer.get_extra_info('sslcontext') is not None - data2 = await client_reader.readline() - client_writer.write(data2) - await client_writer.drain() - client_writer.close() - await client_writer.wait_closed() - - def start(self): - sock = socket.create_server(('127.0.0.1', 0)) - self.server = self.loop.run_until_complete( - asyncio.start_server(self.handle_client, - sock=sock)) - return sock.getsockname() - - def stop(self): - if self.server is not None: - self.server.close() - self.loop.run_until_complete(self.server.wait_closed()) - self.server = None + async def test_start_tls(self): async def client(addr): reader, writer = await asyncio.open_connection(*addr) @@ -757,18 +681,49 @@ class StreamTests(test_utils.TestCase): await writer.wait_closed() return msgback1, msgback2 - messages = [] - self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + async def handle_client(client_reader, client_writer): + data1 = await client_reader.readline() + client_writer.write(data1) + await client_writer.drain() + assert client_writer.get_extra_info('sslcontext') is None + await client_writer.start_tls( + test_utils.simple_server_sslcontext()) + assert client_writer.get_extra_info('sslcontext') is not None - server = MyServer(self.loop) - addr = server.start() - msg1, msg2 = self.loop.run_until_complete(client(addr)) - server.stop() + data2 = await client_reader.readline() + client_writer.write(data2) + await client_writer.drain() + client_writer.close() + await client_writer.wait_closed() - self.assertEqual(messages, []) + server = await asyncio.start_server( + handle_client, + host=socket_helper.HOSTv4 + ) + addr = server.sockets[0].getsockname() + + msg1, msg2 = await client(addr) + server.close() + await server.wait_closed() self.assertEqual(msg1, b"hello world 1!\n") self.assertEqual(msg2, b"hello world 2!\n") + +class StreamTests2(test_utils.TestCase): + + def setUp(self): + super().setUp() + self.loop = asyncio.new_event_loop() + self.set_event_loop(self.loop) + + def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + super().tearDown() + @unittest.skipIf(sys.platform == 'win32', "Don't have pipes") def test_read_all_from_pipe_reader(self): # See asyncio issue 168. This test is derived from the example @@ -986,22 +941,20 @@ os.close(fd) self.assertEqual(str(e), str(e2)) self.assertEqual(e.consumed, e2.consumed) - def test_wait_closed_on_close(self): - with test_utils.run_test_server() as httpd: + async def test_wait_closed_on_close(self): + async with test_utils.run_test_server() as httpd: rd, wr = self.loop.run_until_complete( asyncio.open_connection(*httpd.address)) wr.write(b'GET / HTTP/1.0\r\n\r\n') - f = rd.readline() - data = self.loop.run_until_complete(f) + data = await rd.readline() self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') - f = rd.read() - data = self.loop.run_until_complete(f) + await rd.read() self.assertTrue(data.endswith(b'\r\n\r\nTest message')) self.assertFalse(wr.is_closing()) wr.close() self.assertTrue(wr.is_closing()) - self.loop.run_until_complete(wr.wait_closed()) + await wr.wait_closed() def test_wait_closed_on_close_with_unread_data(self): with test_utils.run_test_server() as httpd: @@ -1057,15 +1010,10 @@ os.close(fd) self.assertEqual(messages, []) - def test_eof_feed_when_closing_writer(self): + async def test_eof_feed_when_closing_writer(self): # See http://bugs.python.org/issue35065 - messages = [] - self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) - - with test_utils.run_test_server() as httpd: - rd, wr = self.loop.run_until_complete( - asyncio.open_connection(*httpd.address)) - + async with test_utils.run_test_server() as httpd: + rd, wr = await asyncio.open_connection(*httpd.address) wr.close() f = wr.wait_closed() self.loop.run_until_complete(f) @@ -1074,8 +1022,6 @@ os.close(fd) data = self.loop.run_until_complete(f) self.assertEqual(data, b'') - self.assertEqual(messages, []) - if __name__ == '__main__': unittest.main()