mirror of https://github.com/python/cpython
gh-89967: make WeakKeyDictionary and WeakValueDictionary thread safe (#125325)
Make `WeakKeyDictionary` and `WeakValueDictionary` thread safe by copying the underlying the dict before iterating over it.
This commit is contained in:
parent
08489325d1
commit
cd0f9d111a
|
@ -8,31 +8,6 @@ from types import GenericAlias
|
|||
__all__ = ['WeakSet']
|
||||
|
||||
|
||||
class _IterationGuard:
|
||||
# This context manager registers itself in the current iterators of the
|
||||
# weak container, such as to delay all removals until the context manager
|
||||
# exits.
|
||||
# This technique should be relatively thread-safe (since sets are).
|
||||
|
||||
def __init__(self, weakcontainer):
|
||||
# Don't create cycles
|
||||
self.weakcontainer = ref(weakcontainer)
|
||||
|
||||
def __enter__(self):
|
||||
w = self.weakcontainer()
|
||||
if w is not None:
|
||||
w._iterating.add(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, e, t, b):
|
||||
w = self.weakcontainer()
|
||||
if w is not None:
|
||||
s = w._iterating
|
||||
s.remove(self)
|
||||
if not s:
|
||||
w._commit_removals()
|
||||
|
||||
|
||||
class WeakSet:
|
||||
def __init__(self, data=None):
|
||||
self.data = set()
|
||||
|
|
198
Lib/weakref.py
198
Lib/weakref.py
|
@ -19,7 +19,7 @@ from _weakref import (
|
|||
ReferenceType,
|
||||
_remove_dead_weakref)
|
||||
|
||||
from _weakrefset import WeakSet, _IterationGuard
|
||||
from _weakrefset import WeakSet
|
||||
|
||||
import _collections_abc # Import after _weakref to avoid circular import.
|
||||
import sys
|
||||
|
@ -105,34 +105,14 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
|
|||
def remove(wr, selfref=ref(self), _atomic_removal=_remove_dead_weakref):
|
||||
self = selfref()
|
||||
if self is not None:
|
||||
if self._iterating:
|
||||
self._pending_removals.append(wr.key)
|
||||
else:
|
||||
# Atomic removal is necessary since this function
|
||||
# can be called asynchronously by the GC
|
||||
_atomic_removal(self.data, wr.key)
|
||||
# Atomic removal is necessary since this function
|
||||
# can be called asynchronously by the GC
|
||||
_atomic_removal(self.data, wr.key)
|
||||
self._remove = remove
|
||||
# A list of keys to be removed
|
||||
self._pending_removals = []
|
||||
self._iterating = set()
|
||||
self.data = {}
|
||||
self.update(other, **kw)
|
||||
|
||||
def _commit_removals(self, _atomic_removal=_remove_dead_weakref):
|
||||
pop = self._pending_removals.pop
|
||||
d = self.data
|
||||
# We shouldn't encounter any KeyError, because this method should
|
||||
# always be called *before* mutating the dict.
|
||||
while True:
|
||||
try:
|
||||
key = pop()
|
||||
except IndexError:
|
||||
return
|
||||
_atomic_removal(d, key)
|
||||
|
||||
def __getitem__(self, key):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
o = self.data[key]()
|
||||
if o is None:
|
||||
raise KeyError(key)
|
||||
|
@ -140,18 +120,12 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
|
|||
return o
|
||||
|
||||
def __delitem__(self, key):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
del self.data[key]
|
||||
|
||||
def __len__(self):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
return len(self.data)
|
||||
|
||||
def __contains__(self, key):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
try:
|
||||
o = self.data[key]()
|
||||
except KeyError:
|
||||
|
@ -162,38 +136,28 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
|
|||
return "<%s at %#x>" % (self.__class__.__name__, id(self))
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
self.data[key] = KeyedRef(value, self._remove, key)
|
||||
|
||||
def copy(self):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
new = WeakValueDictionary()
|
||||
with _IterationGuard(self):
|
||||
for key, wr in self.data.items():
|
||||
o = wr()
|
||||
if o is not None:
|
||||
new[key] = o
|
||||
for key, wr in self.data.copy().items():
|
||||
o = wr()
|
||||
if o is not None:
|
||||
new[key] = o
|
||||
return new
|
||||
|
||||
__copy__ = copy
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
from copy import deepcopy
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
new = self.__class__()
|
||||
with _IterationGuard(self):
|
||||
for key, wr in self.data.items():
|
||||
o = wr()
|
||||
if o is not None:
|
||||
new[deepcopy(key, memo)] = o
|
||||
for key, wr in self.data.copy().items():
|
||||
o = wr()
|
||||
if o is not None:
|
||||
new[deepcopy(key, memo)] = o
|
||||
return new
|
||||
|
||||
def get(self, key, default=None):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
try:
|
||||
wr = self.data[key]
|
||||
except KeyError:
|
||||
|
@ -207,21 +171,15 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
|
|||
return o
|
||||
|
||||
def items(self):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
with _IterationGuard(self):
|
||||
for k, wr in self.data.items():
|
||||
v = wr()
|
||||
if v is not None:
|
||||
yield k, v
|
||||
for k, wr in self.data.copy().items():
|
||||
v = wr()
|
||||
if v is not None:
|
||||
yield k, v
|
||||
|
||||
def keys(self):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
with _IterationGuard(self):
|
||||
for k, wr in self.data.items():
|
||||
if wr() is not None:
|
||||
yield k
|
||||
for k, wr in self.data.copy().items():
|
||||
if wr() is not None:
|
||||
yield k
|
||||
|
||||
__iter__ = keys
|
||||
|
||||
|
@ -235,23 +193,15 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
|
|||
keep the values around longer than needed.
|
||||
|
||||
"""
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
with _IterationGuard(self):
|
||||
yield from self.data.values()
|
||||
yield from self.data.copy().values()
|
||||
|
||||
def values(self):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
with _IterationGuard(self):
|
||||
for wr in self.data.values():
|
||||
obj = wr()
|
||||
if obj is not None:
|
||||
yield obj
|
||||
for wr in self.data.copy().values():
|
||||
obj = wr()
|
||||
if obj is not None:
|
||||
yield obj
|
||||
|
||||
def popitem(self):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
while True:
|
||||
key, wr = self.data.popitem()
|
||||
o = wr()
|
||||
|
@ -259,8 +209,6 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
|
|||
return key, o
|
||||
|
||||
def pop(self, key, *args):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
try:
|
||||
o = self.data.pop(key)()
|
||||
except KeyError:
|
||||
|
@ -279,16 +227,12 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
|
|||
except KeyError:
|
||||
o = None
|
||||
if o is None:
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
self.data[key] = KeyedRef(default, self._remove, key)
|
||||
return default
|
||||
else:
|
||||
return o
|
||||
|
||||
def update(self, other=None, /, **kwargs):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
d = self.data
|
||||
if other is not None:
|
||||
if not hasattr(other, "items"):
|
||||
|
@ -308,9 +252,7 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
|
|||
keep the values around longer than needed.
|
||||
|
||||
"""
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
return list(self.data.values())
|
||||
return list(self.data.copy().values())
|
||||
|
||||
def __ior__(self, other):
|
||||
self.update(other)
|
||||
|
@ -369,57 +311,22 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
|
|||
def remove(k, selfref=ref(self)):
|
||||
self = selfref()
|
||||
if self is not None:
|
||||
if self._iterating:
|
||||
self._pending_removals.append(k)
|
||||
else:
|
||||
try:
|
||||
del self.data[k]
|
||||
except KeyError:
|
||||
pass
|
||||
try:
|
||||
del self.data[k]
|
||||
except KeyError:
|
||||
pass
|
||||
self._remove = remove
|
||||
# A list of dead weakrefs (keys to be removed)
|
||||
self._pending_removals = []
|
||||
self._iterating = set()
|
||||
self._dirty_len = False
|
||||
if dict is not None:
|
||||
self.update(dict)
|
||||
|
||||
def _commit_removals(self):
|
||||
# NOTE: We don't need to call this method before mutating the dict,
|
||||
# because a dead weakref never compares equal to a live weakref,
|
||||
# even if they happened to refer to equal objects.
|
||||
# However, it means keys may already have been removed.
|
||||
pop = self._pending_removals.pop
|
||||
d = self.data
|
||||
while True:
|
||||
try:
|
||||
key = pop()
|
||||
except IndexError:
|
||||
return
|
||||
|
||||
try:
|
||||
del d[key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def _scrub_removals(self):
|
||||
d = self.data
|
||||
self._pending_removals = [k for k in self._pending_removals if k in d]
|
||||
self._dirty_len = False
|
||||
|
||||
def __delitem__(self, key):
|
||||
self._dirty_len = True
|
||||
del self.data[ref(key)]
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.data[ref(key)]
|
||||
|
||||
def __len__(self):
|
||||
if self._dirty_len and self._pending_removals:
|
||||
# self._pending_removals may still contain keys which were
|
||||
# explicitly removed, we have to scrub them (see issue #21173).
|
||||
self._scrub_removals()
|
||||
return len(self.data) - len(self._pending_removals)
|
||||
return len(self.data)
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s at %#x>" % (self.__class__.__name__, id(self))
|
||||
|
@ -429,11 +336,10 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
|
|||
|
||||
def copy(self):
|
||||
new = WeakKeyDictionary()
|
||||
with _IterationGuard(self):
|
||||
for key, value in self.data.items():
|
||||
o = key()
|
||||
if o is not None:
|
||||
new[o] = value
|
||||
for key, value in self.data.copy().items():
|
||||
o = key()
|
||||
if o is not None:
|
||||
new[o] = value
|
||||
return new
|
||||
|
||||
__copy__ = copy
|
||||
|
@ -441,11 +347,10 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
|
|||
def __deepcopy__(self, memo):
|
||||
from copy import deepcopy
|
||||
new = self.__class__()
|
||||
with _IterationGuard(self):
|
||||
for key, value in self.data.items():
|
||||
o = key()
|
||||
if o is not None:
|
||||
new[o] = deepcopy(value, memo)
|
||||
for key, value in self.data.copy().items():
|
||||
o = key()
|
||||
if o is not None:
|
||||
new[o] = deepcopy(value, memo)
|
||||
return new
|
||||
|
||||
def get(self, key, default=None):
|
||||
|
@ -459,26 +364,23 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
|
|||
return wr in self.data
|
||||
|
||||
def items(self):
|
||||
with _IterationGuard(self):
|
||||
for wr, value in self.data.items():
|
||||
key = wr()
|
||||
if key is not None:
|
||||
yield key, value
|
||||
for wr, value in self.data.copy().items():
|
||||
key = wr()
|
||||
if key is not None:
|
||||
yield key, value
|
||||
|
||||
def keys(self):
|
||||
with _IterationGuard(self):
|
||||
for wr in self.data:
|
||||
obj = wr()
|
||||
if obj is not None:
|
||||
yield obj
|
||||
for wr in self.data.copy():
|
||||
obj = wr()
|
||||
if obj is not None:
|
||||
yield obj
|
||||
|
||||
__iter__ = keys
|
||||
|
||||
def values(self):
|
||||
with _IterationGuard(self):
|
||||
for wr, value in self.data.items():
|
||||
if wr() is not None:
|
||||
yield value
|
||||
for wr, value in self.data.copy().items():
|
||||
if wr() is not None:
|
||||
yield value
|
||||
|
||||
def keyrefs(self):
|
||||
"""Return a list of weak references to the keys.
|
||||
|
@ -493,7 +395,6 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
|
|||
return list(self.data)
|
||||
|
||||
def popitem(self):
|
||||
self._dirty_len = True
|
||||
while True:
|
||||
key, value = self.data.popitem()
|
||||
o = key()
|
||||
|
@ -501,7 +402,6 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
|
|||
return o, value
|
||||
|
||||
def pop(self, key, *args):
|
||||
self._dirty_len = True
|
||||
return self.data.pop(ref(key), *args)
|
||||
|
||||
def setdefault(self, key, default=None):
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Make :class:`~weakref.WeakKeyDictionary` and :class:`~weakref.WeakValueDictionary` safe against concurrent mutations from other threads. Patch by Kumar Aditya.
|
Loading…
Reference in New Issue