mirror of https://github.com/python/cpython
gh-89967: make WeakKeyDictionary and WeakValueDictionary thread safe (#125325)
Make `WeakKeyDictionary` and `WeakValueDictionary` thread safe by copying the underlying the dict before iterating over it.
This commit is contained in:
parent
08489325d1
commit
cd0f9d111a
|
@ -8,31 +8,6 @@ from types import GenericAlias
|
||||||
__all__ = ['WeakSet']
|
__all__ = ['WeakSet']
|
||||||
|
|
||||||
|
|
||||||
class _IterationGuard:
|
|
||||||
# This context manager registers itself in the current iterators of the
|
|
||||||
# weak container, such as to delay all removals until the context manager
|
|
||||||
# exits.
|
|
||||||
# This technique should be relatively thread-safe (since sets are).
|
|
||||||
|
|
||||||
def __init__(self, weakcontainer):
|
|
||||||
# Don't create cycles
|
|
||||||
self.weakcontainer = ref(weakcontainer)
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
w = self.weakcontainer()
|
|
||||||
if w is not None:
|
|
||||||
w._iterating.add(self)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, e, t, b):
|
|
||||||
w = self.weakcontainer()
|
|
||||||
if w is not None:
|
|
||||||
s = w._iterating
|
|
||||||
s.remove(self)
|
|
||||||
if not s:
|
|
||||||
w._commit_removals()
|
|
||||||
|
|
||||||
|
|
||||||
class WeakSet:
|
class WeakSet:
|
||||||
def __init__(self, data=None):
|
def __init__(self, data=None):
|
||||||
self.data = set()
|
self.data = set()
|
||||||
|
|
198
Lib/weakref.py
198
Lib/weakref.py
|
@ -19,7 +19,7 @@ from _weakref import (
|
||||||
ReferenceType,
|
ReferenceType,
|
||||||
_remove_dead_weakref)
|
_remove_dead_weakref)
|
||||||
|
|
||||||
from _weakrefset import WeakSet, _IterationGuard
|
from _weakrefset import WeakSet
|
||||||
|
|
||||||
import _collections_abc # Import after _weakref to avoid circular import.
|
import _collections_abc # Import after _weakref to avoid circular import.
|
||||||
import sys
|
import sys
|
||||||
|
@ -105,34 +105,14 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
|
||||||
def remove(wr, selfref=ref(self), _atomic_removal=_remove_dead_weakref):
|
def remove(wr, selfref=ref(self), _atomic_removal=_remove_dead_weakref):
|
||||||
self = selfref()
|
self = selfref()
|
||||||
if self is not None:
|
if self is not None:
|
||||||
if self._iterating:
|
# Atomic removal is necessary since this function
|
||||||
self._pending_removals.append(wr.key)
|
# can be called asynchronously by the GC
|
||||||
else:
|
_atomic_removal(self.data, wr.key)
|
||||||
# Atomic removal is necessary since this function
|
|
||||||
# can be called asynchronously by the GC
|
|
||||||
_atomic_removal(self.data, wr.key)
|
|
||||||
self._remove = remove
|
self._remove = remove
|
||||||
# A list of keys to be removed
|
|
||||||
self._pending_removals = []
|
|
||||||
self._iterating = set()
|
|
||||||
self.data = {}
|
self.data = {}
|
||||||
self.update(other, **kw)
|
self.update(other, **kw)
|
||||||
|
|
||||||
def _commit_removals(self, _atomic_removal=_remove_dead_weakref):
|
|
||||||
pop = self._pending_removals.pop
|
|
||||||
d = self.data
|
|
||||||
# We shouldn't encounter any KeyError, because this method should
|
|
||||||
# always be called *before* mutating the dict.
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
key = pop()
|
|
||||||
except IndexError:
|
|
||||||
return
|
|
||||||
_atomic_removal(d, key)
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
if self._pending_removals:
|
|
||||||
self._commit_removals()
|
|
||||||
o = self.data[key]()
|
o = self.data[key]()
|
||||||
if o is None:
|
if o is None:
|
||||||
raise KeyError(key)
|
raise KeyError(key)
|
||||||
|
@ -140,18 +120,12 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
|
||||||
return o
|
return o
|
||||||
|
|
||||||
def __delitem__(self, key):
|
def __delitem__(self, key):
|
||||||
if self._pending_removals:
|
|
||||||
self._commit_removals()
|
|
||||||
del self.data[key]
|
del self.data[key]
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
if self._pending_removals:
|
|
||||||
self._commit_removals()
|
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
if self._pending_removals:
|
|
||||||
self._commit_removals()
|
|
||||||
try:
|
try:
|
||||||
o = self.data[key]()
|
o = self.data[key]()
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
@ -162,38 +136,28 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
|
||||||
return "<%s at %#x>" % (self.__class__.__name__, id(self))
|
return "<%s at %#x>" % (self.__class__.__name__, id(self))
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
if self._pending_removals:
|
|
||||||
self._commit_removals()
|
|
||||||
self.data[key] = KeyedRef(value, self._remove, key)
|
self.data[key] = KeyedRef(value, self._remove, key)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
if self._pending_removals:
|
|
||||||
self._commit_removals()
|
|
||||||
new = WeakValueDictionary()
|
new = WeakValueDictionary()
|
||||||
with _IterationGuard(self):
|
for key, wr in self.data.copy().items():
|
||||||
for key, wr in self.data.items():
|
o = wr()
|
||||||
o = wr()
|
if o is not None:
|
||||||
if o is not None:
|
new[key] = o
|
||||||
new[key] = o
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
__copy__ = copy
|
__copy__ = copy
|
||||||
|
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, memo):
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
if self._pending_removals:
|
|
||||||
self._commit_removals()
|
|
||||||
new = self.__class__()
|
new = self.__class__()
|
||||||
with _IterationGuard(self):
|
for key, wr in self.data.copy().items():
|
||||||
for key, wr in self.data.items():
|
o = wr()
|
||||||
o = wr()
|
if o is not None:
|
||||||
if o is not None:
|
new[deepcopy(key, memo)] = o
|
||||||
new[deepcopy(key, memo)] = o
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
if self._pending_removals:
|
|
||||||
self._commit_removals()
|
|
||||||
try:
|
try:
|
||||||
wr = self.data[key]
|
wr = self.data[key]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
@ -207,21 +171,15 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
|
||||||
return o
|
return o
|
||||||
|
|
||||||
def items(self):
|
def items(self):
|
||||||
if self._pending_removals:
|
for k, wr in self.data.copy().items():
|
||||||
self._commit_removals()
|
v = wr()
|
||||||
with _IterationGuard(self):
|
if v is not None:
|
||||||
for k, wr in self.data.items():
|
yield k, v
|
||||||
v = wr()
|
|
||||||
if v is not None:
|
|
||||||
yield k, v
|
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
if self._pending_removals:
|
for k, wr in self.data.copy().items():
|
||||||
self._commit_removals()
|
if wr() is not None:
|
||||||
with _IterationGuard(self):
|
yield k
|
||||||
for k, wr in self.data.items():
|
|
||||||
if wr() is not None:
|
|
||||||
yield k
|
|
||||||
|
|
||||||
__iter__ = keys
|
__iter__ = keys
|
||||||
|
|
||||||
|
@ -235,23 +193,15 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
|
||||||
keep the values around longer than needed.
|
keep the values around longer than needed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if self._pending_removals:
|
yield from self.data.copy().values()
|
||||||
self._commit_removals()
|
|
||||||
with _IterationGuard(self):
|
|
||||||
yield from self.data.values()
|
|
||||||
|
|
||||||
def values(self):
|
def values(self):
|
||||||
if self._pending_removals:
|
for wr in self.data.copy().values():
|
||||||
self._commit_removals()
|
obj = wr()
|
||||||
with _IterationGuard(self):
|
if obj is not None:
|
||||||
for wr in self.data.values():
|
yield obj
|
||||||
obj = wr()
|
|
||||||
if obj is not None:
|
|
||||||
yield obj
|
|
||||||
|
|
||||||
def popitem(self):
|
def popitem(self):
|
||||||
if self._pending_removals:
|
|
||||||
self._commit_removals()
|
|
||||||
while True:
|
while True:
|
||||||
key, wr = self.data.popitem()
|
key, wr = self.data.popitem()
|
||||||
o = wr()
|
o = wr()
|
||||||
|
@ -259,8 +209,6 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
|
||||||
return key, o
|
return key, o
|
||||||
|
|
||||||
def pop(self, key, *args):
|
def pop(self, key, *args):
|
||||||
if self._pending_removals:
|
|
||||||
self._commit_removals()
|
|
||||||
try:
|
try:
|
||||||
o = self.data.pop(key)()
|
o = self.data.pop(key)()
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
@ -279,16 +227,12 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
|
||||||
except KeyError:
|
except KeyError:
|
||||||
o = None
|
o = None
|
||||||
if o is None:
|
if o is None:
|
||||||
if self._pending_removals:
|
|
||||||
self._commit_removals()
|
|
||||||
self.data[key] = KeyedRef(default, self._remove, key)
|
self.data[key] = KeyedRef(default, self._remove, key)
|
||||||
return default
|
return default
|
||||||
else:
|
else:
|
||||||
return o
|
return o
|
||||||
|
|
||||||
def update(self, other=None, /, **kwargs):
|
def update(self, other=None, /, **kwargs):
|
||||||
if self._pending_removals:
|
|
||||||
self._commit_removals()
|
|
||||||
d = self.data
|
d = self.data
|
||||||
if other is not None:
|
if other is not None:
|
||||||
if not hasattr(other, "items"):
|
if not hasattr(other, "items"):
|
||||||
|
@ -308,9 +252,7 @@ class WeakValueDictionary(_collections_abc.MutableMapping):
|
||||||
keep the values around longer than needed.
|
keep the values around longer than needed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if self._pending_removals:
|
return list(self.data.copy().values())
|
||||||
self._commit_removals()
|
|
||||||
return list(self.data.values())
|
|
||||||
|
|
||||||
def __ior__(self, other):
|
def __ior__(self, other):
|
||||||
self.update(other)
|
self.update(other)
|
||||||
|
@ -369,57 +311,22 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
|
||||||
def remove(k, selfref=ref(self)):
|
def remove(k, selfref=ref(self)):
|
||||||
self = selfref()
|
self = selfref()
|
||||||
if self is not None:
|
if self is not None:
|
||||||
if self._iterating:
|
try:
|
||||||
self._pending_removals.append(k)
|
del self.data[k]
|
||||||
else:
|
except KeyError:
|
||||||
try:
|
pass
|
||||||
del self.data[k]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
self._remove = remove
|
self._remove = remove
|
||||||
# A list of dead weakrefs (keys to be removed)
|
|
||||||
self._pending_removals = []
|
|
||||||
self._iterating = set()
|
|
||||||
self._dirty_len = False
|
|
||||||
if dict is not None:
|
if dict is not None:
|
||||||
self.update(dict)
|
self.update(dict)
|
||||||
|
|
||||||
def _commit_removals(self):
|
|
||||||
# NOTE: We don't need to call this method before mutating the dict,
|
|
||||||
# because a dead weakref never compares equal to a live weakref,
|
|
||||||
# even if they happened to refer to equal objects.
|
|
||||||
# However, it means keys may already have been removed.
|
|
||||||
pop = self._pending_removals.pop
|
|
||||||
d = self.data
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
key = pop()
|
|
||||||
except IndexError:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
del d[key]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _scrub_removals(self):
|
|
||||||
d = self.data
|
|
||||||
self._pending_removals = [k for k in self._pending_removals if k in d]
|
|
||||||
self._dirty_len = False
|
|
||||||
|
|
||||||
def __delitem__(self, key):
|
def __delitem__(self, key):
|
||||||
self._dirty_len = True
|
|
||||||
del self.data[ref(key)]
|
del self.data[ref(key)]
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
return self.data[ref(key)]
|
return self.data[ref(key)]
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
if self._dirty_len and self._pending_removals:
|
return len(self.data)
|
||||||
# self._pending_removals may still contain keys which were
|
|
||||||
# explicitly removed, we have to scrub them (see issue #21173).
|
|
||||||
self._scrub_removals()
|
|
||||||
return len(self.data) - len(self._pending_removals)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<%s at %#x>" % (self.__class__.__name__, id(self))
|
return "<%s at %#x>" % (self.__class__.__name__, id(self))
|
||||||
|
@ -429,11 +336,10 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
new = WeakKeyDictionary()
|
new = WeakKeyDictionary()
|
||||||
with _IterationGuard(self):
|
for key, value in self.data.copy().items():
|
||||||
for key, value in self.data.items():
|
o = key()
|
||||||
o = key()
|
if o is not None:
|
||||||
if o is not None:
|
new[o] = value
|
||||||
new[o] = value
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
__copy__ = copy
|
__copy__ = copy
|
||||||
|
@ -441,11 +347,10 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, memo):
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
new = self.__class__()
|
new = self.__class__()
|
||||||
with _IterationGuard(self):
|
for key, value in self.data.copy().items():
|
||||||
for key, value in self.data.items():
|
o = key()
|
||||||
o = key()
|
if o is not None:
|
||||||
if o is not None:
|
new[o] = deepcopy(value, memo)
|
||||||
new[o] = deepcopy(value, memo)
|
|
||||||
return new
|
return new
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
|
@ -459,26 +364,23 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
|
||||||
return wr in self.data
|
return wr in self.data
|
||||||
|
|
||||||
def items(self):
|
def items(self):
|
||||||
with _IterationGuard(self):
|
for wr, value in self.data.copy().items():
|
||||||
for wr, value in self.data.items():
|
key = wr()
|
||||||
key = wr()
|
if key is not None:
|
||||||
if key is not None:
|
yield key, value
|
||||||
yield key, value
|
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
with _IterationGuard(self):
|
for wr in self.data.copy():
|
||||||
for wr in self.data:
|
obj = wr()
|
||||||
obj = wr()
|
if obj is not None:
|
||||||
if obj is not None:
|
yield obj
|
||||||
yield obj
|
|
||||||
|
|
||||||
__iter__ = keys
|
__iter__ = keys
|
||||||
|
|
||||||
def values(self):
|
def values(self):
|
||||||
with _IterationGuard(self):
|
for wr, value in self.data.copy().items():
|
||||||
for wr, value in self.data.items():
|
if wr() is not None:
|
||||||
if wr() is not None:
|
yield value
|
||||||
yield value
|
|
||||||
|
|
||||||
def keyrefs(self):
|
def keyrefs(self):
|
||||||
"""Return a list of weak references to the keys.
|
"""Return a list of weak references to the keys.
|
||||||
|
@ -493,7 +395,6 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
|
||||||
return list(self.data)
|
return list(self.data)
|
||||||
|
|
||||||
def popitem(self):
|
def popitem(self):
|
||||||
self._dirty_len = True
|
|
||||||
while True:
|
while True:
|
||||||
key, value = self.data.popitem()
|
key, value = self.data.popitem()
|
||||||
o = key()
|
o = key()
|
||||||
|
@ -501,7 +402,6 @@ class WeakKeyDictionary(_collections_abc.MutableMapping):
|
||||||
return o, value
|
return o, value
|
||||||
|
|
||||||
def pop(self, key, *args):
|
def pop(self, key, *args):
|
||||||
self._dirty_len = True
|
|
||||||
return self.data.pop(ref(key), *args)
|
return self.data.pop(ref(key), *args)
|
||||||
|
|
||||||
def setdefault(self, key, default=None):
|
def setdefault(self, key, default=None):
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Make :class:`~weakref.WeakKeyDictionary` and :class:`~weakref.WeakValueDictionary` safe against concurrent mutations from other threads. Patch by Kumar Aditya.
|
Loading…
Reference in New Issue