import asyncio import asyncio.events import contextlib import os import pprint import select import socket import tempfile import threading class FunctionalTestCaseMixin: def new_loop(self): return asyncio.new_event_loop() def run_loop_briefly(self, *, delay=0.01): self.loop.run_until_complete(asyncio.sleep(delay)) def loop_exception_handler(self, loop, context): self.__unhandled_exceptions.append(context) self.loop.default_exception_handler(context) def setUp(self): self.loop = self.new_loop() asyncio.set_event_loop(None) self.loop.set_exception_handler(self.loop_exception_handler) self.__unhandled_exceptions = [] # Disable `_get_running_loop`. self._old_get_running_loop = asyncio.events._get_running_loop asyncio.events._get_running_loop = lambda: None def tearDown(self): try: self.loop.close() if self.__unhandled_exceptions: print('Unexpected calls to loop.call_exception_handler():') pprint.pprint(self.__unhandled_exceptions) self.fail('unexpected calls to loop.call_exception_handler()') finally: asyncio.events._get_running_loop = self._old_get_running_loop asyncio.set_event_loop(None) self.loop = None def tcp_server(self, server_prog, *, family=socket.AF_INET, addr=None, timeout=5, backlog=1, max_clients=10): if addr is None: if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX: with tempfile.NamedTemporaryFile() as tmp: addr = tmp.name else: addr = ('127.0.0.1', 0) sock = socket.create_server(addr, family=family, backlog=backlog) if timeout is None: raise RuntimeError('timeout is required') if timeout <= 0: raise RuntimeError('only blocking sockets are supported') sock.settimeout(timeout) return TestThreadedServer( self, sock, server_prog, timeout, max_clients) def tcp_client(self, client_prog, family=socket.AF_INET, timeout=10): sock = socket.socket(family, socket.SOCK_STREAM) if timeout is None: raise RuntimeError('timeout is required') if timeout <= 0: raise RuntimeError('only blocking sockets are supported') sock.settimeout(timeout) return TestThreadedClient( self, sock, client_prog, timeout) def unix_server(self, *args, **kwargs): if not hasattr(socket, 'AF_UNIX'): raise NotImplementedError return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs) def unix_client(self, *args, **kwargs): if not hasattr(socket, 'AF_UNIX'): raise NotImplementedError return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs) @contextlib.contextmanager def unix_sock_name(self): with tempfile.TemporaryDirectory() as td: fn = os.path.join(td, 'sock') try: yield fn finally: try: os.unlink(fn) except OSError: pass def _abort_socket_test(self, ex): try: self.loop.stop() finally: self.fail(ex) ############################################################################## # Socket Testing Utilities ############################################################################## class TestSocketWrapper: def __init__(self, sock): self.__sock = sock def recv_all(self, n): buf = b'' while len(buf) < n: data = self.recv(n - len(buf)) if data == b'': raise ConnectionAbortedError buf += data return buf def start_tls(self, ssl_context, *, server_side=False, server_hostname=None): ssl_sock = ssl_context.wrap_socket( self.__sock, server_side=server_side, server_hostname=server_hostname, do_handshake_on_connect=False) try: ssl_sock.do_handshake() except: ssl_sock.close() raise finally: self.__sock.close() self.__sock = ssl_sock def __getattr__(self, name): return getattr(self.__sock, name) def __repr__(self): return '<{} {!r}>'.format(type(self).__name__, self.__sock) class SocketThread(threading.Thread): def stop(self): self._active = False self.join() def __enter__(self): self.start() return self def __exit__(self, *exc): self.stop() class TestThreadedClient(SocketThread): def __init__(self, test, sock, prog, timeout): threading.Thread.__init__(self, None, None, 'test-client') self.daemon = True self._timeout = timeout self._sock = sock self._active = True self._prog = prog self._test = test def run(self): try: self._prog(TestSocketWrapper(self._sock)) except Exception as ex: self._test._abort_socket_test(ex) class TestThreadedServer(SocketThread): def __init__(self, test, sock, prog, timeout, max_clients): threading.Thread.__init__(self, None, None, 'test-server') self.daemon = True self._clients = 0 self._finished_clients = 0 self._max_clients = max_clients self._timeout = timeout self._sock = sock self._active = True self._prog = prog self._s1, self._s2 = socket.socketpair() self._s1.setblocking(False) self._test = test def stop(self): try: if self._s2 and self._s2.fileno() != -1: try: self._s2.send(b'stop') except OSError: pass finally: super().stop() def run(self): try: with self._sock: self._sock.setblocking(False) self._run() finally: self._s1.close() self._s2.close() def _run(self): while self._active: if self._clients >= self._max_clients: return r, w, x = select.select( [self._sock, self._s1], [], [], self._timeout) if self._s1 in r: return if self._sock in r: try: conn, addr = self._sock.accept() except BlockingIOError: continue except socket.timeout: if not self._active: return else: raise else: self._clients += 1 conn.settimeout(self._timeout) try: with conn: self._handle_client(conn) except Exception as ex: self._active = False try: raise finally: self._test._abort_socket_test(ex) def _handle_client(self, sock): self._prog(TestSocketWrapper(sock)) @property def addr(self): return self._sock.getsockname()