bpo-37193: remove thread objects which finished process its request (GH-13893)

* bpo-37193: remove the thread which finished process request from threads list

* rename variable t to thread.

* don't remove thread from list if it is daemon.

* use lock to protect self._threads.

* use finally block in case of exception from shutdown_request().

* check "not thread.daemon" before lock to avoid holding the lock if it's unnecessary.

* fix the place of _threads_lock.

* separate code to remove a current thread into a function.

* check ValueError when removing thread.

* fix wrong code which all instance shared same lock.

* Extract thread management into a _Threads class to encapsulate atomic operations and separate concerns.

* Replace multiple references of 'block_on_close' with one, avoiding the possibility that 'block_on_close' could change during the course of processing requests. Now, there's exactly one _threads object with behavior fixed for the duration.

* Add docstrings to private classes.

* Add test to ensure that a ThreadingTCPServer can be closed without serving any requests.

* Use _NoThreads as the default value. Fixes AttributeError when server is closed without serving any requests.

* Add blurb

* Add test capturing failure.

Co-authored-by: Jason R. Coombs <jaraco@jaraco.com>
This commit is contained in:
MARUYAMA Norihiro 2020-11-02 08:51:04 +09:00 committed by GitHub
parent e662c398d8
commit c415590212
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 86 additions and 13 deletions

View File

@ -128,6 +128,7 @@ import selectors
import os
import sys
import threading
import contextlib
from io import BufferedIOBase
from time import monotonic as time
@ -628,6 +629,55 @@ if hasattr(os, "fork"):
self.collect_children(blocking=self.block_on_close)
class _Threads(list):
"""
Joinable list of all non-daemon threads.
"""
def __init__(self):
self._lock = threading.Lock()
def append(self, thread):
if thread.daemon:
return
with self._lock:
super().append(thread)
def remove(self, thread):
with self._lock:
# should not happen, but safe to ignore
with contextlib.suppress(ValueError):
super().remove(thread)
def remove_current(self):
"""Remove a current non-daemon thread."""
thread = threading.current_thread()
if not thread.daemon:
self.remove(thread)
def pop_all(self):
with self._lock:
self[:], result = [], self[:]
return result
def join(self):
for thread in self.pop_all():
thread.join()
class _NoThreads:
"""
Degenerate version of _Threads.
"""
def append(self, thread):
pass
def join(self):
pass
def remove_current(self):
pass
class ThreadingMixIn:
"""Mix-in class to handle each request in a new thread."""
@ -636,9 +686,9 @@ class ThreadingMixIn:
daemon_threads = False
# If true, server_close() waits until all non-daemonic threads terminate.
block_on_close = True
# For non-daemonic threads, list of threading.Threading objects
# Threads object
# used by server_close() to wait for all threads completion.
_threads = None
_threads = _NoThreads()
def process_request_thread(self, request, client_address):
"""Same as in BaseServer but as a thread.
@ -651,27 +701,24 @@ class ThreadingMixIn:
except Exception:
self.handle_error(request, client_address)
finally:
try:
self.shutdown_request(request)
finally:
self._threads.remove_current()
def process_request(self, request, client_address):
"""Start a new thread to process the request."""
if self.block_on_close:
vars(self).setdefault('_threads', _Threads())
t = threading.Thread(target = self.process_request_thread,
args = (request, client_address))
t.daemon = self.daemon_threads
if not t.daemon and self.block_on_close:
if self._threads is None:
self._threads = []
self._threads.append(t)
t.start()
def server_close(self):
super().server_close()
if self.block_on_close:
threads = self._threads
self._threads = None
if threads:
for thread in threads:
thread.join()
self._threads.join()
if hasattr(os, "fork"):

View File

@ -277,6 +277,13 @@ class SocketServerTest(unittest.TestCase):
t.join()
s.server_close()
def test_close_immediately(self):
class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
pass
server = MyServer((HOST, 0), lambda: None)
server.server_close()
def test_tcpserver_bind_leak(self):
# Issue #22435: the server socket wouldn't be closed if bind()/listen()
# failed.
@ -491,6 +498,23 @@ class MiscTestCase(unittest.TestCase):
self.assertEqual(server.shutdown_called, 1)
server.server_close()
def test_threads_reaped(self):
"""
In #37193, users reported a memory leak
due to the saving of every request thread. Ensure that the
threads are cleaned up after the requests complete.
"""
class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
pass
server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
for n in range(10):
with socket.create_connection(server.server_address):
server.handle_request()
[thread.join() for thread in server._threads]
self.assertEqual(len(server._threads), 0)
server.server_close()
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,2 @@
Fixed memory leak in ``socketserver.ThreadingMixIn`` introduced in Python
3.7.