#!/usr/bin/env python import unittest from test import test_support import errno import socket import select import time import traceback import Queue import sys import os import array import contextlib from weakref import proxy import signal import math def try_address(host, port=0, family=socket.AF_INET): """Try to bind a socket on the given host:port and return True if that has been possible.""" try: sock = socket.socket(family, socket.SOCK_STREAM) sock.bind((host, port)) except (socket.error, socket.gaierror): return False else: sock.close() return True HOST = test_support.HOST MSG = b'Michael Gilfix was here\n' SUPPORTS_IPV6 = socket.has_ipv6 and try_address('::1', family=socket.AF_INET6) try: import thread import threading except ImportError: thread = None threading = None HOST = test_support.HOST MSG = 'Michael Gilfix was here\n' class SocketTCPTest(unittest.TestCase): def setUp(self): self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.port = test_support.bind_port(self.serv) self.serv.listen(1) def tearDown(self): self.serv.close() self.serv = None class SocketUDPTest(unittest.TestCase): def setUp(self): self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.port = test_support.bind_port(self.serv) def tearDown(self): self.serv.close() self.serv = None class ThreadableTest: """Threadable Test class The ThreadableTest class makes it easy to create a threaded client/server pair from an existing unit test. To create a new threaded class from an existing unit test, use multiple inheritance: class NewClass (OldClass, ThreadableTest): pass This class defines two new fixture functions with obvious purposes for overriding: clientSetUp () clientTearDown () Any new test functions within the class must then define tests in pairs, where the test name is preceeded with a '_' to indicate the client portion of the test. Ex: def testFoo(self): # Server portion def _testFoo(self): # Client portion Any exceptions raised by the clients during their tests are caught and transferred to the main thread to alert the testing framework. Note, the server setup function cannot call any blocking functions that rely on the client thread during setup, unless serverExplicitReady() is called just before the blocking call (such as in setting up a client/server connection and performing the accept() in setUp(). """ def __init__(self): # Swap the true setup function self.__setUp = self.setUp self.__tearDown = self.tearDown self.setUp = self._setUp self.tearDown = self._tearDown def serverExplicitReady(self): """This method allows the server to explicitly indicate that it wants the client thread to proceed. This is useful if the server is about to execute a blocking routine that is dependent upon the client thread during its setup routine.""" self.server_ready.set() def _setUp(self): self.server_ready = threading.Event() self.client_ready = threading.Event() self.done = threading.Event() self.queue = Queue.Queue(1) # Do some munging to start the client test. methodname = self.id() i = methodname.rfind('.') methodname = methodname[i+1:] test_method = getattr(self, '_' + methodname) self.client_thread = thread.start_new_thread( self.clientRun, (test_method,)) self.__setUp() if not self.server_ready.is_set(): self.server_ready.set() self.client_ready.wait() def _tearDown(self): self.__tearDown() self.done.wait() if not self.queue.empty(): msg = self.queue.get() self.fail(msg) def clientRun(self, test_func): self.server_ready.wait() self.clientSetUp() self.client_ready.set() if not callable(test_func): raise TypeError("test_func must be a callable function.") try: test_func() except Exception, strerror: self.queue.put(strerror) self.clientTearDown() def clientSetUp(self): raise NotImplementedError("clientSetUp must be implemented.") def clientTearDown(self): self.done.set() thread.exit() class ThreadedTCPSocketTest(SocketTCPTest, ThreadableTest): def __init__(self, methodName='runTest'): SocketTCPTest.__init__(self, methodName=methodName) ThreadableTest.__init__(self) def clientSetUp(self): self.cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM) def clientTearDown(self): self.cli.close() self.cli = None ThreadableTest.clientTearDown(self) class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest): def __init__(self, methodName='runTest'): SocketUDPTest.__init__(self, methodName=methodName) ThreadableTest.__init__(self) def clientSetUp(self): self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) def clientTearDown(self): self.cli.close() self.cli = None ThreadableTest.clientTearDown(self) class SocketConnectedTest(ThreadedTCPSocketTest): def __init__(self, methodName='runTest'): ThreadedTCPSocketTest.__init__(self, methodName=methodName) def setUp(self): ThreadedTCPSocketTest.setUp(self) # Indicate explicitly we're ready for the client thread to # proceed and then perform the blocking call to accept self.serverExplicitReady() conn, addr = self.serv.accept() self.cli_conn = conn def tearDown(self): self.cli_conn.close() self.cli_conn = None ThreadedTCPSocketTest.tearDown(self) def clientSetUp(self): ThreadedTCPSocketTest.clientSetUp(self) self.cli.connect((HOST, self.port)) self.serv_conn = self.cli def clientTearDown(self): self.serv_conn.close() self.serv_conn = None ThreadedTCPSocketTest.clientTearDown(self) class SocketPairTest(unittest.TestCase, ThreadableTest): def __init__(self, methodName='runTest'): unittest.TestCase.__init__(self, methodName=methodName) ThreadableTest.__init__(self) def setUp(self): self.serv, self.cli = socket.socketpair() def tearDown(self): self.serv.close() self.serv = None def clientSetUp(self): pass def clientTearDown(self): self.cli.close() self.cli = None ThreadableTest.clientTearDown(self) ####################################################################### ## Begin Tests class GeneralModuleTests(unittest.TestCase): def test_weakref(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) p = proxy(s) self.assertEqual(p.fileno(), s.fileno()) s.close() s = None try: p.fileno() except ReferenceError: pass else: self.fail('Socket proxy still exists') def testSocketError(self): # Testing socket module exceptions def raise_error(*args, **kwargs): raise socket.error def raise_herror(*args, **kwargs): raise socket.herror def raise_gaierror(*args, **kwargs): raise socket.gaierror self.assertRaises(socket.error, raise_error, "Error raising socket exception.") self.assertRaises(socket.error, raise_herror, "Error raising socket exception.") self.assertRaises(socket.error, raise_gaierror, "Error raising socket exception.") def testSendtoErrors(self): # Testing that sendto doens't masks failures. See #10169. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.addCleanup(s.close) s.bind(('', 0)) sockname = s.getsockname() # 2 args with self.assertRaises(UnicodeEncodeError): s.sendto(u'\u2620', sockname) with self.assertRaises(TypeError) as cm: s.sendto(5j, sockname) self.assertIn('not complex', str(cm.exception)) with self.assertRaises(TypeError) as cm: s.sendto('foo', None) self.assertIn('not NoneType', str(cm.exception)) # 3 args with self.assertRaises(UnicodeEncodeError): s.sendto(u'\u2620', 0, sockname) with self.assertRaises(TypeError) as cm: s.sendto(5j, 0, sockname) self.assertIn('not complex', str(cm.exception)) with self.assertRaises(TypeError) as cm: s.sendto('foo', 0, None) self.assertIn('not NoneType', str(cm.exception)) with self.assertRaises(TypeError) as cm: s.sendto('foo', 'bar', sockname) self.assertIn('an integer is required', str(cm.exception)) with self.assertRaises(TypeError) as cm: s.sendto('foo', None, None) self.assertIn('an integer is required', str(cm.exception)) # wrong number of args with self.assertRaises(TypeError) as cm: s.sendto('foo') self.assertIn('(1 given)', str(cm.exception)) with self.assertRaises(TypeError) as cm: s.sendto('foo', 0, sockname, 4) self.assertIn('(4 given)', str(cm.exception)) def testCrucialConstants(self): # Testing for mission critical constants socket.AF_INET socket.SOCK_STREAM socket.SOCK_DGRAM socket.SOCK_RAW socket.SOCK_RDM socket.SOCK_SEQPACKET socket.SOL_SOCKET socket.SO_REUSEADDR def testHostnameRes(self): # Testing hostname resolution mechanisms hostname = socket.gethostname() try: ip = socket.gethostbyname(hostname) except socket.error: # Probably name lookup wasn't set up right; skip this test return self.assertTrue(ip.find('.') >= 0, "Error resolving host to ip.") try: hname, aliases, ipaddrs = socket.gethostbyaddr(ip) except socket.error: # Probably a similar problem as above; skip this test return all_host_names = [hostname, hname] + aliases fqhn = socket.getfqdn(ip) if not fqhn in all_host_names: self.fail("Error testing host resolution mechanisms. (fqdn: %s, all: %s)" % (fqhn, repr(all_host_names))) def testRefCountGetNameInfo(self): # Testing reference count for getnameinfo if hasattr(sys, "getrefcount"): try: # On some versions, this loses a reference orig = sys.getrefcount(__name__) socket.getnameinfo(__name__,0) except TypeError: self.assertEqual(sys.getrefcount(__name__), orig, "socket.getnameinfo loses a reference") def testInterpreterCrash(self): # Making sure getnameinfo doesn't crash the interpreter try: # On some versions, this crashes the interpreter. socket.getnameinfo(('x', 0, 0, 0), 0) except socket.error: pass def testNtoH(self): # This just checks that htons etc. are their own inverse, # when looking at the lower 16 or 32 bits. sizes = {socket.htonl: 32, socket.ntohl: 32, socket.htons: 16, socket.ntohs: 16} for func, size in sizes.items(): mask = (1L<>> socket.getaddrinfo('localhost', port, 0, SOCK_STREAM) # >>> [(2, 2, 0, '', ('127.0.0.1', 41230)), # (26, 2, 0, '', ('::1', 41230, 0, 0))] # # create_connection() enumerates through all the addresses returned # and if it doesn't successfully bind to any of them, it propagates # the last exception it encountered. # # On Solaris, ENETUNREACH is returned in this circumstance instead # of ECONNREFUSED. So, if that errno exists, add it to our list of # expected errnos. expected_errnos = [ errno.ECONNREFUSED, ] if hasattr(errno, 'ENETUNREACH'): expected_errnos.append(errno.ENETUNREACH) self.assertIn(cm.exception.errno, expected_errnos) def test_create_connection_timeout(self): # Issue #9792: create_connection() should not recast timeout errors # as generic socket errors. with self.mocked_socket_module(): with self.assertRaises(socket.timeout): socket.create_connection((HOST, 1234)) @unittest.skipUnless(thread, 'Threading required for this test.') class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest): def __init__(self, methodName='runTest'): SocketTCPTest.__init__(self, methodName=methodName) ThreadableTest.__init__(self) def clientSetUp(self): self.source_port = test_support.find_unused_port() def clientTearDown(self): self.cli.close() self.cli = None ThreadableTest.clientTearDown(self) def _justAccept(self): conn, addr = self.serv.accept() conn.close() testFamily = _justAccept def _testFamily(self): self.cli = socket.create_connection((HOST, self.port), timeout=30) self.addCleanup(self.cli.close) self.assertEqual(self.cli.family, 2) testSourceAddress = _justAccept def _testSourceAddress(self): self.cli = socket.create_connection((HOST, self.port), timeout=30, source_address=('', self.source_port)) self.addCleanup(self.cli.close) self.assertEqual(self.cli.getsockname()[1], self.source_port) # The port number being used is sufficient to show that the bind() # call happened. testTimeoutDefault = _justAccept def _testTimeoutDefault(self): # passing no explicit timeout uses socket's global default self.assertTrue(socket.getdefaulttimeout() is None) socket.setdefaulttimeout(42) try: self.cli = socket.create_connection((HOST, self.port)) self.addCleanup(self.cli.close) finally: socket.setdefaulttimeout(None) self.assertEqual(self.cli.gettimeout(), 42) testTimeoutNone = _justAccept def _testTimeoutNone(self): # None timeout means the same as sock.settimeout(None) self.assertTrue(socket.getdefaulttimeout() is None) socket.setdefaulttimeout(30) try: self.cli = socket.create_connection((HOST, self.port), timeout=None) self.addCleanup(self.cli.close) finally: socket.setdefaulttimeout(None) self.assertEqual(self.cli.gettimeout(), None) testTimeoutValueNamed = _justAccept def _testTimeoutValueNamed(self): self.cli = socket.create_connection((HOST, self.port), timeout=30) self.assertEqual(self.cli.gettimeout(), 30) testTimeoutValueNonamed = _justAccept def _testTimeoutValueNonamed(self): self.cli = socket.create_connection((HOST, self.port), 30) self.addCleanup(self.cli.close) self.assertEqual(self.cli.gettimeout(), 30) @unittest.skipUnless(thread, 'Threading required for this test.') class NetworkConnectionBehaviourTest(SocketTCPTest, ThreadableTest): def __init__(self, methodName='runTest'): SocketTCPTest.__init__(self, methodName=methodName) ThreadableTest.__init__(self) def clientSetUp(self): pass def clientTearDown(self): self.cli.close() self.cli = None ThreadableTest.clientTearDown(self) def testInsideTimeout(self): conn, addr = self.serv.accept() self.addCleanup(conn.close) time.sleep(3) conn.send("done!") testOutsideTimeout = testInsideTimeout def _testInsideTimeout(self): self.cli = sock = socket.create_connection((HOST, self.port)) data = sock.recv(5) self.assertEqual(data, "done!") def _testOutsideTimeout(self): self.cli = sock = socket.create_connection((HOST, self.port), timeout=1) self.assertRaises(socket.timeout, lambda: sock.recv(5)) class Urllib2FileobjectTest(unittest.TestCase): # urllib2.HTTPHandler has "borrowed" socket._fileobject, and requires that # it close the socket if the close c'tor argument is true def testClose(self): class MockSocket: closed = False def flush(self): pass def close(self): self.closed = True # must not close unless we request it: the original use of _fileobject # by module socket requires that the underlying socket not be closed until # the _socketobject that created the _fileobject is closed s = MockSocket() f = socket._fileobject(s) f.close() self.assertTrue(not s.closed) s = MockSocket() f = socket._fileobject(s, close=True) f.close() self.assertTrue(s.closed) class TCPTimeoutTest(SocketTCPTest): def testTCPTimeout(self): def raise_timeout(*args, **kwargs): self.serv.settimeout(1.0) self.serv.accept() self.assertRaises(socket.timeout, raise_timeout, "Error generating a timeout exception (TCP)") def testTimeoutZero(self): ok = False try: self.serv.settimeout(0.0) foo = self.serv.accept() except socket.timeout: self.fail("caught timeout instead of error (TCP)") except socket.error: ok = True except: self.fail("caught unexpected exception (TCP)") if not ok: self.fail("accept() returned success when we did not expect it") def testInterruptedTimeout(self): # XXX I don't know how to do this test on MSWindows or any other # plaform that doesn't support signal.alarm() or os.kill(), though # the bug should have existed on all platforms. if not hasattr(signal, "alarm"): return # can only test on *nix self.serv.settimeout(5.0) # must be longer than alarm class Alarm(Exception): pass def alarm_handler(signal, frame): raise Alarm old_alarm = signal.signal(signal.SIGALRM, alarm_handler) try: signal.alarm(2) # POSIX allows alarm to be up to 1 second early try: foo = self.serv.accept() except socket.timeout: self.fail("caught timeout instead of Alarm") except Alarm: pass except: self.fail("caught other exception instead of Alarm:" " %s(%s):\n%s" % (sys.exc_info()[:2] + (traceback.format_exc(),))) else: self.fail("nothing caught") finally: signal.alarm(0) # shut off alarm except Alarm: self.fail("got Alarm in wrong place") finally: # no alarm can be pending. Safe to restore old handler. signal.signal(signal.SIGALRM, old_alarm) class UDPTimeoutTest(SocketUDPTest): def testUDPTimeout(self): def raise_timeout(*args, **kwargs): self.serv.settimeout(1.0) self.serv.recv(1024) self.assertRaises(socket.timeout, raise_timeout, "Error generating a timeout exception (UDP)") def testTimeoutZero(self): ok = False try: self.serv.settimeout(0.0) foo = self.serv.recv(1024) except socket.timeout: self.fail("caught timeout instead of error (UDP)") except socket.error: ok = True except: self.fail("caught unexpected exception (UDP)") if not ok: self.fail("recv() returned success when we did not expect it") class TestExceptions(unittest.TestCase): def testExceptionTree(self): self.assertTrue(issubclass(socket.error, Exception)) self.assertTrue(issubclass(socket.herror, socket.error)) self.assertTrue(issubclass(socket.gaierror, socket.error)) self.assertTrue(issubclass(socket.timeout, socket.error)) class TestLinuxAbstractNamespace(unittest.TestCase): UNIX_PATH_MAX = 108 def testLinuxAbstractNamespace(self): address = "\x00python-test-hello\x00\xff" s1 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) s1.bind(address) s1.listen(1) s2 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) s2.connect(s1.getsockname()) s1.accept() self.assertEqual(s1.getsockname(), address) self.assertEqual(s2.getpeername(), address) def testMaxName(self): address = "\x00" + "h" * (self.UNIX_PATH_MAX - 1) s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) s.bind(address) self.assertEqual(s.getsockname(), address) def testNameOverflow(self): address = "\x00" + "h" * self.UNIX_PATH_MAX s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.assertRaises(socket.error, s.bind, address) @unittest.skipUnless(thread, 'Threading required for this test.') class BufferIOTest(SocketConnectedTest): """ Test the buffer versions of socket.recv() and socket.send(). """ def __init__(self, methodName='runTest'): SocketConnectedTest.__init__(self, methodName=methodName) def testRecvIntoArray(self): buf = array.array('c', ' '*1024) nbytes = self.cli_conn.recv_into(buf) self.assertEqual(nbytes, len(MSG)) msg = buf.tostring()[:len(MSG)] self.assertEqual(msg, MSG) def _testRecvIntoArray(self): with test_support.check_py3k_warnings(): buf = buffer(MSG) self.serv_conn.send(buf) def testRecvIntoBytearray(self): buf = bytearray(1024) nbytes = self.cli_conn.recv_into(buf) self.assertEqual(nbytes, len(MSG)) msg = buf[:len(MSG)] self.assertEqual(msg, MSG) _testRecvIntoBytearray = _testRecvIntoArray def testRecvIntoMemoryview(self): buf = bytearray(1024) nbytes = self.cli_conn.recv_into(memoryview(buf)) self.assertEqual(nbytes, len(MSG)) msg = buf[:len(MSG)] self.assertEqual(msg, MSG) _testRecvIntoMemoryview = _testRecvIntoArray def testRecvFromIntoArray(self): buf = array.array('c', ' '*1024) nbytes, addr = self.cli_conn.recvfrom_into(buf) self.assertEqual(nbytes, len(MSG)) msg = buf.tostring()[:len(MSG)] self.assertEqual(msg, MSG) def _testRecvFromIntoArray(self): with test_support.check_py3k_warnings(): buf = buffer(MSG) self.serv_conn.send(buf) def testRecvFromIntoBytearray(self): buf = bytearray(1024) nbytes, addr = self.cli_conn.recvfrom_into(buf) self.assertEqual(nbytes, len(MSG)) msg = buf[:len(MSG)] self.assertEqual(msg, MSG) _testRecvFromIntoBytearray = _testRecvFromIntoArray def testRecvFromIntoMemoryview(self): buf = bytearray(1024) nbytes, addr = self.cli_conn.recvfrom_into(memoryview(buf)) self.assertEqual(nbytes, len(MSG)) msg = buf[:len(MSG)] self.assertEqual(msg, MSG) _testRecvFromIntoMemoryview = _testRecvFromIntoArray TIPC_STYPE = 2000 TIPC_LOWER = 200 TIPC_UPPER = 210 def isTipcAvailable(): """Check if the TIPC module is loaded The TIPC module is not loaded automatically on Ubuntu and probably other Linux distros. """ if not hasattr(socket, "AF_TIPC"): return False if not os.path.isfile("/proc/modules"): return False with open("/proc/modules") as f: for line in f: if line.startswith("tipc "): return True if test_support.verbose: print "TIPC module is not loaded, please 'sudo modprobe tipc'" return False class TIPCTest (unittest.TestCase): def testRDM(self): srv = socket.socket(socket.AF_TIPC, socket.SOCK_RDM) cli = socket.socket(socket.AF_TIPC, socket.SOCK_RDM) srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE, TIPC_LOWER, TIPC_UPPER) srv.bind(srvaddr) sendaddr = (socket.TIPC_ADDR_NAME, TIPC_STYPE, TIPC_LOWER + (TIPC_UPPER - TIPC_LOWER) / 2, 0) cli.sendto(MSG, sendaddr) msg, recvaddr = srv.recvfrom(1024) self.assertEqual(cli.getsockname(), recvaddr) self.assertEqual(msg, MSG) class TIPCThreadableTest (unittest.TestCase, ThreadableTest): def __init__(self, methodName = 'runTest'): unittest.TestCase.__init__(self, methodName = methodName) ThreadableTest.__init__(self) def setUp(self): self.srv = socket.socket(socket.AF_TIPC, socket.SOCK_STREAM) self.srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE, TIPC_LOWER, TIPC_UPPER) self.srv.bind(srvaddr) self.srv.listen(5) self.serverExplicitReady() self.conn, self.connaddr = self.srv.accept() def clientSetUp(self): # The is a hittable race between serverExplicitReady() and the # accept() call; sleep a little while to avoid it, otherwise # we could get an exception time.sleep(0.1) self.cli = socket.socket(socket.AF_TIPC, socket.SOCK_STREAM) addr = (socket.TIPC_ADDR_NAME, TIPC_STYPE, TIPC_LOWER + (TIPC_UPPER - TIPC_LOWER) / 2, 0) self.cli.connect(addr) self.cliaddr = self.cli.getsockname() def testStream(self): msg = self.conn.recv(1024) self.assertEqual(msg, MSG) self.assertEqual(self.cliaddr, self.connaddr) def _testStream(self): self.cli.send(MSG) self.cli.close() def test_main(): tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest, TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ] tests.extend([ NonBlockingTCPTests, FileObjectClassTestCase, FileObjectInterruptedTestCase, UnbufferedFileObjectClassTestCase, LineBufferedFileObjectClassTestCase, SmallBufferedFileObjectClassTestCase, Urllib2FileobjectTest, NetworkConnectionNoServer, NetworkConnectionAttributesTest, NetworkConnectionBehaviourTest, ]) if hasattr(socket, "socketpair"): tests.append(BasicSocketPairTest) if sys.platform == 'linux2': tests.append(TestLinuxAbstractNamespace) if isTipcAvailable(): tests.append(TIPCTest) tests.append(TIPCThreadableTest) thread_info = test_support.threading_setup() test_support.run_unittest(*tests) test_support.threading_cleanup(*thread_info) if __name__ == "__main__": test_main()