From 39eb8fa0dbbcd6568fceb7ca59220aa3281e0cc4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 16 Nov 2007 01:24:05 +0000 Subject: [PATCH] This is roughly socket2.diff from issue 1378, with a few changes applied to ssl.py (no need to test whether we can dup any more). Regular sockets no longer have a _base, but we still have explicit reference counting of socket objects for the benefit of makefile(); using duplicate sockets won't work for SSLSocket. --- Include/longobject.h | 9 +++ Lib/socket.py | 73 ++++++++++------------ Lib/ssl.py | 21 ++----- Lib/test/test_socket.py | 9 +++ Modules/socketmodule.c | 132 +++++++++++++++++++++------------------- 5 files changed, 126 insertions(+), 118 deletions(-) diff --git a/Include/longobject.h b/Include/longobject.h index 6bf34096112..16abd0e9033 100644 --- a/Include/longobject.h +++ b/Include/longobject.h @@ -26,6 +26,15 @@ PyAPI_FUNC(size_t) PyLong_AsSize_t(PyObject *); PyAPI_FUNC(unsigned long) PyLong_AsUnsignedLong(PyObject *); PyAPI_FUNC(unsigned long) PyLong_AsUnsignedLongMask(PyObject *); +/* Used by socketmodule.c */ +#if SIZEOF_SOCKET_T <= SIZEOF_LONG +#define PyLong_FromSocket_t(fd) PyLong_FromLong((SOCKET_T)(fd)) +#define PyLong_AsSocket_t(fd) (SOCKET_T)PyLong_AsLong(fd) +#else +#define PyLong_FromSocket_t(fd) PyLong_FromLongLong(((SOCKET_T)(fd)); +#define PyLong_AsSocket_t(fd) (SOCKET_T)PyLong_AsLongLong(fd) +#endif + /* For use by intobject.c only */ PyAPI_DATA(int) _PyLong_DigitValue[256]; diff --git a/Lib/socket.py b/Lib/socket.py index 6a9a381b748..62eb82dcd18 100644 --- a/Lib/socket.py +++ b/Lib/socket.py @@ -79,28 +79,14 @@ if sys.platform.lower().startswith("win"): __all__.append("errorTab") -# True if os.dup() can duplicate socket descriptors. -# (On Windows at least, os.dup only works on files) -_can_dup_socket = hasattr(_socket.socket, "dup") - -if _can_dup_socket: - def fromfd(fd, family=AF_INET, type=SOCK_STREAM, proto=0): - nfd = os.dup(fd) - return socket(family, type, proto, fileno=nfd) - class socket(_socket.socket): """A subclass of _socket.socket adding the makefile() method.""" __slots__ = ["__weakref__", "_io_refs", "_closed"] - if not _can_dup_socket: - __slots__.append("_base") def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None): - if fileno is None: - _socket.socket.__init__(self, family, type, proto) - else: - _socket.socket.__init__(self, family, type, proto, fileno) + _socket.socket.__init__(self, family, type, proto, fileno) self._io_refs = 0 self._closed = False @@ -114,23 +100,29 @@ class socket(_socket.socket): s[7:]) return s + def dup(self): + """dup() -> socket object + + Return a new socket object connected to the same system resource. + """ + fd = dup(self.fileno()) + sock = self.__class__(self.family, self.type, self.proto, fileno=fd) + sock.settimeout(self.gettimeout()) + return sock + def accept(self): - """Wrap accept() to give the connection the right type.""" - conn, addr = _socket.socket.accept(self) - fd = conn.fileno() - nfd = fd - if _can_dup_socket: - nfd = os.dup(fd) - wrapper = socket(self.family, self.type, self.proto, fileno=nfd) - if fd == nfd: - wrapper._base = conn # Keep the base alive - else: - conn.close() - return wrapper, addr + """accept() -> (socket object, address info) + + Wait for an incoming connection. Return a new socket + representing the connection, and the address of the client. + For IP sockets, the address info is a pair (hostaddr, port). + """ + fd, addr = self._accept() + return socket(self.family, self.type, self.proto, fileno=fd), addr def makefile(self, mode="r", buffering=None, *, encoding=None, newline=None): - """Return an I/O stream connected to the socket. + """makefile(...) -> an I/O stream connected to the socket The arguments are as for io.open() after the filename, except the only mode characters supported are 'r', 'w' and 'b'. @@ -184,23 +176,20 @@ class socket(_socket.socket): def close(self): self._closed = True - if self._io_refs < 1: - self._real_close() - - # _real_close calls close on the _socket.socket base class. - - if not _can_dup_socket: - def _real_close(self): - _socket.socket.close(self) - base = getattr(self, "_base", None) - if base is not None: - self._base = None - base.close() - else: - def _real_close(self): + if self._io_refs <= 0: _socket.socket.close(self) +def fromfd(fd, family, type, proto=0): + """ fromfd(fd, family, type[, proto]) -> socket object + + Create a socket object from a duplicate of the given file + descriptor. The remaining arguments are the same as for socket(). + """ + nfd = dup(fd) + return socket(family, type, proto, nfd) + + class SocketIO(io.RawIOBase): """Raw I/O implementation for stream sockets. diff --git a/Lib/ssl.py b/Lib/ssl.py index c2cfa31c444..9d63d12ce3b 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -78,8 +78,8 @@ from _ssl import ( from socket import socket, AF_INET, SOCK_STREAM, error from socket import getnameinfo as _getnameinfo from socket import error as socket_error +from socket import dup as _dup import base64 # for DER-to-PEM translation -_can_dup_socket = hasattr(socket, "dup") class SSLSocket(socket): @@ -99,20 +99,11 @@ class SSLSocket(socket): if sock is not None: # copied this code from socket.accept() fd = sock.fileno() - nfd = fd - if _can_dup_socket: - nfd = os.dup(fd) - try: - socket.__init__(self, family=sock.family, type=sock.type, - proto=sock.proto, fileno=nfd) - except: - if nfd != fd: - os.close(nfd) - else: - if fd != nfd: - sock.close() - sock = None - + nfd = _dup(fd) + socket.__init__(self, family=sock.family, type=sock.type, + proto=sock.proto, fileno=nfd) + sock.close() + sock = None elif fileno is not None: socket.__init__(self, fileno=fileno) else: diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 82eb6e7e3c3..c01d998ae0b 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -575,6 +575,15 @@ class BasicTCPTest(SocketConnectedTest): def _testFromFd(self): self.serv_conn.send(MSG) + def testDup(self): + # Testing dup() + sock = self.cli_conn.dup() + msg = sock.recv(1024) + self.assertEqual(msg, MSG) + + def _testDup(self): + self.serv_conn.send(MSG) + def testShutdown(self): # Testing shutdown() msg = self.cli_conn.recv(1024) diff --git a/Modules/socketmodule.c b/Modules/socketmodule.c index 30e8d227121..f5ad2928652 100644 --- a/Modules/socketmodule.c +++ b/Modules/socketmodule.c @@ -89,12 +89,12 @@ A socket object represents one endpoint of a network connection.\n\ \n\ Methods of socket objects (keyword arguments not allowed):\n\ \n\ -accept() -- accept a connection, returning new socket and client address\n\ +_accept() -- accept connection, returning new socket fd and client address\n\ bind(addr) -- bind the socket to a local address\n\ close() -- close the socket\n\ connect(addr) -- connect the socket to a remote address\n\ connect_ex(addr) -- connect, return an error code instead of an exception\n\ -dup() -- return a new socket object identical to the current one [*]\n\ +_dup() -- return a new socket fd duplicated from fileno()\n\ fileno() -- return underlying file descriptor\n\ getpeername() -- return remote address [*]\n\ getsockname() -- return local address\n\ @@ -327,10 +327,26 @@ const char *inet_ntop(int af, const void *src, char *dst, socklen_t size); #include "getnameinfo.c" #endif -#if defined(MS_WINDOWS) -/* seem to be a few differences in the API */ +#ifdef MS_WINDOWS +/* On Windows a socket is really a handle not an fd */ +static SOCKET +dup_socket(SOCKET handle) +{ + HANDLE newhandle; + + if (!DuplicateHandle(GetCurrentProcess(), (HANDLE)handle, + GetCurrentProcess(), &newhandle, + 0, FALSE, DUPLICATE_SAME_ACCESS)) + { + WSASetLastError(GetLastError()); + return INVALID_SOCKET; + } + return (SOCKET)newhandle; +} #define SOCKETCLOSE closesocket -#define NO_DUP /* Actually it exists on NT 3.5, but what the heck... */ +#else +/* On Unix we can use dup to duplicate the file descriptor of a socket*/ +#define dup_socket(fd) dup(fd) #endif #ifdef MS_WIN32 @@ -628,7 +644,7 @@ internal_select(PySocketSockObject *s, int writing) pollfd.events = writing ? POLLOUT : POLLIN; /* s->sock_timeout is in seconds, timeout in ms */ - timeout = (int)(s->sock_timeout * 1000 + 0.5); + timeout = (int)(s->sock_timeout * 1000 + 0.5); n = poll(&pollfd, 1, timeout); } #else @@ -648,7 +664,7 @@ internal_select(PySocketSockObject *s, int writing) n = select(s->sock_fd+1, &fds, NULL, NULL, &tv); } #endif - + if (n < 0) return -1; if (n == 0) @@ -1423,7 +1439,7 @@ getsockaddrlen(PySocketSockObject *s, socklen_t *len_ret) } -/* s.accept() method */ +/* s._accept() -> (fd, address) */ static PyObject * sock_accept(PySocketSockObject *s) @@ -1457,17 +1473,12 @@ sock_accept(PySocketSockObject *s) if (newfd == INVALID_SOCKET) return s->errorhandler(); - /* Create the new object with unspecified family, - to avoid calls to bind() etc. on it. */ - sock = (PyObject *) new_sockobject(newfd, - s->sock_family, - s->sock_type, - s->sock_proto); - + sock = PyLong_FromSocket_t(newfd); if (sock == NULL) { SOCKETCLOSE(newfd); goto finally; } + addr = makesockaddr(s->sock_fd, SAS2SA(&addrbuf), addrlen, s->sock_proto); if (addr == NULL) @@ -1482,11 +1493,11 @@ finally: } PyDoc_STRVAR(accept_doc, -"accept() -> (socket object, address info)\n\ +"_accept() -> (integer, address info)\n\ \n\ -Wait for an incoming connection. Return a new socket representing the\n\ -connection, and the address of the client. For IP sockets, the address\n\ -info is a pair (hostaddr, port)."); +Wait for an incoming connection. Return a new socket file descriptor\n\ +representing the connection, and the address of the client.\n\ +For IP sockets, the address info is a pair (hostaddr, port)."); /* s.setblocking(flag) method. Argument: False -- non-blocking mode; same as settimeout(0) @@ -1882,11 +1893,7 @@ instead of raising an exception when an error occurs."); static PyObject * sock_fileno(PySocketSockObject *s) { -#if SIZEOF_SOCKET_T <= SIZEOF_LONG - return PyInt_FromLong((long) s->sock_fd); -#else - return PyLong_FromLongLong((PY_LONG_LONG)s->sock_fd); -#endif + return PyLong_FromSocket_t(s->sock_fd); } PyDoc_STRVAR(fileno_doc, @@ -1895,35 +1902,6 @@ PyDoc_STRVAR(fileno_doc, Return the integer file descriptor of the socket."); -#ifndef NO_DUP -/* s.dup() method */ - -static PyObject * -sock_dup(PySocketSockObject *s) -{ - SOCKET_T newfd; - PyObject *sock; - - newfd = dup(s->sock_fd); - if (newfd < 0) - return s->errorhandler(); - sock = (PyObject *) new_sockobject(newfd, - s->sock_family, - s->sock_type, - s->sock_proto); - if (sock == NULL) - SOCKETCLOSE(newfd); - return sock; -} - -PyDoc_STRVAR(dup_doc, -"dup() -> socket object\n\ -\n\ -Return a new socket object connected to the same system resource."); - -#endif - - /* s.getsockname() method */ static PyObject * @@ -2542,7 +2520,7 @@ of the socket (flag == SHUT_WR), or both ends (flag == SHUT_RDWR)."); /* List of methods for socket objects */ static PyMethodDef sock_methods[] = { - {"accept", (PyCFunction)sock_accept, METH_NOARGS, + {"_accept", (PyCFunction)sock_accept, METH_NOARGS, accept_doc}, {"bind", (PyCFunction)sock_bind, METH_O, bind_doc}, @@ -2552,10 +2530,6 @@ static PyMethodDef sock_methods[] = { connect_doc}, {"connect_ex", (PyCFunction)sock_connect_ex, METH_O, connect_ex_doc}, -#ifndef NO_DUP - {"dup", (PyCFunction)sock_dup, METH_NOARGS, - dup_doc}, -#endif {"fileno", (PyCFunction)sock_fileno, METH_NOARGS, fileno_doc}, #ifdef HAVE_GETPEERNAME @@ -2672,8 +2646,8 @@ sock_initobj(PyObject *self, PyObject *args, PyObject *kwds) &family, &type, &proto, &fdobj)) return -1; - if (fdobj != NULL) { - fd = PyLong_AsLongLong(fdobj); + if (fdobj != NULL && fdobj != Py_None) { + fd = PyLong_AsSocket_t(fdobj); if (fd == (SOCKET_T)(-1) && PyErr_Occurred()) return -1; if (fd == INVALID_SOCKET) { @@ -3172,6 +3146,38 @@ PyDoc_STRVAR(getprotobyname_doc, Return the protocol number for the named protocol. (Rarely used.)"); +#ifndef NO_DUP +/* dup() function for socket fds */ + +static PyObject * +socket_dup(PyObject *self, PyObject *fdobj) +{ + SOCKET_T fd, newfd; + PyObject *newfdobj; + + + fd = PyLong_AsSocket_t(fdobj); + if (fd == (SOCKET_T)(-1) && PyErr_Occurred()) + return NULL; + + newfd = dup_socket(fd); + if (newfd == INVALID_SOCKET) + return set_error(); + + newfdobj = PyLong_FromSocket_t(newfd); + if (newfdobj == NULL) + SOCKETCLOSE(newfd); + return newfdobj; +} + +PyDoc_STRVAR(dup_doc, +"dup(integer) -> integer\n\ +\n\ +Duplicate an integer socket file descriptor. This is like os.dup(), but for\n\ +sockets; on some platforms os.dup() won't work for socket file descriptors."); +#endif + + #ifdef HAVE_SOCKETPAIR /* Create a pair of sockets using the socketpair() function. Arguments as for socket() except the default family is AF_UNIX if @@ -3811,6 +3817,10 @@ static PyMethodDef socket_methods[] = { METH_VARARGS, getservbyport_doc}, {"getprotobyname", socket_getprotobyname, METH_VARARGS, getprotobyname_doc}, +#ifndef NO_DUP + {"dup", socket_dup, + METH_O, dup_doc}, +#endif #ifdef HAVE_SOCKETPAIR {"socketpair", socket_socketpair, METH_VARARGS, socketpair_doc}, @@ -4105,7 +4115,7 @@ init_socket(void) PyModule_AddIntConstant(m, "NETLINK_IP6_FW", NETLINK_IP6_FW); #ifdef NETLINK_DNRTMSG PyModule_AddIntConstant(m, "NETLINK_DNRTMSG", NETLINK_DNRTMSG); -#endif +#endif #ifdef NETLINK_TAPBASE PyModule_AddIntConstant(m, "NETLINK_TAPBASE", NETLINK_TAPBASE); #endif