""" Test suite for socketserver. """ import _imp as imp import contextlib import os import select import signal import socket import select import errno import tempfile import unittest import socketserver import test.support from test.support import reap_children, reap_threads, verbose try: import threading except ImportError: threading = None test.support.requires("network") TEST_STR = b"hello world\n" HOST = test.support.HOST HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX") requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS, 'requires Unix sockets') HAVE_FORKING = hasattr(os, "fork") requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking') def signal_alarm(n): """Call signal.alarm when it exists (i.e. not on Windows).""" if hasattr(signal, 'alarm'): signal.alarm(n) # Remember real select() to avoid interferences with mocking _real_select = select.select def receive(sock, n, timeout=20): r, w, x = _real_select([sock], [], [], timeout) if sock in r: return sock.recv(n) else: raise RuntimeError("timed out on %r" % (sock,)) if HAVE_UNIX_SOCKETS: class ForkingUnixStreamServer(socketserver.ForkingMixIn, socketserver.UnixStreamServer): pass class ForkingUnixDatagramServer(socketserver.ForkingMixIn, socketserver.UnixDatagramServer): pass @contextlib.contextmanager def simple_subprocess(testcase): pid = os.fork() if pid == 0: # Don't raise an exception; it would be caught by the test harness. os._exit(72) yield None pid2, status = os.waitpid(pid, 0) testcase.assertEqual(pid2, pid) testcase.assertEqual(72 << 8, status) @unittest.skipUnless(threading, 'Threading required for this test.') class SocketServerTest(unittest.TestCase): """Test all socket servers.""" def setUp(self): signal_alarm(60) # Kill deadlocks after 60 seconds. self.port_seed = 0 self.test_files = [] def tearDown(self): signal_alarm(0) # Didn't deadlock. reap_children() for fn in self.test_files: try: os.remove(fn) except OSError: pass self.test_files[:] = [] def pickaddr(self, proto): if proto == socket.AF_INET: return (HOST, 0) else: # XXX: We need a way to tell AF_UNIX to pick its own name # like AF_INET provides port==0. dir = None fn = tempfile.mktemp(prefix='unix_socket.', dir=dir) self.test_files.append(fn) return fn def make_server(self, addr, svrcls, hdlrbase): class MyServer(svrcls): def handle_error(self, request, client_address): self.close_request(request) self.server_close() raise class MyHandler(hdlrbase): def handle(self): line = self.rfile.readline() self.wfile.write(line) if verbose: print("creating server") server = MyServer(addr, MyHandler) self.assertEqual(server.server_address, server.socket.getsockname()) return server @reap_threads def run_server(self, svrcls, hdlrbase, testfunc): server = self.make_server(self.pickaddr(svrcls.address_family), svrcls, hdlrbase) # We had the OS pick a port, so pull the real address out of # the server. addr = server.server_address if verbose: print("ADDR =", addr) print("CLASS =", svrcls) t = threading.Thread( name='%s serving' % svrcls, target=server.serve_forever, # Short poll interval to make the test finish quickly. # Time between requests is short enough that we won't wake # up spuriously too many times. kwargs={'poll_interval':0.01}) t.daemon = True # In case this function raises. t.start() if verbose: print("server running") for i in range(3): if verbose: print("test client", i) testfunc(svrcls.address_family, addr) if verbose: print("waiting for server") server.shutdown() t.join() server.server_close() if verbose: print("done") def stream_examine(self, proto, addr): s = socket.socket(proto, socket.SOCK_STREAM) s.connect(addr) s.sendall(TEST_STR) buf = data = receive(s, 100) while data and b'\n' not in buf: data = receive(s, 100) buf += data self.assertEqual(buf, TEST_STR) s.close() def dgram_examine(self, proto, addr): s = socket.socket(proto, socket.SOCK_DGRAM) s.sendto(TEST_STR, addr) buf = data = receive(s, 100) while data and b'\n' not in buf: data = receive(s, 100) buf += data self.assertEqual(buf, TEST_STR) s.close() def test_TCPServer(self): self.run_server(socketserver.TCPServer, socketserver.StreamRequestHandler, self.stream_examine) def test_ThreadingTCPServer(self): self.run_server(socketserver.ThreadingTCPServer, socketserver.StreamRequestHandler, self.stream_examine) @requires_forking def test_ForkingTCPServer(self): with simple_subprocess(self): self.run_server(socketserver.ForkingTCPServer, socketserver.StreamRequestHandler, self.stream_examine) @requires_unix_sockets def test_UnixStreamServer(self): self.run_server(socketserver.UnixStreamServer, socketserver.StreamRequestHandler, self.stream_examine) @requires_unix_sockets def test_ThreadingUnixStreamServer(self): self.run_server(socketserver.ThreadingUnixStreamServer, socketserver.StreamRequestHandler, self.stream_examine) @requires_unix_sockets @requires_forking def test_ForkingUnixStreamServer(self): with simple_subprocess(self): self.run_server(ForkingUnixStreamServer, socketserver.StreamRequestHandler, self.stream_examine) def test_UDPServer(self): self.run_server(socketserver.UDPServer, socketserver.DatagramRequestHandler, self.dgram_examine) def test_ThreadingUDPServer(self): self.run_server(socketserver.ThreadingUDPServer, socketserver.DatagramRequestHandler, self.dgram_examine) @requires_forking def test_ForkingUDPServer(self): with simple_subprocess(self): self.run_server(socketserver.ForkingUDPServer, socketserver.DatagramRequestHandler, self.dgram_examine) @contextlib.contextmanager def mocked_select_module(self): """Mocks the select.select() call to raise EINTR for first call""" old_select = select.select class MockSelect: def __init__(self): self.called = 0 def __call__(self, *args): self.called += 1 if self.called == 1: # raise the exception on first call raise OSError(errno.EINTR, os.strerror(errno.EINTR)) else: # Return real select value for consecutive calls return old_select(*args) select.select = MockSelect() try: yield select.select finally: select.select = old_select def test_InterruptServerSelectCall(self): with self.mocked_select_module() as mock_select: pid = self.run_server(socketserver.TCPServer, socketserver.StreamRequestHandler, self.stream_examine) # Make sure select was called again: self.assertGreater(mock_select.called, 1) # Alas, on Linux (at least) recvfrom() doesn't return a meaningful # client address so this cannot work: # @requires_unix_sockets # def test_UnixDatagramServer(self): # self.run_server(socketserver.UnixDatagramServer, # socketserver.DatagramRequestHandler, # self.dgram_examine) # # @requires_unix_sockets # def test_ThreadingUnixDatagramServer(self): # self.run_server(socketserver.ThreadingUnixDatagramServer, # socketserver.DatagramRequestHandler, # self.dgram_examine) # # @requires_unix_sockets # @requires_forking # def test_ForkingUnixDatagramServer(self): # self.run_server(socketserver.ForkingUnixDatagramServer, # socketserver.DatagramRequestHandler, # self.dgram_examine) @reap_threads def test_shutdown(self): # Issue #2302: shutdown() should always succeed in making an # other thread leave serve_forever(). class MyServer(socketserver.TCPServer): pass class MyHandler(socketserver.StreamRequestHandler): pass threads = [] for i in range(20): s = MyServer((HOST, 0), MyHandler) t = threading.Thread( name='MyServer serving', target=s.serve_forever, kwargs={'poll_interval':0.01}) t.daemon = True # In case this function raises. threads.append((t, s)) for t, s in threads: t.start() s.shutdown() for t, s in threads: t.join() s.server_close() def test_main(): if imp.lock_held(): # If the import lock is held, the threads will hang raise unittest.SkipTest("can't run when import lock is held") test.support.run_unittest(SocketServerTest) if __name__ == "__main__": test_main()