""" Test suite for socketserver. """ import contextlib import io import os import select import signal import socket import tempfile import unittest import socketserver import test.support from test.support import reap_children, reap_threads, verbose try: import threading except ImportError: threading = None test.support.requires("network") TEST_STR = b"hello world\n" HOST = test.support.HOST HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX") requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS, 'requires Unix sockets') HAVE_FORKING = hasattr(os, "fork") requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking') def signal_alarm(n): """Call signal.alarm when it exists (i.e. not on Windows).""" if hasattr(signal, 'alarm'): signal.alarm(n) # Remember real select() to avoid interferences with mocking _real_select = select.select def receive(sock, n, timeout=20): r, w, x = _real_select([sock], [], [], timeout) if sock in r: return sock.recv(n) else: raise RuntimeError("timed out on %r" % (sock,)) if HAVE_UNIX_SOCKETS and HAVE_FORKING: class ForkingUnixStreamServer(socketserver.ForkingMixIn, socketserver.UnixStreamServer): pass class ForkingUnixDatagramServer(socketserver.ForkingMixIn, socketserver.UnixDatagramServer): pass @contextlib.contextmanager def simple_subprocess(testcase): """Tests that a custom child process is not waited on (Issue 1540386)""" pid = os.fork() if pid == 0: # Don't raise an exception; it would be caught by the test harness. os._exit(72) yield None pid2, status = os.waitpid(pid, 0) testcase.assertEqual(pid2, pid) testcase.assertEqual(72 << 8, status) def close_server(server): server.server_close() if hasattr(server, 'active_children'): # ForkingMixIn: Manually reap all child processes, since server_close() # calls waitpid() in non-blocking mode using the WNOHANG flag. for pid in server.active_children.copy(): try: os.waitpid(pid, 0) except ChildProcessError: pass server.active_children.clear() @unittest.skipUnless(threading, 'Threading required for this test.') class SocketServerTest(unittest.TestCase): """Test all socket servers.""" def setUp(self): signal_alarm(60) # Kill deadlocks after 60 seconds. self.port_seed = 0 self.test_files = [] def tearDown(self): signal_alarm(0) # Didn't deadlock. reap_children() for fn in self.test_files: try: os.remove(fn) except OSError: pass self.test_files[:] = [] def pickaddr(self, proto): if proto == socket.AF_INET: return (HOST, 0) else: # XXX: We need a way to tell AF_UNIX to pick its own name # like AF_INET provides port==0. dir = None fn = tempfile.mktemp(prefix='unix_socket.', dir=dir) self.test_files.append(fn) return fn def make_server(self, addr, svrcls, hdlrbase): class MyServer(svrcls): def handle_error(self, request, client_address): self.close_request(request) raise class MyHandler(hdlrbase): def handle(self): line = self.rfile.readline() self.wfile.write(line) if verbose: print("creating server") server = MyServer(addr, MyHandler) self.assertEqual(server.server_address, server.socket.getsockname()) return server @reap_threads def run_server(self, svrcls, hdlrbase, testfunc): server = self.make_server(self.pickaddr(svrcls.address_family), svrcls, hdlrbase) # We had the OS pick a port, so pull the real address out of # the server. addr = server.server_address if verbose: print("ADDR =", addr) print("CLASS =", svrcls) t = threading.Thread( name='%s serving' % svrcls, target=server.serve_forever, # Short poll interval to make the test finish quickly. # Time between requests is short enough that we won't wake # up spuriously too many times. kwargs={'poll_interval':0.01}) t.daemon = True # In case this function raises. t.start() if verbose: print("server running") for i in range(3): if verbose: print("test client", i) testfunc(svrcls.address_family, addr) if verbose: print("waiting for server") server.shutdown() t.join() close_server(server) self.assertEqual(-1, server.socket.fileno()) if verbose: print("done") def stream_examine(self, proto, addr): s = socket.socket(proto, socket.SOCK_STREAM) s.connect(addr) s.sendall(TEST_STR) buf = data = receive(s, 100) while data and b'\n' not in buf: data = receive(s, 100) buf += data self.assertEqual(buf, TEST_STR) s.close() def dgram_examine(self, proto, addr): s = socket.socket(proto, socket.SOCK_DGRAM) if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX: s.bind(self.pickaddr(proto)) s.sendto(TEST_STR, addr) buf = data = receive(s, 100) while data and b'\n' not in buf: data = receive(s, 100) buf += data self.assertEqual(buf, TEST_STR) s.close() def test_TCPServer(self): self.run_server(socketserver.TCPServer, socketserver.StreamRequestHandler, self.stream_examine) def test_ThreadingTCPServer(self): self.run_server(socketserver.ThreadingTCPServer, socketserver.StreamRequestHandler, self.stream_examine) @requires_forking def test_ForkingTCPServer(self): with simple_subprocess(self): self.run_server(socketserver.ForkingTCPServer, socketserver.StreamRequestHandler, self.stream_examine) @requires_unix_sockets def test_UnixStreamServer(self): self.run_server(socketserver.UnixStreamServer, socketserver.StreamRequestHandler, self.stream_examine) @requires_unix_sockets def test_ThreadingUnixStreamServer(self): self.run_server(socketserver.ThreadingUnixStreamServer, socketserver.StreamRequestHandler, self.stream_examine) @requires_unix_sockets @requires_forking def test_ForkingUnixStreamServer(self): with simple_subprocess(self): self.run_server(ForkingUnixStreamServer, socketserver.StreamRequestHandler, self.stream_examine) def test_UDPServer(self): self.run_server(socketserver.UDPServer, socketserver.DatagramRequestHandler, self.dgram_examine) def test_ThreadingUDPServer(self): self.run_server(socketserver.ThreadingUDPServer, socketserver.DatagramRequestHandler, self.dgram_examine) @requires_forking def test_ForkingUDPServer(self): with simple_subprocess(self): self.run_server(socketserver.ForkingUDPServer, socketserver.DatagramRequestHandler, self.dgram_examine) @requires_unix_sockets def test_UnixDatagramServer(self): self.run_server(socketserver.UnixDatagramServer, socketserver.DatagramRequestHandler, self.dgram_examine) @requires_unix_sockets def test_ThreadingUnixDatagramServer(self): self.run_server(socketserver.ThreadingUnixDatagramServer, socketserver.DatagramRequestHandler, self.dgram_examine) @requires_unix_sockets @requires_forking def test_ForkingUnixDatagramServer(self): self.run_server(ForkingUnixDatagramServer, socketserver.DatagramRequestHandler, self.dgram_examine) @reap_threads def test_shutdown(self): # Issue #2302: shutdown() should always succeed in making an # other thread leave serve_forever(). class MyServer(socketserver.TCPServer): pass class MyHandler(socketserver.StreamRequestHandler): pass threads = [] for i in range(20): s = MyServer((HOST, 0), MyHandler) t = threading.Thread( name='MyServer serving', target=s.serve_forever, kwargs={'poll_interval':0.01}) t.daemon = True # In case this function raises. threads.append((t, s)) for t, s in threads: t.start() s.shutdown() for t, s in threads: t.join() close_server(s) def test_tcpserver_bind_leak(self): # Issue #22435: the server socket wouldn't be closed if bind()/listen() # failed. # Create many servers for which bind() will fail, to see if this result # in FD exhaustion. for i in range(1024): with self.assertRaises(OverflowError): socketserver.TCPServer((HOST, -1), socketserver.StreamRequestHandler) def test_context_manager(self): with socketserver.TCPServer((HOST, 0), socketserver.StreamRequestHandler) as server: pass self.assertEqual(-1, server.socket.fileno()) class ErrorHandlerTest(unittest.TestCase): """Test that the servers pass normal exceptions from the handler to handle_error(), and that exiting exceptions like SystemExit and KeyboardInterrupt are not passed.""" def tearDown(self): test.support.unlink(test.support.TESTFN) reap_children() def test_sync_handled(self): BaseErrorTestServer(ValueError) self.check_result(handled=True) def test_sync_not_handled(self): with self.assertRaises(SystemExit): BaseErrorTestServer(SystemExit) self.check_result(handled=False) @unittest.skipUnless(threading, 'Threading required for this test.') def test_threading_handled(self): ThreadingErrorTestServer(ValueError) self.check_result(handled=True) @unittest.skipUnless(threading, 'Threading required for this test.') def test_threading_not_handled(self): ThreadingErrorTestServer(SystemExit) self.check_result(handled=False) @requires_forking def test_forking_handled(self): ForkingErrorTestServer(ValueError) self.check_result(handled=True) @requires_forking def test_forking_not_handled(self): ForkingErrorTestServer(SystemExit) self.check_result(handled=False) def check_result(self, handled): with open(test.support.TESTFN) as log: expected = 'Handler called\n' + 'Error handled\n' * handled self.assertEqual(log.read(), expected) class BaseErrorTestServer(socketserver.TCPServer): def __init__(self, exception): self.exception = exception super().__init__((HOST, 0), BadHandler) with socket.create_connection(self.server_address): pass try: self.handle_request() finally: close_server(self) self.wait_done() def handle_error(self, request, client_address): with open(test.support.TESTFN, 'a') as log: log.write('Error handled\n') def wait_done(self): pass class BadHandler(socketserver.BaseRequestHandler): def handle(self): with open(test.support.TESTFN, 'a') as log: log.write('Handler called\n') raise self.server.exception('Test error') class ThreadingErrorTestServer(socketserver.ThreadingMixIn, BaseErrorTestServer): def __init__(self, *pos, **kw): self.done = threading.Event() super().__init__(*pos, **kw) def shutdown_request(self, *pos, **kw): super().shutdown_request(*pos, **kw) self.done.set() def wait_done(self): self.done.wait() if HAVE_FORKING: class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer): pass class SocketWriterTest(unittest.TestCase): def test_basics(self): class Handler(socketserver.StreamRequestHandler): def handle(self): self.server.wfile = self.wfile self.server.wfile_fileno = self.wfile.fileno() self.server.request_fileno = self.request.fileno() server = socketserver.TCPServer((HOST, 0), Handler) self.addCleanup(close_server, server) s = socket.socket( server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP) with s: s.connect(server.server_address) server.handle_request() self.assertIsInstance(server.wfile, io.BufferedIOBase) self.assertEqual(server.wfile_fileno, server.request_fileno) @unittest.skipUnless(threading, 'Threading required for this test.') def test_write(self): # Test that wfile.write() sends data immediately, and that it does # not truncate sends when interrupted by a Unix signal pthread_kill = test.support.get_attribute(signal, 'pthread_kill') class Handler(socketserver.StreamRequestHandler): def handle(self): self.server.sent1 = self.wfile.write(b'write data\n') # Should be sent immediately, without requiring flush() self.server.received = self.rfile.readline() big_chunk = b'\0' * test.support.SOCK_MAX_SIZE self.server.sent2 = self.wfile.write(big_chunk) server = socketserver.TCPServer((HOST, 0), Handler) self.addCleanup(close_server, server) interrupted = threading.Event() def signal_handler(signum, frame): interrupted.set() original = signal.signal(signal.SIGUSR1, signal_handler) self.addCleanup(signal.signal, signal.SIGUSR1, original) response1 = None received2 = None main_thread = threading.get_ident() def run_client(): s = socket.socket(server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP) with s, s.makefile('rb') as reader: s.connect(server.server_address) nonlocal response1 response1 = reader.readline() s.sendall(b'client response\n') reader.read(100) # The main thread should now be blocking in a send() syscall. # But in theory, it could get interrupted by other signals, # and then retried. So keep sending the signal in a loop, in # case an earlier signal happens to be delivered at an # inconvenient moment. while True: pthread_kill(main_thread, signal.SIGUSR1) if interrupted.wait(timeout=float(1)): break nonlocal received2 received2 = len(reader.read()) background = threading.Thread(target=run_client) background.start() server.handle_request() background.join() self.assertEqual(server.sent1, len(response1)) self.assertEqual(response1, b'write data\n') self.assertEqual(server.received, b'client response\n') self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE) self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100) class MiscTestCase(unittest.TestCase): def test_all(self): # objects defined in the module should be in __all__ expected = [] for name in dir(socketserver): if not name.startswith('_'): mod_object = getattr(socketserver, name) if getattr(mod_object, '__module__', None) == 'socketserver': expected.append(name) self.assertCountEqual(socketserver.__all__, expected) def test_shutdown_request_called_if_verify_request_false(self): # Issue #26309: BaseServer should call shutdown_request even if # verify_request is False class MyServer(socketserver.TCPServer): def verify_request(self, request, client_address): return False shutdown_called = 0 def shutdown_request(self, request): self.shutdown_called += 1 socketserver.TCPServer.shutdown_request(self, request) server = MyServer((HOST, 0), socketserver.StreamRequestHandler) s = socket.socket(server.address_family, socket.SOCK_STREAM) s.connect(server.server_address) s.close() server.handle_request() self.assertEqual(server.shutdown_called, 1) close_server(server) if __name__ == "__main__": unittest.main()