Reap dead threads when opening a new one.

This commit is contained in:
Jason R. Coombs 2020-11-03 20:31:32 -05:00
parent 365ef2f12d
commit 97c70544c9
2 changed files with 14 additions and 21 deletions

View File

@ -642,18 +642,6 @@ class _Threads(list):
with self._lock:
super().append(thread)
def remove(self, thread):
with self._lock:
# should not happen, but safe to ignore
with contextlib.suppress(ValueError):
super().remove(thread)
def remove_current(self):
"""Remove a current non-daemon thread."""
thread = threading.current_thread()
if not thread.daemon:
self.remove(thread)
def pop_all(self):
with self._lock:
self[:], result = [], self[:]
@ -663,6 +651,14 @@ class _Threads(list):
for thread in self.pop_all():
thread.join()
def reap(self):
with self._lock:
dead = [thread for thread in self if not thread.is_alive()]
for thread in dead:
# should not happen, but safe to ignore
with contextlib.suppress(ValueError):
self.remove(thread)
class _NoThreads:
"""
@ -674,7 +670,7 @@ class _NoThreads:
def join(self):
pass
def remove_current(self):
def reap(self):
pass
@ -701,15 +697,13 @@ class ThreadingMixIn:
except Exception:
self.handle_error(request, client_address)
finally:
try:
self.shutdown_request(request)
finally:
self._threads.remove_current()
self.shutdown_request(request)
def process_request(self, request, client_address):
"""Start a new thread to process the request."""
if self.block_on_close:
vars(self).setdefault('_threads', _Threads())
self._threads.reap()
t = threading.Thread(target = self.process_request_thread,
args = (request, client_address))
t.daemon = self.daemon_threads

View File

@ -501,8 +501,8 @@ class MiscTestCase(unittest.TestCase):
def test_threads_reaped(self):
"""
In #37193, users reported a memory leak
due to the saving of every request thread. Ensure that the
threads are cleaned up after the requests complete.
due to the saving of every request thread. Ensure that
not all threads are kept forever.
"""
class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
pass
@ -511,8 +511,7 @@ class MiscTestCase(unittest.TestCase):
for n in range(10):
with socket.create_connection(server.server_address):
server.handle_request()
[thread.join() for thread in server._threads]
self.assertEqual(len(server._threads), 0)
self.assertLess(len(server._threads), 10)
server.server_close()