#!/usr/bin/env python import unittest import test_support import socket import select import time import thread, threading import Queue PORT = 50007 HOST = 'localhost' MSG = 'Michael Gilfix was here\n' class SocketTCPTest(unittest.TestCase): def setUp(self): self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.serv.bind((HOST, PORT)) self.serv.listen(1) def tearDown(self): self.serv.close() self.serv = None class SocketUDPTest(unittest.TestCase): def setUp(self): self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.serv.bind((HOST, PORT)) def tearDown(self): self.serv.close() self.serv = None class ThreadableTest: def __init__(self): # Swap the true setup function self.__setUp = self.setUp self.__tearDown = self.tearDown self.setUp = self._setUp self.tearDown = self._tearDown def _setUp(self): self.ready = threading.Event() self.done = threading.Event() self.queue = Queue.Queue(1) # Do some munging to start the client test. test_method = getattr(self, ''.join(('_', self._TestCase__testMethodName))) self.client_thread = thread.start_new_thread(self.clientRun, (test_method, )) self.__setUp() self.ready.wait() def _tearDown(self): self.__tearDown() self.done.wait() if not self.queue.empty(): msg = self.queue.get() self.fail(msg) def clientRun(self, test_func): self.ready.set() self.clientSetUp() if not callable(test_func): raise TypeError, "test_func must be a callable function" try: test_func() except Exception, strerror: self.queue.put(strerror) self.clientTearDown() def clientSetUp(self): raise NotImplementedError, "clientSetUp must be implemented." def clientTearDown(self): self.done.set() thread.exit() class ThreadedTCPSocketTest(SocketTCPTest, ThreadableTest): def __init__(self, methodName='runTest'): SocketTCPTest.__init__(self, methodName=methodName) ThreadableTest.__init__(self) def clientSetUp(self): self.cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM) def clientTearDown(self): self.cli.close() self.cli = None ThreadableTest.clientTearDown(self) class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest): def __init__(self, methodName='runTest'): SocketUDPTest.__init__(self, methodName=methodName) ThreadableTest.__init__(self) def clientSetUp(self): self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) class SocketConnectedTest(ThreadedTCPSocketTest): def __init__(self, methodName='runTest'): ThreadedTCPSocketTest.__init__(self, methodName=methodName) def setUp(self): ThreadedTCPSocketTest.setUp(self) conn, addr = self.serv.accept() self.cli_conn = conn def tearDown(self): self.cli_conn.close() self.cli_conn = None ThreadedTCPSocketTest.tearDown(self) def clientSetUp(self): ThreadedTCPSocketTest.clientSetUp(self) self.cli.connect((HOST, PORT)) self.serv_conn = self.cli def clientTearDown(self): self.serv_conn.close() self.serv_conn = None ThreadedTCPSocketTest.clientTearDown(self) ####################################################################### ## Begin Tests class GeneralModuleTests(unittest.TestCase): def testSocketError(self): """Testing that socket module exceptions.""" def raise_error(*args, **kwargs): raise socket.error def raise_herror(*args, **kwargs): raise socket.herror def raise_gaierror(*args, **kwargs): raise socket.gaierror self.failUnlessRaises(socket.error, raise_error, "Error raising socket exception.") self.failUnlessRaises(socket.error, raise_herror, "Error raising socket exception.") self.failUnlessRaises(socket.error, raise_gaierror, "Error raising socket exception.") def testCrucialConstants(self): """Testing for mission critical constants.""" socket.AF_INET socket.SOCK_STREAM socket.SOCK_DGRAM socket.SOCK_RAW socket.SOCK_RDM socket.SOCK_SEQPACKET socket.SOL_SOCKET socket.SO_REUSEADDR def testNonCrucialConstants(self): """Testing for existance of non-crucial constants.""" for const in ( "AF_UNIX", "SO_DEBUG", "SO_ACCEPTCONN", "SO_REUSEADDR", "SO_KEEPALIVE", "SO_DONTROUTE", "SO_BROADCAST", "SO_USELOOPBACK", "SO_LINGER", "SO_OOBINLINE", "SO_REUSEPORT", "SO_SNDBUF", "SO_RCVBUF", "SO_SNDLOWAT", "SO_RCVLOWAT", "SO_SNDTIMEO", "SO_RCVTIMEO", "SO_ERROR", "SO_TYPE", "SOMAXCONN", "MSG_OOB", "MSG_PEEK", "MSG_DONTROUTE", "MSG_EOR", "MSG_TRUNC", "MSG_CTRUNC", "MSG_WAITALL", "MSG_BTAG", "MSG_ETAG", "SOL_SOCKET", "IPPROTO_IP", "IPPROTO_ICMP", "IPPROTO_IGMP", "IPPROTO_GGP", "IPPROTO_TCP", "IPPROTO_EGP", "IPPROTO_PUP", "IPPROTO_UDP", "IPPROTO_IDP", "IPPROTO_HELLO", "IPPROTO_ND", "IPPROTO_TP", "IPPROTO_XTP", "IPPROTO_EON", "IPPROTO_BIP", "IPPROTO_RAW", "IPPROTO_MAX", "IPPORT_RESERVED", "IPPORT_USERRESERVED", "INADDR_ANY", "INADDR_BROADCAST", "INADDR_LOOPBACK", "INADDR_UNSPEC_GROUP", "INADDR_ALLHOSTS_GROUP", "INADDR_MAX_LOCAL_GROUP", "INADDR_NONE", "IP_OPTIONS", "IP_HDRINCL", "IP_TOS", "IP_TTL", "IP_RECVOPTS", "IP_RECVRETOPTS", "IP_RECVDSTADDR", "IP_RETOPTS", "IP_MULTICAST_IF", "IP_MULTICAST_TTL", "IP_MULTICAST_LOOP", "IP_ADD_MEMBERSHIP", "IP_DROP_MEMBERSHIP", ): try: getattr(socket, const) except AttributeError: pass def testHostnameRes(self): """Testing hostname resolution mechanisms.""" hostname = socket.gethostname() ip = socket.gethostbyname(hostname) self.assert_(ip.find('.') >= 0, "Error resolving host to ip.") hname, aliases, ipaddrs = socket.gethostbyaddr(ip) all_host_names = [hname] + aliases fqhn = socket.getfqdn() if not fqhn in all_host_names: self.fail("Error testing host resolution mechanisms.") def testRefCountGetNameInfo(self): """Testing reference count for getnameinfo.""" import sys if hasattr(sys, "getrefcount"): try: # On some versions, this loses a reference orig = sys.getrefcount(__name__) socket.getnameinfo(__name__,0) except SystemError: if sys.getrefcount(__name__) <> orig: self.fail("socket.getnameinfo loses a reference") def testInterpreterCrash(self): """Making sure getnameinfo doesn't crash the interpreter.""" try: # On some versions, this crashes the interpreter. socket.getnameinfo(('x', 0, 0, 0), 0) except socket.error: pass def testGetServByName(self): """Testing getservbyname().""" if hasattr(socket, 'getservbyname'): socket.getservbyname('telnet', 'tcp') try: socket.getservbyname('telnet', 'udp') except socket.error: pass def testSockName(self): """Testing getsockname().""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.bind(("0.0.0.0", PORT+1)) name = sock.getsockname() self.assertEqual(name, ("0.0.0.0", PORT+1)) def testGetSockOpt(self): """Testing getsockopt().""" # We know a socket should start without reuse==0 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) self.failIf(reuse != 0, "initial mode is reuse") def testSetSockOpt(self): """Testing setsockopt().""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) self.failIf(reuse == 0, "failed to set reuse mode") class BasicTCPTest(SocketConnectedTest): def __init__(self, methodName='runTest'): SocketConnectedTest.__init__(self, methodName=methodName) def testRecv(self): """Testing large receive over TCP.""" msg = self.cli_conn.recv(1024) self.assertEqual(msg, MSG) def _testRecv(self): self.serv_conn.send(MSG) def testOverFlowRecv(self): """Testing receive in chunks over TCP.""" seg1 = self.cli_conn.recv(len(MSG) - 3) seg2 = self.cli_conn.recv(1024) msg = ''.join ((seg1, seg2)) self.assertEqual(msg, MSG) def _testOverFlowRecv(self): self.serv_conn.send(MSG) def testRecvFrom(self): """Testing large recvfrom() over TCP.""" msg, addr = self.cli_conn.recvfrom(1024) hostname, port = addr ##self.assertEqual(hostname, socket.gethostbyname('localhost')) self.assertEqual(msg, MSG) def _testRecvFrom(self): self.serv_conn.send(MSG) def testOverFlowRecvFrom(self): """Testing recvfrom() in chunks over TCP.""" seg1, addr = self.cli_conn.recvfrom(len(MSG)-3) seg2, addr = self.cli_conn.recvfrom(1024) msg = ''.join((seg1, seg2)) hostname, port = addr ##self.assertEqual(hostname, socket.gethostbyname('localhost')) self.assertEqual(msg, MSG) def _testOverFlowRecvFrom(self): self.serv_conn.send(MSG) def testSendAll(self): """Testing sendall() with a 2048 byte string over TCP.""" while 1: read = self.cli_conn.recv(1024) if not read: break self.assert_(len(read) == 1024, "Error performing sendall.") read = filter(lambda x: x == 'f', read) self.assert_(len(read) == 1024, "Error performing sendall.") def _testSendAll(self): big_chunk = ''.join([ 'f' ] * 2048) self.serv_conn.sendall(big_chunk) def testFromFd(self): """Testing fromfd().""" if not hasattr(socket, "fromfd"): return # On Windows, this doesn't exist fd = self.cli_conn.fileno() sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM) msg = sock.recv(1024) self.assertEqual(msg, MSG) def _testFromFd(self): self.serv_conn.send(MSG) def testShutdown(self): """Testing shutdown().""" msg = self.cli_conn.recv(1024) self.assertEqual(msg, MSG) def _testShutdown(self): self.serv_conn.send(MSG) self.serv_conn.shutdown(2) class BasicUDPTest(ThreadedUDPSocketTest): def __init__(self, methodName='runTest'): ThreadedUDPSocketTest.__init__(self, methodName=methodName) def testSendtoAndRecv(self): """Testing sendto() and Recv() over UDP.""" msg = self.serv.recv(len(MSG)) self.assertEqual(msg, MSG) def _testSendtoAndRecv(self): self.cli.sendto(MSG, 0, (HOST, PORT)) def testRecvFrom(self): """Testing recfrom() over UDP.""" msg, addr = self.serv.recvfrom(len(MSG)) hostname, port = addr ##self.assertEqual(hostname, socket.gethostbyname('localhost')) self.assertEqual(msg, MSG) def _testRecvFrom(self): self.cli.sendto(MSG, 0, (HOST, PORT)) class NonBlockingTCPTests(ThreadedTCPSocketTest): def __init__(self, methodName='runTest'): ThreadedTCPSocketTest.__init__(self, methodName=methodName) def testSetBlocking(self): """Testing whether set blocking works.""" self.serv.setblocking(0) start = time.time() try: self.serv.accept() except socket.error: pass end = time.time() self.assert_((end - start) < 1.0, "Error setting non-blocking mode.") def _testSetBlocking(self): pass def testAccept(self): """Testing non-blocking accept.""" self.serv.setblocking(0) try: conn, addr = self.serv.accept() except socket.error: pass else: self.fail("Error trying to do non-blocking accept.") read, write, err = select.select([self.serv], [], []) if self.serv in read: conn, addr = self.serv.accept() else: self.fail("Error trying to do accept after select.") def _testAccept(self): time.sleep(1) self.cli.connect((HOST, PORT)) def testConnect(self): """Testing non-blocking connect.""" time.sleep(1) conn, addr = self.serv.accept() def _testConnect(self): self.cli.setblocking(0) try: self.cli.connect((HOST, PORT)) except socket.error: pass else: self.fail("Error trying to do non-blocking connect.") read, write, err = select.select([self.cli], [], []) if self.cli in read: self.cli.connect((HOST, PORT)) else: self.fail("Error trying to do connect after select.") def testRecv(self): """Testing non-blocking recv.""" conn, addr = self.serv.accept() conn.setblocking(0) try: msg = conn.recv(len(MSG)) except socket.error: pass else: self.fail("Error trying to do non-blocking recv.") read, write, err = select.select([conn], [], []) if conn in read: msg = conn.recv(len(MSG)) self.assertEqual(msg, MSG) else: self.fail("Error during select call to non-blocking socket.") def _testRecv(self): self.cli.connect((HOST, PORT)) time.sleep(1) self.cli.send(MSG) class FileObjectClassTestCase(SocketConnectedTest): def __init__(self, methodName='runTest'): SocketConnectedTest.__init__(self, methodName=methodName) def setUp(self): SocketConnectedTest.setUp(self) self.serv_file = socket._fileobject(self.cli_conn, 'rb', 8192) def tearDown(self): self.serv_file.close() self.serv_file = None SocketConnectedTest.tearDown(self) def clientSetUp(self): SocketConnectedTest.clientSetUp(self) self.cli_file = socket._fileobject(self.serv_conn, 'rb', 8192) def clientTearDown(self): self.cli_file.close() self.cli_file = None SocketConnectedTest.clientTearDown(self) def testSmallRead(self): """Performing small file read test.""" first_seg = self.serv_file.read(len(MSG)-3) second_seg = self.serv_file.read(3) msg = ''.join((first_seg, second_seg)) self.assertEqual(msg, MSG) def _testSmallRead(self): self.cli_file.write(MSG) self.cli_file.flush() def testUnbufferedRead(self): """Performing unbuffered file read test.""" buf = '' while 1: char = self.serv_file.read(1) self.failIf(not char) buf += char if buf == MSG: break def _testUnbufferedRead(self): self.cli_file.write(MSG) self.cli_file.flush() def testReadline(self): """Performing file readline test.""" line = self.serv_file.readline() self.assertEqual(line, MSG) def _testReadline(self): self.cli_file.write(MSG) self.cli_file.flush() def test_main(): suite = unittest.TestSuite() suite.addTest(unittest.makeSuite(GeneralModuleTests)) suite.addTest(unittest.makeSuite(BasicTCPTest)) suite.addTest(unittest.makeSuite(BasicUDPTest)) suite.addTest(unittest.makeSuite(NonBlockingTCPTests)) suite.addTest(unittest.makeSuite(FileObjectClassTestCase)) test_support.run_suite(suite) if __name__ == "__main__": test_main()