diff --git a/Lib/io.py b/Lib/io.py index 9a4f9561d0e..43695be1f7c 100644 --- a/Lib/io.py +++ b/Lib/io.py @@ -442,34 +442,6 @@ class FileIO(_fileio._FileIO, RawIOBase): return self._mode -class SocketIO(RawIOBase): - - """Raw I/O implementation for stream sockets.""" - - # XXX More docs - - def __init__(self, sock, mode): - assert mode in ("r", "w", "rw") - RawIOBase.__init__(self) - self._sock = sock - self._mode = mode - - def readinto(self, b): - return self._sock.recv_into(b) - - def write(self, b): - return self._sock.send(b) - - def readable(self): - return "r" in self._mode - - def writable(self): - return "w" in self._mode - - def fileno(self): - return self._sock.fileno() - - class BufferedIOBase(IOBase): """Base class for buffered IO objects. diff --git a/Lib/socket.py b/Lib/socket.py index 8d3508a81fd..1b3920adf90 100644 --- a/Lib/socket.py +++ b/Lib/socket.py @@ -89,22 +89,67 @@ if sys.platform.lower().startswith("win"): # True if os.dup() can duplicate socket descriptors. # (On Windows at least, os.dup only works on files) -_can_dup_socket = hasattr(_socket, "dup") +_can_dup_socket = hasattr(_socket.socket, "dup") if _can_dup_socket: def fromfd(fd, family=AF_INET, type=SOCK_STREAM, proto=0): nfd = os.dup(fd) return socket(family, type, proto, fileno=nfd) +class SocketCloser: + + """Helper to manage socket close() logic for makefile(). + + The OS socket should not be closed until the socket and all + of its makefile-children are closed. If the refcount is zero + when socket.close() is called, this is easy: Just close the + socket. If the refcount is non-zero when socket.close() is + called, then the real close should not occur until the last + makefile-child is closed. + """ + + def __init__(self, sock): + self._sock = sock + self._makefile_refs = 0 + # Test whether the socket is open. + try: + sock.fileno() + self._socket_open = True + except error: + self._socket_open = False + + def socket_close(self): + self._socket_open = False + self.close() + + def makefile_open(self): + self._makefile_refs += 1 + + def makefile_close(self): + self._makefile_refs -= 1 + self.close() + + def close(self): + if not (self._socket_open or self._makefile_refs): + self._sock._real_close() + class socket(_socket.socket): """A subclass of _socket.socket adding the makefile() method.""" - __slots__ = ["__weakref__"] + __slots__ = ["__weakref__", "_closer"] if not _can_dup_socket: __slots__.append("_base") + def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None): + if fileno is None: + _socket.socket.__init__(self, family, type, proto) + else: + _socket.socket.__init__(self, family, type, proto, fileno) + # Defer creating a SocketCloser until makefile() is actually called. + self._closer = None + def __repr__(self): """Wrap __repr__() to reveal the real class name.""" s = _socket.socket.__repr__(self) @@ -128,14 +173,6 @@ class socket(_socket.socket): conn.close() return wrapper, addr - if not _can_dup_socket: - def close(self): - """Wrap close() to close the _base as well.""" - _socket.socket.close(self) - base = getattr(self, "_base", None) - if base is not None: - base.close() - def makefile(self, mode="r", buffering=None, *, encoding=None, newline=None): """Return an I/O stream connected to the socket. @@ -156,7 +193,9 @@ class socket(_socket.socket): rawmode += "r" if writing: rawmode += "w" - raw = io.SocketIO(self, rawmode) + if self._closer is None: + self._closer = SocketCloser(self) + raw = SocketIO(self, rawmode, self._closer) if buffering is None: buffering = -1 if buffering < 0: @@ -183,6 +222,65 @@ class socket(_socket.socket): text.mode = mode return text + def close(self): + if self._closer is None: + self._real_close() + else: + self._closer.socket_close() + + # _real_close calls close on the _socket.socket base class. + + if not _can_dup_socket: + def _real_close(self): + _socket.socket.close(self) + base = getattr(self, "_base", None) + if base is not None: + self._base = None + base.close() + else: + def _real_close(self): + _socket.socket.close(self) + + +class SocketIO(io.RawIOBase): + + """Raw I/O implementation for stream sockets. + + This class supports the makefile() method on sockets. It provides + the raw I/O interface on top of a socket object. + """ + + # XXX More docs + + def __init__(self, sock, mode, closer): + assert mode in ("r", "w", "rw") + io.RawIOBase.__init__(self) + self._sock = sock + self._mode = mode + self._closer = closer + closer.makefile_open() + + def readinto(self, b): + return self._sock.recv_into(b) + + def write(self, b): + return self._sock.send(b) + + def readable(self): + return "r" in self._mode + + def writable(self): + return "w" in self._mode + + def fileno(self): + return self._sock.fileno() + + def close(self): + if self.closed: + return + self._closer.makefile_close() + io.RawIOBase.close(self) + def getfqdn(name=''): """Get fully qualified domain name from name. diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index f2b74ee996b..a8b65c410e1 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -163,6 +163,11 @@ class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest): self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) class SocketConnectedTest(ThreadedTCPSocketTest): + """Socket tests for client-server connection. + + self.cli_conn is a client socket connected to the server. The + setUp() method guarantees that it is connected to the server. + """ def __init__(self, methodName='runTest'): ThreadedTCPSocketTest.__init__(self, methodName=methodName) @@ -618,6 +623,10 @@ class TCPCloserTest(ThreadedTCPSocketTest): self.assertEqual(read, [sd]) self.assertEqual(sd.recv(1), b'') + # Calling close() many times should be safe. + conn.close() + conn.close() + def _testClose(self): self.cli.connect((HOST, PORT)) time.sleep(1.0) @@ -710,6 +719,16 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): self.cli.send(MSG) class FileObjectClassTestCase(SocketConnectedTest): + """Unit tests for the object returned by socket.makefile() + + self.serv_file is the io object returned by makefile() on + the client connection. You can read from this file to + get output from the server. + + self.cli_file is the io object returned by makefile() on the + server connection. You can write to this file to send output + to the client. + """ bufsize = -1 # Use default buffer size @@ -779,6 +798,26 @@ class FileObjectClassTestCase(SocketConnectedTest): self.cli_file.write(MSG) self.cli_file.flush() + def testCloseAfterMakefile(self): + # The file returned by makefile should keep the socket open. + self.cli_conn.close() + # read until EOF + msg = self.serv_file.read() + self.assertEqual(msg, MSG) + + def _testCloseAfterMakefile(self): + self.cli_file.write(MSG) + self.cli_file.flush() + + def testMakefileAfterMakefileClose(self): + self.serv_file.close() + msg = self.cli_conn.recv(len(MSG)) + self.assertEqual(msg, MSG) + + def _testMakefileAfterMakefileClose(self): + self.cli_file.write(MSG) + self.cli_file.flush() + def testClosedAttr(self): self.assert_(not self.serv_file.closed)