bpo-37193: remove thread objects which finished process its request (GH-13893)

* bpo-37193: remove the thread which finished process request from threads list

* rename variable t to thread.

* don't remove thread from list if it is daemon.

* use lock to protect self._threads.

* use finally block in case of exception from shutdown_request().

* check "not thread.daemon" before lock to avoid holding the lock if it's unnecessary.

* fix the place of _threads_lock.

* separate code to remove a current thread into a function.

* check ValueError when removing thread.

* fix wrong code which all instance shared same lock.

* Extract thread management into a _Threads class to encapsulate atomic operations and separate concerns.

* Replace multiple references of 'block_on_close' with one, avoiding the possibility that 'block_on_close' could change during the course of processing requests. Now, there's exactly one _threads object with behavior fixed for the duration.

* Add docstrings to private classes.

* Add test to ensure that a ThreadingTCPServer can be closed without serving any requests.

* Use _NoThreads as the default value. Fixes AttributeError when server is closed without serving any requests.

* Add blurb

* Add test capturing failure.

Co-authored-by: Jason R. Coombs <jaraco@jaraco.com>
This commit is contained in:
MARUYAMA Norihiro 2020-11-02 08:51:04 +09:00 committed by GitHub
parent e662c398d8
commit c415590212
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 86 additions and 13 deletions

View File

@ -128,6 +128,7 @@ import selectors
import os import os
import sys import sys
import threading import threading
import contextlib
from io import BufferedIOBase from io import BufferedIOBase
from time import monotonic as time from time import monotonic as time
@ -628,6 +629,55 @@ if hasattr(os, "fork"):
self.collect_children(blocking=self.block_on_close) self.collect_children(blocking=self.block_on_close)
class _Threads(list):
"""
Joinable list of all non-daemon threads.
"""
def __init__(self):
self._lock = threading.Lock()
def append(self, thread):
if thread.daemon:
return
with self._lock:
super().append(thread)
def remove(self, thread):
with self._lock:
# should not happen, but safe to ignore
with contextlib.suppress(ValueError):
super().remove(thread)
def remove_current(self):
"""Remove a current non-daemon thread."""
thread = threading.current_thread()
if not thread.daemon:
self.remove(thread)
def pop_all(self):
with self._lock:
self[:], result = [], self[:]
return result
def join(self):
for thread in self.pop_all():
thread.join()
class _NoThreads:
"""
Degenerate version of _Threads.
"""
def append(self, thread):
pass
def join(self):
pass
def remove_current(self):
pass
class ThreadingMixIn: class ThreadingMixIn:
"""Mix-in class to handle each request in a new thread.""" """Mix-in class to handle each request in a new thread."""
@ -636,9 +686,9 @@ class ThreadingMixIn:
daemon_threads = False daemon_threads = False
# If true, server_close() waits until all non-daemonic threads terminate. # If true, server_close() waits until all non-daemonic threads terminate.
block_on_close = True block_on_close = True
# For non-daemonic threads, list of threading.Threading objects # Threads object
# used by server_close() to wait for all threads completion. # used by server_close() to wait for all threads completion.
_threads = None _threads = _NoThreads()
def process_request_thread(self, request, client_address): def process_request_thread(self, request, client_address):
"""Same as in BaseServer but as a thread. """Same as in BaseServer but as a thread.
@ -651,27 +701,24 @@ class ThreadingMixIn:
except Exception: except Exception:
self.handle_error(request, client_address) self.handle_error(request, client_address)
finally: finally:
self.shutdown_request(request) try:
self.shutdown_request(request)
finally:
self._threads.remove_current()
def process_request(self, request, client_address): def process_request(self, request, client_address):
"""Start a new thread to process the request.""" """Start a new thread to process the request."""
if self.block_on_close:
vars(self).setdefault('_threads', _Threads())
t = threading.Thread(target = self.process_request_thread, t = threading.Thread(target = self.process_request_thread,
args = (request, client_address)) args = (request, client_address))
t.daemon = self.daemon_threads t.daemon = self.daemon_threads
if not t.daemon and self.block_on_close: self._threads.append(t)
if self._threads is None:
self._threads = []
self._threads.append(t)
t.start() t.start()
def server_close(self): def server_close(self):
super().server_close() super().server_close()
if self.block_on_close: self._threads.join()
threads = self._threads
self._threads = None
if threads:
for thread in threads:
thread.join()
if hasattr(os, "fork"): if hasattr(os, "fork"):

View File

@ -277,6 +277,13 @@ class SocketServerTest(unittest.TestCase):
t.join() t.join()
s.server_close() s.server_close()
def test_close_immediately(self):
class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
pass
server = MyServer((HOST, 0), lambda: None)
server.server_close()
def test_tcpserver_bind_leak(self): def test_tcpserver_bind_leak(self):
# Issue #22435: the server socket wouldn't be closed if bind()/listen() # Issue #22435: the server socket wouldn't be closed if bind()/listen()
# failed. # failed.
@ -491,6 +498,23 @@ class MiscTestCase(unittest.TestCase):
self.assertEqual(server.shutdown_called, 1) self.assertEqual(server.shutdown_called, 1)
server.server_close() server.server_close()
def test_threads_reaped(self):
"""
In #37193, users reported a memory leak
due to the saving of every request thread. Ensure that the
threads are cleaned up after the requests complete.
"""
class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
pass
server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
for n in range(10):
with socket.create_connection(server.server_address):
server.handle_request()
[thread.join() for thread in server._threads]
self.assertEqual(len(server._threads), 0)
server.server_close()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -0,0 +1,2 @@
Fixed memory leak in ``socketserver.ThreadingMixIn`` introduced in Python
3.7.