cpython/Lib/test/test_asyncio/test_server.py

345 lines
11 KiB
Python

import asyncio
import os
import socket
import time
import threading
import unittest
from test.support import socket_helper
from test.test_asyncio import utils as test_utils
from test.test_asyncio import functional as func_tests
def tearDownModule():
asyncio.set_event_loop_policy(None)
class BaseStartServer(func_tests.FunctionalTestCaseMixin):
def new_loop(self):
raise NotImplementedError
def test_start_server_1(self):
HELLO_MSG = b'1' * 1024 * 5 + b'\n'
def client(sock, addr):
for i in range(10):
time.sleep(0.2)
if srv.is_serving():
break
else:
raise RuntimeError
sock.settimeout(2)
sock.connect(addr)
sock.send(HELLO_MSG)
sock.recv_all(1)
sock.close()
async def serve(reader, writer):
await reader.readline()
main_task.cancel()
writer.write(b'1')
writer.close()
await writer.wait_closed()
async def main(srv):
async with srv:
await srv.serve_forever()
srv = self.loop.run_until_complete(asyncio.start_server(
serve, socket_helper.HOSTv4, 0, start_serving=False))
self.assertFalse(srv.is_serving())
main_task = self.loop.create_task(main(srv))
addr = srv.sockets[0].getsockname()
with self.assertRaises(asyncio.CancelledError):
with self.tcp_client(lambda sock: client(sock, addr)):
self.loop.run_until_complete(main_task)
self.assertEqual(srv.sockets, ())
self.assertIsNone(srv._sockets)
self.assertIsNone(srv._waiters)
self.assertFalse(srv.is_serving())
with self.assertRaisesRegex(RuntimeError, r'is closed'):
self.loop.run_until_complete(srv.serve_forever())
class SelectorStartServerTests(BaseStartServer, unittest.TestCase):
def new_loop(self):
return asyncio.SelectorEventLoop()
@socket_helper.skip_unless_bind_unix_socket
def test_start_unix_server_1(self):
HELLO_MSG = b'1' * 1024 * 5 + b'\n'
started = threading.Event()
def client(sock, addr):
sock.settimeout(2)
started.wait(5)
sock.connect(addr)
sock.send(HELLO_MSG)
sock.recv_all(1)
sock.close()
async def serve(reader, writer):
await reader.readline()
main_task.cancel()
writer.write(b'1')
writer.close()
await writer.wait_closed()
async def main(srv):
async with srv:
self.assertFalse(srv.is_serving())
await srv.start_serving()
self.assertTrue(srv.is_serving())
started.set()
await srv.serve_forever()
with test_utils.unix_socket_path() as addr:
srv = self.loop.run_until_complete(asyncio.start_unix_server(
serve, addr, start_serving=False))
main_task = self.loop.create_task(main(srv))
with self.assertRaises(asyncio.CancelledError):
with self.unix_client(lambda sock: client(sock, addr)):
self.loop.run_until_complete(main_task)
self.assertEqual(srv.sockets, ())
self.assertIsNone(srv._sockets)
self.assertIsNone(srv._waiters)
self.assertFalse(srv.is_serving())
with self.assertRaisesRegex(RuntimeError, r'is closed'):
self.loop.run_until_complete(srv.serve_forever())
class TestServer2(unittest.IsolatedAsyncioTestCase):
async def test_wait_closed_basic(self):
async def serve(rd, wr):
try:
await rd.read()
finally:
wr.close()
await wr.wait_closed()
srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
self.addCleanup(srv.close)
# active count = 0, not closed: should block
task1 = asyncio.create_task(srv.wait_closed())
await asyncio.sleep(0)
self.assertFalse(task1.done())
# active count != 0, not closed: should block
addr = srv.sockets[0].getsockname()
(rd, wr) = await asyncio.open_connection(addr[0], addr[1])
task2 = asyncio.create_task(srv.wait_closed())
await asyncio.sleep(0)
self.assertFalse(task1.done())
self.assertFalse(task2.done())
srv.close()
await asyncio.sleep(0)
# active count != 0, closed: should block
task3 = asyncio.create_task(srv.wait_closed())
await asyncio.sleep(0)
self.assertFalse(task1.done())
self.assertFalse(task2.done())
self.assertFalse(task3.done())
wr.close()
await wr.wait_closed()
# active count == 0, closed: should unblock
await task1
await task2
await task3
await srv.wait_closed() # Return immediately
async def test_wait_closed_race(self):
# Test a regression in 3.12.0, should be fixed in 3.12.1
async def serve(rd, wr):
try:
await rd.read()
finally:
wr.close()
await wr.wait_closed()
srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
self.addCleanup(srv.close)
task = asyncio.create_task(srv.wait_closed())
await asyncio.sleep(0)
self.assertFalse(task.done())
addr = srv.sockets[0].getsockname()
(rd, wr) = await asyncio.open_connection(addr[0], addr[1])
loop = asyncio.get_running_loop()
loop.call_soon(srv.close)
loop.call_soon(wr.close)
await srv.wait_closed()
async def test_close_clients(self):
async def serve(rd, wr):
try:
await rd.read()
finally:
wr.close()
await wr.wait_closed()
srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
self.addCleanup(srv.close)
addr = srv.sockets[0].getsockname()
(rd, wr) = await asyncio.open_connection(addr[0], addr[1])
self.addCleanup(wr.close)
task = asyncio.create_task(srv.wait_closed())
await asyncio.sleep(0)
self.assertFalse(task.done())
srv.close()
srv.close_clients()
await asyncio.sleep(0)
await asyncio.sleep(0)
self.assertTrue(task.done())
async def test_abort_clients(self):
async def serve(rd, wr):
fut.set_result((rd, wr))
await wr.wait_closed()
fut = asyncio.Future()
srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
self.addCleanup(srv.close)
addr = srv.sockets[0].getsockname()
(c_rd, c_wr) = await asyncio.open_connection(addr[0], addr[1], limit=4096)
self.addCleanup(c_wr.close)
(s_rd, s_wr) = await fut
# Limit the socket buffers so we can reliably overfill them
s_sock = s_wr.get_extra_info('socket')
s_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 65536)
c_sock = c_wr.get_extra_info('socket')
c_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 65536)
# Get the reader in to a paused state by sending more than twice
# the configured limit
s_wr.write(b'a' * 4096)
s_wr.write(b'a' * 4096)
s_wr.write(b'a' * 4096)
while c_wr.transport.is_reading():
await asyncio.sleep(0)
# Get the writer in a waiting state by sending data until the
# socket buffers are full on both server and client sockets and
# the kernel stops accepting more data
s_wr.write(b'a' * c_sock.getsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF))
s_wr.write(b'a' * s_sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF))
self.assertNotEqual(s_wr.transport.get_write_buffer_size(), 0)
task = asyncio.create_task(srv.wait_closed())
await asyncio.sleep(0)
self.assertFalse(task.done())
srv.close()
srv.abort_clients()
await asyncio.sleep(0)
await asyncio.sleep(0)
self.assertTrue(task.done())
# Test the various corner cases of Unix server socket removal
class UnixServerCleanupTests(unittest.IsolatedAsyncioTestCase):
@socket_helper.skip_unless_bind_unix_socket
async def test_unix_server_addr_cleanup(self):
# Default scenario
with test_utils.unix_socket_path() as addr:
async def serve(*args):
pass
srv = await asyncio.start_unix_server(serve, addr)
srv.close()
self.assertFalse(os.path.exists(addr))
@socket_helper.skip_unless_bind_unix_socket
async def test_unix_server_sock_cleanup(self):
# Using already bound socket
with test_utils.unix_socket_path() as addr:
async def serve(*args):
pass
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
sock.bind(addr)
srv = await asyncio.start_unix_server(serve, sock=sock)
srv.close()
self.assertFalse(os.path.exists(addr))
@socket_helper.skip_unless_bind_unix_socket
async def test_unix_server_cleanup_gone(self):
# Someone else has already cleaned up the socket
with test_utils.unix_socket_path() as addr:
async def serve(*args):
pass
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
sock.bind(addr)
srv = await asyncio.start_unix_server(serve, sock=sock)
os.unlink(addr)
srv.close()
@socket_helper.skip_unless_bind_unix_socket
async def test_unix_server_cleanup_replaced(self):
# Someone else has replaced the socket with their own
with test_utils.unix_socket_path() as addr:
async def serve(*args):
pass
srv = await asyncio.start_unix_server(serve, addr)
os.unlink(addr)
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
sock.bind(addr)
srv.close()
self.assertTrue(os.path.exists(addr))
@socket_helper.skip_unless_bind_unix_socket
async def test_unix_server_cleanup_prevented(self):
# Automatic cleanup explicitly disabled
with test_utils.unix_socket_path() as addr:
async def serve(*args):
pass
srv = await asyncio.start_unix_server(serve, addr, cleanup_socket=False)
srv.close()
self.assertTrue(os.path.exists(addr))
@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
class ProactorStartServerTests(BaseStartServer, unittest.TestCase):
def new_loop(self):
return asyncio.ProactorEventLoop()
if __name__ == '__main__':
unittest.main()