From e991270363435da12049ecfe70bb69bd9c14b535 Mon Sep 17 00:00:00 2001 From: Dima Tisnek Date: Mon, 17 Dec 2018 22:07:55 +0900 Subject: [PATCH] bpo-35415: validate fileno argument to socket.socket (GH-10917) https://bugs.python.org/issue35415 --- Lib/test/test_socket.py | 49 +++++++++++++-- .../2018-12-06-14-44-21.bpo-35415.-HoK3d.rst | 1 + Modules/socketmodule.c | 61 +++++++++++-------- 3 files changed, 80 insertions(+), 31 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2018-12-06-14-44-21.bpo-35415.-HoK3d.rst diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 626a0779735..bfbd1cc2a18 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -1700,7 +1700,6 @@ class GeneralModuleTests(unittest.TestCase): s.setblocking(False) self.assertEqual(s.type, socket.SOCK_STREAM) - @unittest.skipIf(os.name == 'nt', 'Will not work on Windows') def test_unknown_socket_family_repr(self): # Test that when created with a family that's not one of the known # AF_*/SOCK_* constants, socket.family just returns the number. @@ -1708,10 +1707,8 @@ class GeneralModuleTests(unittest.TestCase): # To do this we fool socket.socket into believing it already has an # open fd because on this path it doesn't actually verify the family and # type and populates the socket object. - # - # On Windows this trick won't work, so the test is skipped. - fd, path = tempfile.mkstemp() - self.addCleanup(os.unlink, path) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + fd = sock.detach() unknown_family = max(socket.AddressFamily.__members__.values()) + 1 unknown_type = max( @@ -1785,6 +1782,48 @@ class GeneralModuleTests(unittest.TestCase): s.bind(os.path.join(tmpdir, 'socket')) self._test_socket_fileno(s, socket.AF_UNIX, socket.SOCK_STREAM) + def test_socket_fileno_rejects_float(self): + with self.assertRaisesRegex(TypeError, "integer argument expected"): + socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=42.5) + + def test_socket_fileno_rejects_other_types(self): + with self.assertRaisesRegex(TypeError, "integer is required"): + socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno="foo") + + def test_socket_fileno_rejects_invalid_socket(self): + with self.assertRaisesRegex(ValueError, "negative file descriptor"): + socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=-1) + + @unittest.skipIf(os.name == "nt", "Windows disallows -1 only") + def test_socket_fileno_rejects_negative(self): + with self.assertRaisesRegex(ValueError, "negative file descriptor"): + socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=-42) + + def test_socket_fileno_requires_valid_fd(self): + WSAENOTSOCK = 10038 + with self.assertRaises(OSError) as cm: + socket.socket(fileno=support.make_bad_fd()) + self.assertIn(cm.exception.errno, (errno.EBADF, WSAENOTSOCK)) + + with self.assertRaises(OSError) as cm: + socket.socket( + socket.AF_INET, + socket.SOCK_STREAM, + fileno=support.make_bad_fd()) + self.assertIn(cm.exception.errno, (errno.EBADF, WSAENOTSOCK)) + + def test_socket_fileno_requires_socket_fd(self): + with tempfile.NamedTemporaryFile() as afile: + with self.assertRaises(OSError): + socket.socket(fileno=afile.fileno()) + + with self.assertRaises(OSError) as cm: + socket.socket( + socket.AF_INET, + socket.SOCK_STREAM, + fileno=afile.fileno()) + self.assertEqual(cm.exception.errno, errno.ENOTSOCK) + @unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.') class BasicCANTest(unittest.TestCase): diff --git a/Misc/NEWS.d/next/Library/2018-12-06-14-44-21.bpo-35415.-HoK3d.rst b/Misc/NEWS.d/next/Library/2018-12-06-14-44-21.bpo-35415.-HoK3d.rst new file mode 100644 index 00000000000..ab053df4f74 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-12-06-14-44-21.bpo-35415.-HoK3d.rst @@ -0,0 +1 @@ +Validate fileno= argument to socket.socket(). diff --git a/Modules/socketmodule.c b/Modules/socketmodule.c index 73d3e1add3e..66e52f84eb9 100644 --- a/Modules/socketmodule.c +++ b/Modules/socketmodule.c @@ -5018,28 +5018,45 @@ sock_initobj(PyObject *self, PyObject *args, PyObject *kwds) else #endif { - fd = PyLong_AsSocket_t(fdobj); - if (fd == (SOCKET_T)(-1) && PyErr_Occurred()) - return -1; - if (fd == INVALID_SOCKET) { - PyErr_SetString(PyExc_ValueError, - "can't use invalid socket value"); + + if (PyFloat_Check(fdobj)) { + PyErr_SetString(PyExc_TypeError, + "integer argument expected, got float"); return -1; } - if (family == -1) { - sock_addr_t addrbuf; - socklen_t addrlen = sizeof(sock_addr_t); - - memset(&addrbuf, 0, addrlen); - if (getsockname(fd, SAS2SA(&addrbuf), &addrlen) == 0) { - family = SAS2SA(&addrbuf)->sa_family; - } else { + fd = PyLong_AsSocket_t(fdobj); + if (fd == (SOCKET_T)(-1) && PyErr_Occurred()) + return -1; #ifdef MS_WINDOWS - PyErr_SetFromWindowsErrWithFilename(0, "family"); + if (fd == INVALID_SOCKET) { #else - PyErr_SetFromErrnoWithFilename(PyExc_OSError, "family"); + if (fd < 0) { #endif + PyErr_SetString(PyExc_ValueError, "negative file descriptor"); + return -1; + } + + /* validate that passed file descriptor is valid and a socket. */ + sock_addr_t addrbuf; + socklen_t addrlen = sizeof(sock_addr_t); + + memset(&addrbuf, 0, addrlen); + if (getsockname(fd, SAS2SA(&addrbuf), &addrlen) == 0) { + if (family == -1) { + family = SAS2SA(&addrbuf)->sa_family; + } + } else { +#ifdef MS_WINDOWS + /* getsockname() on an unbound socket is an error on Windows. + Invalid descriptor and not a socket is same error code. + Error out if family must be resolved, or bad descriptor. */ + if (family == -1 || CHECK_ERRNO(ENOTSOCK)) { +#else + /* getsockname() is not supported for SOL_ALG on Linux. */ + if (family == -1 || CHECK_ERRNO(EBADF) || CHECK_ERRNO(ENOTSOCK)) { +#endif + set_error(); return -1; } } @@ -5052,11 +5069,7 @@ sock_initobj(PyObject *self, PyObject *args, PyObject *kwds) { type = tmp; } else { -#ifdef MS_WINDOWS - PyErr_SetFromWindowsErrWithFilename(0, "type"); -#else - PyErr_SetFromErrnoWithFilename(PyExc_OSError, "type"); -#endif + set_error(); return -1; } } @@ -5072,11 +5085,7 @@ sock_initobj(PyObject *self, PyObject *args, PyObject *kwds) { proto = tmp; } else { -#ifdef MS_WINDOWS - PyErr_SetFromWindowsErrWithFilename(0, "protocol"); -#else - PyErr_SetFromErrnoWithFilename(PyExc_OSError, "protocol"); -#endif + set_error(); return -1; } }