gh-89967: make WeakKeyDictionary and WeakValueDictionary thread safe (#125325)

Make `WeakKeyDictionary` and `WeakValueDictionary` thread safe by copying the underlying the dict before iterating over it.
This commit is contained in:
Kumar Aditya 2024-10-13 21:05:05 +05:30 committed by GitHub
parent 08489325d1
commit cd0f9d111a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 50 additions and 174 deletions

View File

@ -8,31 +8,6 @@ from types import GenericAlias
__all__ = ['WeakSet']
class _IterationGuard:
# This context manager registers itself in the current iterators of the
# weak container, such as to delay all removals until the context manager
# exits.
# This technique should be relatively thread-safe (since sets are).
def __init__(self, weakcontainer):
# Don't create cycles
self.weakcontainer = ref(weakcontainer)
def __enter__(self):
w = self.weakcontainer()
if w is not None:
w._iterating.add(self)
return self
def __exit__(self, e, t, b):
w = self.weakcontainer()
if w is not None:
s = w._iterating
s.remove(self)
if not s:
w._commit_removals()
class WeakSet:
def __init__(self, data=None):
self.data = set()

View File

@ -19,7 +19,7 @@ from _weakref import (
ReferenceType,
_remove_dead_weakref)
from _weakrefset import WeakSet, _IterationGuard
from _weakrefset import WeakSet
import _collections_abc # Import after _weakref to avoid circular import.
import sys
@ -105,34 +105,14 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
def remove(wr, selfref=ref(self), _atomic_removal=_remove_dead_weakref):
self = selfref()
if self is not None:
if self._iterating:
self._pending_removals.append(wr.key)
else:
# Atomic removal is necessary since this function
# can be called asynchronously by the GC
_atomic_removal(self.data, wr.key)
# Atomic removal is necessary since this function
# can be called asynchronously by the GC
_atomic_removal(self.data, wr.key)
self._remove = remove
# A list of keys to be removed
self._pending_removals = []
self._iterating = set()
self.data = {}
self.update(other, **kw)
def _commit_removals(self, _atomic_removal=_remove_dead_weakref):
pop = self._pending_removals.pop
d = self.data
# We shouldn't encounter any KeyError, because this method should
# always be called *before* mutating the dict.
while True:
try:
key = pop()
except IndexError:
return
_atomic_removal(d, key)
def __getitem__(self, key):
if self._pending_removals:
self._commit_removals()
o = self.data[key]()
if o is None:
raise KeyError(key)
@ -140,18 +120,12 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
return o
def __delitem__(self, key):
if self._pending_removals:
self._commit_removals()
del self.data[key]
def __len__(self):
if self._pending_removals:
self._commit_removals()
return len(self.data)
def __contains__(self, key):
if self._pending_removals:
self._commit_removals()
try:
o = self.data[key]()
except KeyError:
@ -162,38 +136,28 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
return "<%s at %#x>" % (self.__class__.__name__, id(self))
def __setitem__(self, key, value):
if self._pending_removals:
self._commit_removals()
self.data[key] = KeyedRef(value, self._remove, key)
def copy(self):
if self._pending_removals:
self._commit_removals()
new = WeakValueDictionary()
with _IterationGuard(self):
for key, wr in self.data.items():
o = wr()
if o is not None:
new[key] = o
for key, wr in self.data.copy().items():
o = wr()
if o is not None:
new[key] = o
return new
__copy__ = copy
def __deepcopy__(self, memo):
from copy import deepcopy
if self._pending_removals:
self._commit_removals()
new = self.__class__()
with _IterationGuard(self):
for key, wr in self.data.items():
o = wr()
if o is not None:
new[deepcopy(key, memo)] = o
for key, wr in self.data.copy().items():
o = wr()
if o is not None:
new[deepcopy(key, memo)] = o
return new
def get(self, key, default=None):
if self._pending_removals:
self._commit_removals()
try:
wr = self.data[key]
except KeyError:
@ -207,21 +171,15 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
return o
def items(self):
if self._pending_removals:
self._commit_removals()
with _IterationGuard(self):
for k, wr in self.data.items():
v = wr()
if v is not None:
yield k, v
for k, wr in self.data.copy().items():
v = wr()
if v is not None:
yield k, v
def keys(self):
if self._pending_removals:
self._commit_removals()
with _IterationGuard(self):
for k, wr in self.data.items():
if wr() is not None:
yield k
for k, wr in self.data.copy().items():
if wr() is not None:
yield k
__iter__ = keys
@ -235,23 +193,15 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
keep the values around longer than needed.
"""
if self._pending_removals:
self._commit_removals()
with _IterationGuard(self):
yield from self.data.values()
yield from self.data.copy().values()
def values(self):
if self._pending_removals:
self._commit_removals()
with _IterationGuard(self):
for wr in self.data.values():
obj = wr()
if obj is not None:
yield obj
for wr in self.data.copy().values():
obj = wr()
if obj is not None:
yield obj
def popitem(self):
if self._pending_removals:
self._commit_removals()
while True:
key, wr = self.data.popitem()
o = wr()
@ -259,8 +209,6 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
return key, o
def pop(self, key, *args):
if self._pending_removals:
self._commit_removals()
try:
o = self.data.pop(key)()
except KeyError:
@ -279,16 +227,12 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
except KeyError:
o = None
if o is None:
if self._pending_removals:
self._commit_removals()
self.data[key] = KeyedRef(default, self._remove, key)
return default
else:
return o
def update(self, other=None, /, **kwargs):
if self._pending_removals:
self._commit_removals()
d = self.data
if other is not None:
if not hasattr(other, "items"):
@ -308,9 +252,7 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
keep the values around longer than needed.
"""
if self._pending_removals:
self._commit_removals()
return list(self.data.values())
return list(self.data.copy().values())
def __ior__(self, other):
self.update(other)
@ -369,57 +311,22 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
def remove(k, selfref=ref(self)):
self = selfref()
if self is not None:
if self._iterating:
self._pending_removals.append(k)
else:
try:
del self.data[k]
except KeyError:
pass
try:
del self.data[k]
except KeyError:
pass
self._remove = remove
# A list of dead weakrefs (keys to be removed)
self._pending_removals = []
self._iterating = set()
self._dirty_len = False
if dict is not None:
self.update(dict)
def _commit_removals(self):
# NOTE: We don't need to call this method before mutating the dict,
# because a dead weakref never compares equal to a live weakref,
# even if they happened to refer to equal objects.
# However, it means keys may already have been removed.
pop = self._pending_removals.pop
d = self.data
while True:
try:
key = pop()
except IndexError:
return
try:
del d[key]
except KeyError:
pass
def _scrub_removals(self):
d = self.data
self._pending_removals = [k for k in self._pending_removals if k in d]
self._dirty_len = False
def __delitem__(self, key):
self._dirty_len = True
del self.data[ref(key)]
def __getitem__(self, key):
return self.data[ref(key)]
def __len__(self):
if self._dirty_len and self._pending_removals:
# self._pending_removals may still contain keys which were
# explicitly removed, we have to scrub them (see issue #21173).
self._scrub_removals()
return len(self.data) - len(self._pending_removals)
return len(self.data)
def __repr__(self):
return "<%s at %#x>" % (self.__class__.__name__, id(self))
@ -429,11 +336,10 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
def copy(self):
new = WeakKeyDictionary()
with _IterationGuard(self):
for key, value in self.data.items():
o = key()
if o is not None:
new[o] = value
for key, value in self.data.copy().items():
o = key()
if o is not None:
new[o] = value
return new
__copy__ = copy
@ -441,11 +347,10 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
def __deepcopy__(self, memo):
from copy import deepcopy
new = self.__class__()
with _IterationGuard(self):
for key, value in self.data.items():
o = key()
if o is not None:
new[o] = deepcopy(value, memo)
for key, value in self.data.copy().items():
o = key()
if o is not None:
new[o] = deepcopy(value, memo)
return new
def get(self, key, default=None):
@ -459,26 +364,23 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
return wr in self.data
def items(self):
with _IterationGuard(self):
for wr, value in self.data.items():
key = wr()
if key is not None:
yield key, value
for wr, value in self.data.copy().items():
key = wr()
if key is not None:
yield key, value
def keys(self):
with _IterationGuard(self):
for wr in self.data:
obj = wr()
if obj is not None:
yield obj
for wr in self.data.copy():
obj = wr()
if obj is not None:
yield obj
__iter__ = keys
def values(self):
with _IterationGuard(self):
for wr, value in self.data.items():
if wr() is not None:
yield value
for wr, value in self.data.copy().items():
if wr() is not None:
yield value
def keyrefs(self):
"""Return a list of weak references to the keys.
@ -493,7 +395,6 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
return list(self.data)
def popitem(self):
self._dirty_len = True
while True:
key, value = self.data.popitem()
o = key()
@ -501,7 +402,6 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
return o, value
def pop(self, key, *args):
self._dirty_len = True
return self.data.pop(ref(key), *args)
def setdefault(self, key, default=None):

View File

@ -0,0 +1 @@
Make :class:`~weakref.WeakKeyDictionary` and :class:`~weakref.WeakValueDictionary` safe against concurrent mutations from other threads. Patch by Kumar Aditya.