diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index ececebfd1b1..a2c047daa3a 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -34,7 +34,8 @@ except ImportError: fcntl = None HOST = support.HOST -MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf-8') ## test unicode string and carriage return +# test unicode string and carriage return +MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf-8') MAIN_TIMEOUT = 60.0 VSOCKPORT = 1234 @@ -111,9 +112,14 @@ def _have_socket_vsock(): return ret -def _is_fd_in_blocking_mode(sock): - return not bool( - fcntl.fcntl(sock, fcntl.F_GETFL, os.O_NONBLOCK) & os.O_NONBLOCK) +@contextlib.contextmanager +def socket_setdefaulttimeout(timeout): + old_timeout = socket.getdefaulttimeout() + try: + socket.setdefaulttimeout(timeout) + yield + finally: + socket.setdefaulttimeout(old_timeout) HAVE_SOCKET_CAN = _have_socket_can() @@ -1069,18 +1075,16 @@ class GeneralModuleTests(unittest.TestCase): s.close() # Set the default timeout to 10, and see if it propagates - socket.setdefaulttimeout(10) - self.assertEqual(socket.getdefaulttimeout(), 10) - s = socket.socket() - self.assertEqual(s.gettimeout(), 10) - s.close() + with socket_setdefaulttimeout(10): + self.assertEqual(socket.getdefaulttimeout(), 10) + with socket.socket() as sock: + self.assertEqual(sock.gettimeout(), 10) - # Reset the default timeout to None, and see if it propagates - socket.setdefaulttimeout(None) - self.assertEqual(socket.getdefaulttimeout(), None) - s = socket.socket() - self.assertEqual(s.gettimeout(), None) - s.close() + # Reset the default timeout to None, and see if it propagates + socket.setdefaulttimeout(None) + self.assertEqual(socket.getdefaulttimeout(), None) + with socket.socket() as sock: + self.assertEqual(sock.gettimeout(), None) # Check that setting it to an invalid value raises ValueError self.assertRaises(ValueError, socket.setdefaulttimeout, -1) @@ -4218,55 +4222,42 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): self.event = threading.Event() ThreadedTCPSocketTest.__init__(self, methodName=methodName) + def assert_sock_timeout(self, sock, timeout): + self.assertEqual(self.serv.gettimeout(), timeout) + + blocking = (timeout != 0.0) + self.assertEqual(sock.getblocking(), blocking) + + if fcntl is not None: + # When a Python socket has a non-zero timeout, it's switched + # internally to a non-blocking mode. Later, sock.sendall(), + # sock.recv(), and other socket operations use a select() call and + # handle EWOULDBLOCK/EGAIN on all socket operations. That's how + # timeouts are enforced. + fd_blocking = (timeout is None) + + flag = fcntl.fcntl(sock, fcntl.F_GETFL, os.O_NONBLOCK) + self.assertEqual(not bool(flag & os.O_NONBLOCK), fd_blocking) + def testSetBlocking(self): - # Testing whether set blocking works + # Test setblocking() and settimeout() methods self.serv.setblocking(True) - self.assertIsNone(self.serv.gettimeout()) - self.assertTrue(self.serv.getblocking()) - if fcntl: - self.assertTrue(_is_fd_in_blocking_mode(self.serv)) + self.assert_sock_timeout(self.serv, None) self.serv.setblocking(False) - self.assertEqual(self.serv.gettimeout(), 0.0) - self.assertFalse(self.serv.getblocking()) - if fcntl: - self.assertFalse(_is_fd_in_blocking_mode(self.serv)) + self.assert_sock_timeout(self.serv, 0.0) self.serv.settimeout(None) - self.assertTrue(self.serv.getblocking()) - if fcntl: - self.assertTrue(_is_fd_in_blocking_mode(self.serv)) + self.assert_sock_timeout(self.serv, None) self.serv.settimeout(0) - self.assertFalse(self.serv.getblocking()) - self.assertEqual(self.serv.gettimeout(), 0) - if fcntl: - self.assertFalse(_is_fd_in_blocking_mode(self.serv)) + self.assert_sock_timeout(self.serv, 0) self.serv.settimeout(10) - self.assertTrue(self.serv.getblocking()) - self.assertEqual(self.serv.gettimeout(), 10) - if fcntl: - # When a Python socket has a non-zero timeout, it's - # switched internally to a non-blocking mode. - # Later, sock.sendall(), sock.recv(), and other socket - # operations use a `select()` call and handle EWOULDBLOCK/EGAIN - # on all socket operations. That's how timeouts are - # enforced. - self.assertFalse(_is_fd_in_blocking_mode(self.serv)) + self.assert_sock_timeout(self.serv, 10) self.serv.settimeout(0) - self.assertFalse(self.serv.getblocking()) - if fcntl: - self.assertFalse(_is_fd_in_blocking_mode(self.serv)) - - start = time.time() - try: - self.serv.accept() - except OSError: - pass - end = time.time() - self.assertTrue((end - start) < 1.0, "Error setting non-blocking mode.") + self.assert_sock_timeout(self.serv, 0) def _testSetBlocking(self): pass @@ -4277,8 +4268,10 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): import _testcapi if _testcapi.UINT_MAX >= _testcapi.ULONG_MAX: self.skipTest('needs UINT_MAX < ULONG_MAX') + self.serv.setblocking(False) self.assertEqual(self.serv.gettimeout(), 0.0) + self.serv.setblocking(_testcapi.UINT_MAX + 1) self.assertIsNone(self.serv.gettimeout()) @@ -4288,50 +4281,51 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): 'test needs socket.SOCK_NONBLOCK') @support.requires_linux_version(2, 6, 28) def testInitNonBlocking(self): - # reinit server socket + # create a socket with SOCK_NONBLOCK self.serv.close() - self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM | - socket.SOCK_NONBLOCK) - self.assertFalse(self.serv.getblocking()) - self.assertEqual(self.serv.gettimeout(), 0) - self.port = support.bind_port(self.serv) - self.serv.listen() - # actual testing - start = time.time() - try: - self.serv.accept() - except OSError: - pass - end = time.time() - self.assertTrue((end - start) < 1.0, "Error creating with non-blocking mode.") + self.serv = socket.socket(socket.AF_INET, + socket.SOCK_STREAM | socket.SOCK_NONBLOCK) + self.assert_sock_timeout(self.serv, 0) def _testInitNonBlocking(self): pass - def testInheritFlags(self): - # Issue #7995: when calling accept() on a listening socket with a - # timeout, the resulting socket should not be non-blocking. - self.serv.settimeout(10) - try: + def testInheritFlagsBlocking(self): + # bpo-7995: accept() on a listening socket with a timeout and the + # default timeout is None, the resulting socket must be blocking. + with socket_setdefaulttimeout(None): + self.serv.settimeout(10) conn, addr = self.serv.accept() - message = conn.recv(len(MSG)) - finally: - conn.close() - self.serv.settimeout(None) + self.addCleanup(conn.close) + self.assertIsNone(conn.gettimeout()) - def _testInheritFlags(self): - time.sleep(0.1) + def _testInheritFlagsBlocking(self): + self.cli.connect((HOST, self.port)) + + def testInheritFlagsTimeout(self): + # bpo-7995: accept() on a listening socket with a timeout and the + # default timeout is None, the resulting socket must inherit + # the default timeout. + default_timeout = 20.0 + with socket_setdefaulttimeout(default_timeout): + self.serv.settimeout(10) + conn, addr = self.serv.accept() + self.addCleanup(conn.close) + self.assertEqual(conn.gettimeout(), default_timeout) + + def _testInheritFlagsTimeout(self): self.cli.connect((HOST, self.port)) - time.sleep(0.5) - self.cli.send(MSG) def testAccept(self): # Testing non-blocking accept self.serv.setblocking(0) # connect() didn't start: non-blocking accept() fails + start_time = time.monotonic() with self.assertRaises(BlockingIOError): conn, addr = self.serv.accept() + dt = time.monotonic() - start_time + self.assertLess(dt, 1.0) self.event.set() @@ -4351,15 +4345,6 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): self.cli.connect((HOST, self.port)) - def testConnect(self): - # Testing non-blocking connect - conn, addr = self.serv.accept() - conn.close() - - def _testConnect(self): - self.cli.settimeout(10) - self.cli.connect((HOST, self.port)) - def testRecv(self): # Testing non-blocking recv conn, addr = self.serv.accept()