Issue #7105: Make WeakKeyDictionary and WeakValueDictionary robust against
the destruction of weakref'ed objects while iterating.
This commit is contained in:
parent
dc2a61347b
commit
c1baa601e2
|
@ -159,7 +159,7 @@ than needed.
|
|||
|
||||
.. method:: WeakKeyDictionary.keyrefs()
|
||||
|
||||
Return an :term:`iterator` that yields the weak references to the keys.
|
||||
Return an iterable of the weak references to the keys.
|
||||
|
||||
|
||||
.. class:: WeakValueDictionary([dict])
|
||||
|
@ -182,7 +182,7 @@ These method have the same issues as the and :meth:`keyrefs` method of
|
|||
|
||||
.. method:: WeakValueDictionary.valuerefs()
|
||||
|
||||
Return an :term:`iterator` that yields the weak references to the values.
|
||||
Return an iterable of the weak references to the values.
|
||||
|
||||
|
||||
.. class:: WeakSet([elements])
|
||||
|
|
|
@ -6,22 +6,61 @@ from _weakref import ref
|
|||
|
||||
__all__ = ['WeakSet']
|
||||
|
||||
|
||||
class _IterationGuard:
|
||||
# This context manager registers itself in the current iterators of the
|
||||
# weak container, such as to delay all removals until the context manager
|
||||
# exits.
|
||||
# This technique should be relatively thread-safe (since sets are).
|
||||
|
||||
def __init__(self, weakcontainer):
|
||||
# Don't create cycles
|
||||
self.weakcontainer = ref(weakcontainer)
|
||||
|
||||
def __enter__(self):
|
||||
w = self.weakcontainer()
|
||||
if w is not None:
|
||||
w._iterating.add(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, e, t, b):
|
||||
w = self.weakcontainer()
|
||||
if w is not None:
|
||||
s = w._iterating
|
||||
s.remove(self)
|
||||
if not s:
|
||||
w._commit_removals()
|
||||
|
||||
|
||||
class WeakSet:
|
||||
def __init__(self, data=None):
|
||||
self.data = set()
|
||||
def _remove(item, selfref=ref(self)):
|
||||
self = selfref()
|
||||
if self is not None:
|
||||
self.data.discard(item)
|
||||
if self._iterating:
|
||||
self._pending_removals.append(item)
|
||||
else:
|
||||
self.data.discard(item)
|
||||
self._remove = _remove
|
||||
# A list of keys to be removed
|
||||
self._pending_removals = []
|
||||
self._iterating = set()
|
||||
if data is not None:
|
||||
self.update(data)
|
||||
|
||||
def _commit_removals(self):
|
||||
l = self._pending_removals
|
||||
discard = self.data.discard
|
||||
while l:
|
||||
discard(l.pop())
|
||||
|
||||
def __iter__(self):
|
||||
for itemref in self.data:
|
||||
item = itemref()
|
||||
if item is not None:
|
||||
yield item
|
||||
with _IterationGuard(self):
|
||||
for itemref in self.data:
|
||||
item = itemref()
|
||||
if item is not None:
|
||||
yield item
|
||||
|
||||
def __len__(self):
|
||||
return sum(x() is not None for x in self.data)
|
||||
|
@ -34,15 +73,21 @@ class WeakSet:
|
|||
getattr(self, '__dict__', None))
|
||||
|
||||
def add(self, item):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
self.data.add(ref(item, self._remove))
|
||||
|
||||
def clear(self):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
self.data.clear()
|
||||
|
||||
def copy(self):
|
||||
return self.__class__(self)
|
||||
|
||||
def pop(self):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
while True:
|
||||
try:
|
||||
itemref = self.data.pop()
|
||||
|
@ -53,17 +98,24 @@ class WeakSet:
|
|||
return item
|
||||
|
||||
def remove(self, item):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
self.data.remove(ref(item))
|
||||
|
||||
def discard(self, item):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
self.data.discard(ref(item))
|
||||
|
||||
def update(self, other):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
if isinstance(other, self.__class__):
|
||||
self.data.update(other.data)
|
||||
else:
|
||||
for element in other:
|
||||
self.add(element)
|
||||
|
||||
def __ior__(self, other):
|
||||
self.update(other)
|
||||
return self
|
||||
|
@ -82,11 +134,15 @@ class WeakSet:
|
|||
__sub__ = difference
|
||||
|
||||
def difference_update(self, other):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
if self is other:
|
||||
self.data.clear()
|
||||
else:
|
||||
self.data.difference_update(ref(item) for item in other)
|
||||
def __isub__(self, other):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
if self is other:
|
||||
self.data.clear()
|
||||
else:
|
||||
|
@ -98,8 +154,12 @@ class WeakSet:
|
|||
__and__ = intersection
|
||||
|
||||
def intersection_update(self, other):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
self.data.intersection_update(ref(item) for item in other)
|
||||
def __iand__(self, other):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
self.data.intersection_update(ref(item) for item in other)
|
||||
return self
|
||||
|
||||
|
@ -127,11 +187,15 @@ class WeakSet:
|
|||
__xor__ = symmetric_difference
|
||||
|
||||
def symmetric_difference_update(self, other):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
if self is other:
|
||||
self.data.clear()
|
||||
else:
|
||||
self.data.symmetric_difference_update(ref(item) for item in other)
|
||||
def __ixor__(self, other):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
if self is other:
|
||||
self.data.clear()
|
||||
else:
|
||||
|
|
|
@ -4,6 +4,8 @@ import unittest
|
|||
import collections
|
||||
import weakref
|
||||
import operator
|
||||
import contextlib
|
||||
import copy
|
||||
|
||||
from test import support
|
||||
|
||||
|
@ -788,6 +790,10 @@ class Object:
|
|||
self.arg = arg
|
||||
def __repr__(self):
|
||||
return "<Object %r>" % self.arg
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Object):
|
||||
return self.arg == other.arg
|
||||
return NotImplemented
|
||||
def __lt__(self, other):
|
||||
if isinstance(other, Object):
|
||||
return self.arg < other.arg
|
||||
|
@ -935,6 +941,87 @@ class MappingTestCase(TestBase):
|
|||
self.assertFalse(values,
|
||||
"itervalues() did not touch all values")
|
||||
|
||||
def check_weak_destroy_while_iterating(self, dict, objects, iter_name):
|
||||
n = len(dict)
|
||||
it = iter(getattr(dict, iter_name)())
|
||||
next(it) # Trigger internal iteration
|
||||
# Destroy an object
|
||||
del objects[-1]
|
||||
gc.collect() # just in case
|
||||
# We have removed either the first consumed object, or another one
|
||||
self.assertIn(len(list(it)), [len(objects), len(objects) - 1])
|
||||
del it
|
||||
# The removal has been committed
|
||||
self.assertEqual(len(dict), n - 1)
|
||||
|
||||
def check_weak_destroy_and_mutate_while_iterating(self, dict, testcontext):
|
||||
# Check that we can explicitly mutate the weak dict without
|
||||
# interfering with delayed removal.
|
||||
# `testcontext` should create an iterator, destroy one of the
|
||||
# weakref'ed objects and then return a new key/value pair corresponding
|
||||
# to the destroyed object.
|
||||
with testcontext() as (k, v):
|
||||
self.assertFalse(k in dict)
|
||||
with testcontext() as (k, v):
|
||||
self.assertRaises(KeyError, dict.__delitem__, k)
|
||||
self.assertFalse(k in dict)
|
||||
with testcontext() as (k, v):
|
||||
self.assertRaises(KeyError, dict.pop, k)
|
||||
self.assertFalse(k in dict)
|
||||
with testcontext() as (k, v):
|
||||
dict[k] = v
|
||||
self.assertEqual(dict[k], v)
|
||||
ddict = copy.copy(dict)
|
||||
with testcontext() as (k, v):
|
||||
dict.update(ddict)
|
||||
self.assertEqual(dict, ddict)
|
||||
with testcontext() as (k, v):
|
||||
dict.clear()
|
||||
self.assertEqual(len(dict), 0)
|
||||
|
||||
def test_weak_keys_destroy_while_iterating(self):
|
||||
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
|
||||
dict, objects = self.make_weak_keyed_dict()
|
||||
self.check_weak_destroy_while_iterating(dict, objects, 'keys')
|
||||
self.check_weak_destroy_while_iterating(dict, objects, 'items')
|
||||
self.check_weak_destroy_while_iterating(dict, objects, 'values')
|
||||
self.check_weak_destroy_while_iterating(dict, objects, 'keyrefs')
|
||||
dict, objects = self.make_weak_keyed_dict()
|
||||
@contextlib.contextmanager
|
||||
def testcontext():
|
||||
try:
|
||||
it = iter(dict.items())
|
||||
next(it)
|
||||
# Schedule a key/value for removal and recreate it
|
||||
v = objects.pop().arg
|
||||
gc.collect() # just in case
|
||||
yield Object(v), v
|
||||
finally:
|
||||
it = None # should commit all removals
|
||||
self.check_weak_destroy_and_mutate_while_iterating(dict, testcontext)
|
||||
|
||||
def test_weak_values_destroy_while_iterating(self):
|
||||
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
|
||||
dict, objects = self.make_weak_valued_dict()
|
||||
self.check_weak_destroy_while_iterating(dict, objects, 'keys')
|
||||
self.check_weak_destroy_while_iterating(dict, objects, 'items')
|
||||
self.check_weak_destroy_while_iterating(dict, objects, 'values')
|
||||
self.check_weak_destroy_while_iterating(dict, objects, 'itervaluerefs')
|
||||
self.check_weak_destroy_while_iterating(dict, objects, 'valuerefs')
|
||||
dict, objects = self.make_weak_valued_dict()
|
||||
@contextlib.contextmanager
|
||||
def testcontext():
|
||||
try:
|
||||
it = iter(dict.items())
|
||||
next(it)
|
||||
# Schedule a key/value for removal and recreate it
|
||||
k = objects.pop().arg
|
||||
gc.collect() # just in case
|
||||
yield k, Object(k)
|
||||
finally:
|
||||
it = None # should commit all removals
|
||||
self.check_weak_destroy_and_mutate_while_iterating(dict, testcontext)
|
||||
|
||||
def test_make_weak_keyed_dict_from_dict(self):
|
||||
o = Object(3)
|
||||
dict = weakref.WeakKeyDictionary({o:364})
|
||||
|
|
|
@ -10,6 +10,8 @@ import sys
|
|||
import warnings
|
||||
import collections
|
||||
from collections import UserString as ustr
|
||||
import gc
|
||||
import contextlib
|
||||
|
||||
|
||||
class Foo:
|
||||
|
@ -307,6 +309,54 @@ class TestWeakSet(unittest.TestCase):
|
|||
self.assertFalse(self.s == WeakSet([Foo]))
|
||||
self.assertFalse(self.s == 1)
|
||||
|
||||
def test_weak_destroy_while_iterating(self):
|
||||
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
|
||||
# Create new items to be sure no-one else holds a reference
|
||||
items = [ustr(c) for c in ('a', 'b', 'c')]
|
||||
s = WeakSet(items)
|
||||
it = iter(s)
|
||||
next(it) # Trigger internal iteration
|
||||
# Destroy an item
|
||||
del items[-1]
|
||||
gc.collect() # just in case
|
||||
# We have removed either the first consumed items, or another one
|
||||
self.assertIn(len(list(it)), [len(items), len(items) - 1])
|
||||
del it
|
||||
# The removal has been committed
|
||||
self.assertEqual(len(s), len(items))
|
||||
|
||||
def test_weak_destroy_and_mutate_while_iterating(self):
|
||||
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
|
||||
items = [ustr(c) for c in string.ascii_letters]
|
||||
s = WeakSet(items)
|
||||
@contextlib.contextmanager
|
||||
def testcontext():
|
||||
try:
|
||||
it = iter(s)
|
||||
next(it)
|
||||
# Schedule an item for removal and recreate it
|
||||
u = ustr(str(items.pop()))
|
||||
gc.collect() # just in case
|
||||
yield u
|
||||
finally:
|
||||
it = None # should commit all removals
|
||||
|
||||
with testcontext() as u:
|
||||
self.assertFalse(u in s)
|
||||
with testcontext() as u:
|
||||
self.assertRaises(KeyError, s.remove, u)
|
||||
self.assertFalse(u in s)
|
||||
with testcontext() as u:
|
||||
s.add(u)
|
||||
self.assertTrue(u in s)
|
||||
t = s.copy()
|
||||
with testcontext() as u:
|
||||
s.update(t)
|
||||
self.assertEqual(len(s), len(t))
|
||||
with testcontext() as u:
|
||||
s.clear()
|
||||
self.assertEqual(len(s), 0)
|
||||
|
||||
|
||||
def test_main(verbose=None):
|
||||
support.run_unittest(TestWeakSet)
|
||||
|
|
134
Lib/weakref.py
134
Lib/weakref.py
|
@ -18,7 +18,7 @@ from _weakref import (
|
|||
ProxyType,
|
||||
ReferenceType)
|
||||
|
||||
from _weakrefset import WeakSet
|
||||
from _weakrefset import WeakSet, _IterationGuard
|
||||
|
||||
import collections # Import after _weakref to avoid circular import.
|
||||
|
||||
|
@ -46,11 +46,25 @@ class WeakValueDictionary(collections.MutableMapping):
|
|||
def remove(wr, selfref=ref(self)):
|
||||
self = selfref()
|
||||
if self is not None:
|
||||
del self.data[wr.key]
|
||||
if self._iterating:
|
||||
self._pending_removals.append(wr.key)
|
||||
else:
|
||||
del self.data[wr.key]
|
||||
self._remove = remove
|
||||
# A list of keys to be removed
|
||||
self._pending_removals = []
|
||||
self._iterating = set()
|
||||
self.data = d = {}
|
||||
self.update(*args, **kw)
|
||||
|
||||
def _commit_removals(self):
|
||||
l = self._pending_removals
|
||||
d = self.data
|
||||
# We shouldn't encounter any KeyError, because this method should
|
||||
# always be called *before* mutating the dict.
|
||||
while l:
|
||||
del d[l.pop()]
|
||||
|
||||
def __getitem__(self, key):
|
||||
o = self.data[key]()
|
||||
if o is None:
|
||||
|
@ -59,6 +73,8 @@ class WeakValueDictionary(collections.MutableMapping):
|
|||
return o
|
||||
|
||||
def __delitem__(self, key):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
del self.data[key]
|
||||
|
||||
def __len__(self):
|
||||
|
@ -75,6 +91,8 @@ class WeakValueDictionary(collections.MutableMapping):
|
|||
return "<WeakValueDictionary at %s>" % id(self)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
self.data[key] = KeyedRef(value, self._remove, key)
|
||||
|
||||
def copy(self):
|
||||
|
@ -110,24 +128,19 @@ class WeakValueDictionary(collections.MutableMapping):
|
|||
return o
|
||||
|
||||
def items(self):
|
||||
L = []
|
||||
for key, wr in self.data.items():
|
||||
o = wr()
|
||||
if o is not None:
|
||||
L.append((key, o))
|
||||
return L
|
||||
|
||||
def items(self):
|
||||
for wr in self.data.values():
|
||||
value = wr()
|
||||
if value is not None:
|
||||
yield wr.key, value
|
||||
with _IterationGuard(self):
|
||||
for k, wr in self.data.items():
|
||||
v = wr()
|
||||
if v is not None:
|
||||
yield k, v
|
||||
|
||||
def keys(self):
|
||||
return iter(self.data.keys())
|
||||
with _IterationGuard(self):
|
||||
for k, wr in self.data.items():
|
||||
if wr() is not None:
|
||||
yield k
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.data.keys())
|
||||
__iter__ = keys
|
||||
|
||||
def itervaluerefs(self):
|
||||
"""Return an iterator that yields the weak references to the values.
|
||||
|
@ -139,15 +152,20 @@ class WeakValueDictionary(collections.MutableMapping):
|
|||
keep the values around longer than needed.
|
||||
|
||||
"""
|
||||
return self.data.values()
|
||||
with _IterationGuard(self):
|
||||
for wr in self.data.values():
|
||||
yield wr
|
||||
|
||||
def values(self):
|
||||
for wr in self.data.values():
|
||||
obj = wr()
|
||||
if obj is not None:
|
||||
yield obj
|
||||
with _IterationGuard(self):
|
||||
for wr in self.data.values():
|
||||
obj = wr()
|
||||
if obj is not None:
|
||||
yield obj
|
||||
|
||||
def popitem(self):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
while 1:
|
||||
key, wr = self.data.popitem()
|
||||
o = wr()
|
||||
|
@ -155,6 +173,8 @@ class WeakValueDictionary(collections.MutableMapping):
|
|||
return key, o
|
||||
|
||||
def pop(self, key, *args):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
try:
|
||||
o = self.data.pop(key)()
|
||||
except KeyError:
|
||||
|
@ -170,12 +190,16 @@ class WeakValueDictionary(collections.MutableMapping):
|
|||
try:
|
||||
wr = self.data[key]
|
||||
except KeyError:
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
self.data[key] = KeyedRef(default, self._remove, key)
|
||||
return default
|
||||
else:
|
||||
return wr()
|
||||
|
||||
def update(self, dict=None, **kwargs):
|
||||
if self._pending_removals:
|
||||
self._commit_removals()
|
||||
d = self.data
|
||||
if dict is not None:
|
||||
if not hasattr(dict, "items"):
|
||||
|
@ -195,7 +219,7 @@ class WeakValueDictionary(collections.MutableMapping):
|
|||
keep the values around longer than needed.
|
||||
|
||||
"""
|
||||
return self.data.values()
|
||||
return list(self.data.values())
|
||||
|
||||
|
||||
class KeyedRef(ref):
|
||||
|
@ -235,9 +259,29 @@ class WeakKeyDictionary(collections.MutableMapping):
|
|||
def remove(k, selfref=ref(self)):
|
||||
self = selfref()
|
||||
if self is not None:
|
||||
del self.data[k]
|
||||
if self._iterating:
|
||||
self._pending_removals.append(k)
|
||||
else:
|
||||
del self.data[k]
|
||||
self._remove = remove
|
||||
if dict is not None: self.update(dict)
|
||||
# A list of dead weakrefs (keys to be removed)
|
||||
self._pending_removals = []
|
||||
self._iterating = set()
|
||||
if dict is not None:
|
||||
self.update(dict)
|
||||
|
||||
def _commit_removals(self):
|
||||
# NOTE: We don't need to call this method before mutating the dict,
|
||||
# because a dead weakref never compares equal to a live weakref,
|
||||
# even if they happened to refer to equal objects.
|
||||
# However, it means keys may already have been removed.
|
||||
l = self._pending_removals
|
||||
d = self.data
|
||||
while l:
|
||||
try:
|
||||
del d[l.pop()]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.data[ref(key)]
|
||||
|
@ -284,34 +328,26 @@ class WeakKeyDictionary(collections.MutableMapping):
|
|||
return wr in self.data
|
||||
|
||||
def items(self):
|
||||
for wr, value in self.data.items():
|
||||
key = wr()
|
||||
if key is not None:
|
||||
yield key, value
|
||||
|
||||
def keyrefs(self):
|
||||
"""Return an iterator that yields the weak references to the keys.
|
||||
|
||||
The references are not guaranteed to be 'live' at the time
|
||||
they are used, so the result of calling the references needs
|
||||
to be checked before being used. This can be used to avoid
|
||||
creating references that will cause the garbage collector to
|
||||
keep the keys around longer than needed.
|
||||
|
||||
"""
|
||||
return self.data.keys()
|
||||
with _IterationGuard(self):
|
||||
for wr, value in self.data.items():
|
||||
key = wr()
|
||||
if key is not None:
|
||||
yield key, value
|
||||
|
||||
def keys(self):
|
||||
for wr in self.data.keys():
|
||||
obj = wr()
|
||||
if obj is not None:
|
||||
yield obj
|
||||
with _IterationGuard(self):
|
||||
for wr in self.data:
|
||||
obj = wr()
|
||||
if obj is not None:
|
||||
yield obj
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.keys())
|
||||
__iter__ = keys
|
||||
|
||||
def values(self):
|
||||
return iter(self.data.values())
|
||||
with _IterationGuard(self):
|
||||
for wr, value in self.data.items():
|
||||
if wr() is not None:
|
||||
yield value
|
||||
|
||||
def keyrefs(self):
|
||||
"""Return a list of weak references to the keys.
|
||||
|
@ -323,7 +359,7 @@ class WeakKeyDictionary(collections.MutableMapping):
|
|||
keep the keys around longer than needed.
|
||||
|
||||
"""
|
||||
return self.data.keys()
|
||||
return list(self.data)
|
||||
|
||||
def popitem(self):
|
||||
while 1:
|
||||
|
|
Loading…
Reference in New Issue