""" Test suite for socketserver. """ import contextlib import imp import os import select import signal import socket import tempfile import unittest import socketserver import test.support from test.support import reap_children, reap_threads, verbose try: import threading except ImportError: threading = None test.support.requires("network") TEST_STR = b"hello world\n" HOST = test.support.HOST HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX") HAVE_FORKING = hasattr(os, "fork") and os.name != "os2" def signal_alarm(n): """Call signal.alarm when it exists (i.e. not on Windows).""" if hasattr(signal, 'alarm'): signal.alarm(n) def receive(sock, n, timeout=20): r, w, x = select.select([sock], [], [], timeout) if sock in r: return sock.recv(n) else: raise RuntimeError("timed out on %r" % (sock,)) if HAVE_UNIX_SOCKETS: class ForkingUnixStreamServer(socketserver.ForkingMixIn, socketserver.UnixStreamServer): pass class ForkingUnixDatagramServer(socketserver.ForkingMixIn, socketserver.UnixDatagramServer): pass @contextlib.contextmanager def simple_subprocess(testcase): pid = os.fork() if pid == 0: # Don't throw an exception; it would be caught by the test harness. os._exit(72) yield None pid2, status = os.waitpid(pid, 0) testcase.assertEquals(pid2, pid) testcase.assertEquals(72 << 8, status) class SocketServerTest(unittest.TestCase): """Test all socket servers.""" def setUp(self): signal_alarm(20) # Kill deadlocks after 20 seconds. self.port_seed = 0 self.test_files = [] def tearDown(self): signal_alarm(0) # Didn't deadlock. reap_children() for fn in self.test_files: try: os.remove(fn) except os.error: pass self.test_files[:] = [] def pickaddr(self, proto): if proto == socket.AF_INET: return (HOST, 0) else: # XXX: We need a way to tell AF_UNIX to pick its own name # like AF_INET provides port==0. dir = None if os.name == 'os2': dir = '\socket' fn = tempfile.mktemp(prefix='unix_socket.', dir=dir) if os.name == 'os2': # AF_UNIX socket names on OS/2 require a specific prefix # which can't include a drive letter and must also use # backslashes as directory separators if fn[1] == ':': fn = fn[2:] if fn[0] in (os.sep, os.altsep): fn = fn[1:] if os.sep == '/': fn = fn.replace(os.sep, os.altsep) else: fn = fn.replace(os.altsep, os.sep) self.test_files.append(fn) return fn def make_server(self, addr, svrcls, hdlrbase): class MyServer(svrcls): def handle_error(self, request, client_address): self.close_request(request) self.server_close() raise class MyHandler(hdlrbase): def handle(self): line = self.rfile.readline() self.wfile.write(line) if verbose: print("creating server") server = MyServer(addr, MyHandler) self.assertEquals(server.server_address, server.socket.getsockname()) return server @unittest.skipUnless(threading, 'Threading required for this test.') @reap_threads def run_server(self, svrcls, hdlrbase, testfunc): server = self.make_server(self.pickaddr(svrcls.address_family), svrcls, hdlrbase) # We had the OS pick a port, so pull the real address out of # the server. addr = server.server_address if verbose: print("ADDR =", addr) print("CLASS =", svrcls) t = threading.Thread( name='%s serving' % svrcls, target=server.serve_forever, # Short poll interval to make the test finish quickly. # Time between requests is short enough that we won't wake # up spuriously too many times. kwargs={'poll_interval':0.01}) t.daemon = True # In case this function raises. t.start() if verbose: print("server running") for i in range(3): if verbose: print("test client", i) testfunc(svrcls.address_family, addr) if verbose: print("waiting for server") server.shutdown() t.join() if verbose: print("done") def stream_examine(self, proto, addr): s = socket.socket(proto, socket.SOCK_STREAM) s.connect(addr) s.sendall(TEST_STR) buf = data = receive(s, 100) while data and b'\n' not in buf: data = receive(s, 100) buf += data self.assertEquals(buf, TEST_STR) s.close() def dgram_examine(self, proto, addr): s = socket.socket(proto, socket.SOCK_DGRAM) s.sendto(TEST_STR, addr) buf = data = receive(s, 100) while data and b'\n' not in buf: data = receive(s, 100) buf += data self.assertEquals(buf, TEST_STR) s.close() def test_TCPServer(self): self.run_server(socketserver.TCPServer, socketserver.StreamRequestHandler, self.stream_examine) def test_ThreadingTCPServer(self): self.run_server(socketserver.ThreadingTCPServer, socketserver.StreamRequestHandler, self.stream_examine) if HAVE_FORKING: def test_ForkingTCPServer(self): with simple_subprocess(self): self.run_server(socketserver.ForkingTCPServer, socketserver.StreamRequestHandler, self.stream_examine) if HAVE_UNIX_SOCKETS: def test_UnixStreamServer(self): self.run_server(socketserver.UnixStreamServer, socketserver.StreamRequestHandler, self.stream_examine) def test_ThreadingUnixStreamServer(self): self.run_server(socketserver.ThreadingUnixStreamServer, socketserver.StreamRequestHandler, self.stream_examine) if HAVE_FORKING: def test_ForkingUnixStreamServer(self): with simple_subprocess(self): self.run_server(ForkingUnixStreamServer, socketserver.StreamRequestHandler, self.stream_examine) def test_UDPServer(self): self.run_server(socketserver.UDPServer, socketserver.DatagramRequestHandler, self.dgram_examine) def test_ThreadingUDPServer(self): self.run_server(socketserver.ThreadingUDPServer, socketserver.DatagramRequestHandler, self.dgram_examine) if HAVE_FORKING: def test_ForkingUDPServer(self): with simple_subprocess(self): self.run_server(socketserver.ForkingUDPServer, socketserver.DatagramRequestHandler, self.dgram_examine) # Alas, on Linux (at least) recvfrom() doesn't return a meaningful # client address so this cannot work: # if HAVE_UNIX_SOCKETS: # def test_UnixDatagramServer(self): # self.run_server(socketserver.UnixDatagramServer, # socketserver.DatagramRequestHandler, # self.dgram_examine) # # def test_ThreadingUnixDatagramServer(self): # self.run_server(socketserver.ThreadingUnixDatagramServer, # socketserver.DatagramRequestHandler, # self.dgram_examine) # # if HAVE_FORKING: # def test_ForkingUnixDatagramServer(self): # self.run_server(socketserver.ForkingUnixDatagramServer, # socketserver.DatagramRequestHandler, # self.dgram_examine) @reap_threads def test_shutdown(self): # Issue #2302: shutdown() should always succeed in making an # other thread leave serve_forever(). class MyServer(socketserver.TCPServer): pass class MyHandler(socketserver.StreamRequestHandler): pass threads = [] for i in range(20): s = MyServer((HOST, 0), MyHandler) t = threading.Thread( name='MyServer serving', target=s.serve_forever, kwargs={'poll_interval':0.01}) t.daemon = True # In case this function raises. threads.append((t, s)) for t, s in threads: t.start() s.shutdown() for t, s in threads: t.join() def test_main(): if imp.lock_held(): # If the import lock is held, the threads will hang raise unittest.SkipTest("can't run when import lock is held") test.support.run_unittest(SocketServerTest) if __name__ == "__main__": test_main() signal_alarm(3) # Shutdown shouldn't take more than 3 seconds.