import asyncio import os import socket import time import threading import unittest from test.support import socket_helper from test.test_asyncio import utils as test_utils from test.test_asyncio import functional as func_tests def tearDownModule(): asyncio.set_event_loop_policy(None) class BaseStartServer(func_tests.FunctionalTestCaseMixin): def new_loop(self): raise NotImplementedError def test_start_server_1(self): HELLO_MSG = b'1' * 1024 * 5 + b'\n' def client(sock, addr): for i in range(10): time.sleep(0.2) if srv.is_serving(): break else: raise RuntimeError sock.settimeout(2) sock.connect(addr) sock.send(HELLO_MSG) sock.recv_all(1) sock.close() async def serve(reader, writer): await reader.readline() main_task.cancel() writer.write(b'1') writer.close() await writer.wait_closed() async def main(srv): async with srv: await srv.serve_forever() srv = self.loop.run_until_complete(asyncio.start_server( serve, socket_helper.HOSTv4, 0, start_serving=False)) self.assertFalse(srv.is_serving()) main_task = self.loop.create_task(main(srv)) addr = srv.sockets[0].getsockname() with self.assertRaises(asyncio.CancelledError): with self.tcp_client(lambda sock: client(sock, addr)): self.loop.run_until_complete(main_task) self.assertEqual(srv.sockets, ()) self.assertIsNone(srv._sockets) self.assertIsNone(srv._waiters) self.assertFalse(srv.is_serving()) with self.assertRaisesRegex(RuntimeError, r'is closed'): self.loop.run_until_complete(srv.serve_forever()) class SelectorStartServerTests(BaseStartServer, unittest.TestCase): def new_loop(self): return asyncio.SelectorEventLoop() @socket_helper.skip_unless_bind_unix_socket def test_start_unix_server_1(self): HELLO_MSG = b'1' * 1024 * 5 + b'\n' started = threading.Event() def client(sock, addr): sock.settimeout(2) started.wait(5) sock.connect(addr) sock.send(HELLO_MSG) sock.recv_all(1) sock.close() async def serve(reader, writer): await reader.readline() main_task.cancel() writer.write(b'1') writer.close() await writer.wait_closed() async def main(srv): async with srv: self.assertFalse(srv.is_serving()) await srv.start_serving() self.assertTrue(srv.is_serving()) started.set() await srv.serve_forever() with test_utils.unix_socket_path() as addr: srv = self.loop.run_until_complete(asyncio.start_unix_server( serve, addr, start_serving=False)) main_task = self.loop.create_task(main(srv)) with self.assertRaises(asyncio.CancelledError): with self.unix_client(lambda sock: client(sock, addr)): self.loop.run_until_complete(main_task) self.assertEqual(srv.sockets, ()) self.assertIsNone(srv._sockets) self.assertIsNone(srv._waiters) self.assertFalse(srv.is_serving()) with self.assertRaisesRegex(RuntimeError, r'is closed'): self.loop.run_until_complete(srv.serve_forever()) class TestServer2(unittest.IsolatedAsyncioTestCase): async def test_wait_closed_basic(self): async def serve(rd, wr): try: await rd.read() finally: wr.close() await wr.wait_closed() srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0) self.addCleanup(srv.close) # active count = 0, not closed: should block task1 = asyncio.create_task(srv.wait_closed()) await asyncio.sleep(0) self.assertFalse(task1.done()) # active count != 0, not closed: should block addr = srv.sockets[0].getsockname() (rd, wr) = await asyncio.open_connection(addr[0], addr[1]) task2 = asyncio.create_task(srv.wait_closed()) await asyncio.sleep(0) self.assertFalse(task1.done()) self.assertFalse(task2.done()) srv.close() await asyncio.sleep(0) # active count != 0, closed: should block task3 = asyncio.create_task(srv.wait_closed()) await asyncio.sleep(0) self.assertFalse(task1.done()) self.assertFalse(task2.done()) self.assertFalse(task3.done()) wr.close() await wr.wait_closed() # active count == 0, closed: should unblock await task1 await task2 await task3 await srv.wait_closed() # Return immediately async def test_wait_closed_race(self): # Test a regression in 3.12.0, should be fixed in 3.12.1 async def serve(rd, wr): try: await rd.read() finally: wr.close() await wr.wait_closed() srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0) self.addCleanup(srv.close) task = asyncio.create_task(srv.wait_closed()) await asyncio.sleep(0) self.assertFalse(task.done()) addr = srv.sockets[0].getsockname() (rd, wr) = await asyncio.open_connection(addr[0], addr[1]) loop = asyncio.get_running_loop() loop.call_soon(srv.close) loop.call_soon(wr.close) await srv.wait_closed() async def test_close_clients(self): async def serve(rd, wr): try: await rd.read() finally: wr.close() await wr.wait_closed() srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0) self.addCleanup(srv.close) addr = srv.sockets[0].getsockname() (rd, wr) = await asyncio.open_connection(addr[0], addr[1]) self.addCleanup(wr.close) task = asyncio.create_task(srv.wait_closed()) await asyncio.sleep(0) self.assertFalse(task.done()) srv.close() srv.close_clients() await asyncio.sleep(0) await asyncio.sleep(0) self.assertTrue(task.done()) async def test_abort_clients(self): async def serve(rd, wr): fut.set_result((rd, wr)) await wr.wait_closed() fut = asyncio.Future() srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0) self.addCleanup(srv.close) addr = srv.sockets[0].getsockname() (c_rd, c_wr) = await asyncio.open_connection(addr[0], addr[1], limit=4096) self.addCleanup(c_wr.close) (s_rd, s_wr) = await fut # Limit the socket buffers so we can more reliably overfill them s_sock = s_wr.get_extra_info('socket') s_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 65536) c_sock = c_wr.get_extra_info('socket') c_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 65536) # Get the reader in to a paused state by sending more than twice # the configured limit s_wr.write(b'a' * 4096) s_wr.write(b'a' * 4096) s_wr.write(b'a' * 4096) while c_wr.transport.is_reading(): await asyncio.sleep(0) # Get the writer in a waiting state by sending data until the # kernel stops accepting more data in the send buffer. # gh-122136: getsockopt() does not reliably report the buffer size # available for message content. # We loop until we start filling up the asyncio buffer. # To avoid an infinite loop we cap at 10 times the expected value c_bufsize = c_sock.getsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF) s_bufsize = s_sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF) for i in range(10): s_wr.write(b'a' * c_bufsize) s_wr.write(b'a' * s_bufsize) if s_wr.transport.get_write_buffer_size() > 0: break self.assertNotEqual(s_wr.transport.get_write_buffer_size(), 0) task = asyncio.create_task(srv.wait_closed()) await asyncio.sleep(0) self.assertFalse(task.done()) srv.close() srv.abort_clients() await asyncio.sleep(0) await asyncio.sleep(0) self.assertTrue(task.done()) # Test the various corner cases of Unix server socket removal class UnixServerCleanupTests(unittest.IsolatedAsyncioTestCase): @socket_helper.skip_unless_bind_unix_socket async def test_unix_server_addr_cleanup(self): # Default scenario with test_utils.unix_socket_path() as addr: async def serve(*args): pass srv = await asyncio.start_unix_server(serve, addr) srv.close() self.assertFalse(os.path.exists(addr)) @socket_helper.skip_unless_bind_unix_socket async def test_unix_server_sock_cleanup(self): # Using already bound socket with test_utils.unix_socket_path() as addr: async def serve(*args): pass with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: sock.bind(addr) srv = await asyncio.start_unix_server(serve, sock=sock) srv.close() self.assertFalse(os.path.exists(addr)) @socket_helper.skip_unless_bind_unix_socket async def test_unix_server_cleanup_gone(self): # Someone else has already cleaned up the socket with test_utils.unix_socket_path() as addr: async def serve(*args): pass with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: sock.bind(addr) srv = await asyncio.start_unix_server(serve, sock=sock) os.unlink(addr) srv.close() @socket_helper.skip_unless_bind_unix_socket async def test_unix_server_cleanup_replaced(self): # Someone else has replaced the socket with their own with test_utils.unix_socket_path() as addr: async def serve(*args): pass srv = await asyncio.start_unix_server(serve, addr) os.unlink(addr) with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: sock.bind(addr) srv.close() self.assertTrue(os.path.exists(addr)) @socket_helper.skip_unless_bind_unix_socket async def test_unix_server_cleanup_prevented(self): # Automatic cleanup explicitly disabled with test_utils.unix_socket_path() as addr: async def serve(*args): pass srv = await asyncio.start_unix_server(serve, addr, cleanup_socket=False) srv.close() self.assertTrue(os.path.exists(addr)) @unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only') class ProactorStartServerTests(BaseStartServer, unittest.TestCase): def new_loop(self): return asyncio.ProactorEventLoop() if __name__ == '__main__': unittest.main()