Issue #7105: Make WeakKeyDictionary and WeakValueDictionary robust against

the destruction of weakref'ed objects while iterating.
This commit is contained in:
Antoine Pitrou 2010-01-08 17:54:23 +00:00
parent dc2a61347b
commit c1baa601e2
6 changed files with 296 additions and 56 deletions

View File

@ -159,7 +159,7 @@ than needed.
.. method:: WeakKeyDictionary.keyrefs() .. method:: WeakKeyDictionary.keyrefs()
Return an :term:`iterator` that yields the weak references to the keys. Return an iterable of the weak references to the keys.
.. class:: WeakValueDictionary([dict]) .. class:: WeakValueDictionary([dict])
@ -182,7 +182,7 @@ These method have the same issues as the and :meth:`keyrefs` method of
.. method:: WeakValueDictionary.valuerefs() .. method:: WeakValueDictionary.valuerefs()
Return an :term:`iterator` that yields the weak references to the values. Return an iterable of the weak references to the values.
.. class:: WeakSet([elements]) .. class:: WeakSet([elements])

View File

@ -6,18 +6,57 @@ from _weakref import ref
__all__ = ['WeakSet'] __all__ = ['WeakSet']
class _IterationGuard:
# This context manager registers itself in the current iterators of the
# weak container, such as to delay all removals until the context manager
# exits.
# This technique should be relatively thread-safe (since sets are).
def __init__(self, weakcontainer):
# Don't create cycles
self.weakcontainer = ref(weakcontainer)
def __enter__(self):
w = self.weakcontainer()
if w is not None:
w._iterating.add(self)
return self
def __exit__(self, e, t, b):
w = self.weakcontainer()
if w is not None:
s = w._iterating
s.remove(self)
if not s:
w._commit_removals()
class WeakSet: class WeakSet:
def __init__(self, data=None): def __init__(self, data=None):
self.data = set() self.data = set()
def _remove(item, selfref=ref(self)): def _remove(item, selfref=ref(self)):
self = selfref() self = selfref()
if self is not None: if self is not None:
if self._iterating:
self._pending_removals.append(item)
else:
self.data.discard(item) self.data.discard(item)
self._remove = _remove self._remove = _remove
# A list of keys to be removed
self._pending_removals = []
self._iterating = set()
if data is not None: if data is not None:
self.update(data) self.update(data)
def _commit_removals(self):
l = self._pending_removals
discard = self.data.discard
while l:
discard(l.pop())
def __iter__(self): def __iter__(self):
with _IterationGuard(self):
for itemref in self.data: for itemref in self.data:
item = itemref() item = itemref()
if item is not None: if item is not None:
@ -34,15 +73,21 @@ class WeakSet:
getattr(self, '__dict__', None)) getattr(self, '__dict__', None))
def add(self, item): def add(self, item):
if self._pending_removals:
self._commit_removals()
self.data.add(ref(item, self._remove)) self.data.add(ref(item, self._remove))
def clear(self): def clear(self):
if self._pending_removals:
self._commit_removals()
self.data.clear() self.data.clear()
def copy(self): def copy(self):
return self.__class__(self) return self.__class__(self)
def pop(self): def pop(self):
if self._pending_removals:
self._commit_removals()
while True: while True:
try: try:
itemref = self.data.pop() itemref = self.data.pop()
@ -53,17 +98,24 @@ class WeakSet:
return item return item
def remove(self, item): def remove(self, item):
if self._pending_removals:
self._commit_removals()
self.data.remove(ref(item)) self.data.remove(ref(item))
def discard(self, item): def discard(self, item):
if self._pending_removals:
self._commit_removals()
self.data.discard(ref(item)) self.data.discard(ref(item))
def update(self, other): def update(self, other):
if self._pending_removals:
self._commit_removals()
if isinstance(other, self.__class__): if isinstance(other, self.__class__):
self.data.update(other.data) self.data.update(other.data)
else: else:
for element in other: for element in other:
self.add(element) self.add(element)
def __ior__(self, other): def __ior__(self, other):
self.update(other) self.update(other)
return self return self
@ -82,11 +134,15 @@ class WeakSet:
__sub__ = difference __sub__ = difference
def difference_update(self, other): def difference_update(self, other):
if self._pending_removals:
self._commit_removals()
if self is other: if self is other:
self.data.clear() self.data.clear()
else: else:
self.data.difference_update(ref(item) for item in other) self.data.difference_update(ref(item) for item in other)
def __isub__(self, other): def __isub__(self, other):
if self._pending_removals:
self._commit_removals()
if self is other: if self is other:
self.data.clear() self.data.clear()
else: else:
@ -98,8 +154,12 @@ class WeakSet:
__and__ = intersection __and__ = intersection
def intersection_update(self, other): def intersection_update(self, other):
if self._pending_removals:
self._commit_removals()
self.data.intersection_update(ref(item) for item in other) self.data.intersection_update(ref(item) for item in other)
def __iand__(self, other): def __iand__(self, other):
if self._pending_removals:
self._commit_removals()
self.data.intersection_update(ref(item) for item in other) self.data.intersection_update(ref(item) for item in other)
return self return self
@ -127,11 +187,15 @@ class WeakSet:
__xor__ = symmetric_difference __xor__ = symmetric_difference
def symmetric_difference_update(self, other): def symmetric_difference_update(self, other):
if self._pending_removals:
self._commit_removals()
if self is other: if self is other:
self.data.clear() self.data.clear()
else: else:
self.data.symmetric_difference_update(ref(item) for item in other) self.data.symmetric_difference_update(ref(item) for item in other)
def __ixor__(self, other): def __ixor__(self, other):
if self._pending_removals:
self._commit_removals()
if self is other: if self is other:
self.data.clear() self.data.clear()
else: else:

View File

@ -4,6 +4,8 @@ import unittest
import collections import collections
import weakref import weakref
import operator import operator
import contextlib
import copy
from test import support from test import support
@ -788,6 +790,10 @@ class Object:
self.arg = arg self.arg = arg
def __repr__(self): def __repr__(self):
return "<Object %r>" % self.arg return "<Object %r>" % self.arg
def __eq__(self, other):
if isinstance(other, Object):
return self.arg == other.arg
return NotImplemented
def __lt__(self, other): def __lt__(self, other):
if isinstance(other, Object): if isinstance(other, Object):
return self.arg < other.arg return self.arg < other.arg
@ -935,6 +941,87 @@ class MappingTestCase(TestBase):
self.assertFalse(values, self.assertFalse(values,
"itervalues() did not touch all values") "itervalues() did not touch all values")
def check_weak_destroy_while_iterating(self, dict, objects, iter_name):
n = len(dict)
it = iter(getattr(dict, iter_name)())
next(it) # Trigger internal iteration
# Destroy an object
del objects[-1]
gc.collect() # just in case
# We have removed either the first consumed object, or another one
self.assertIn(len(list(it)), [len(objects), len(objects) - 1])
del it
# The removal has been committed
self.assertEqual(len(dict), n - 1)
def check_weak_destroy_and_mutate_while_iterating(self, dict, testcontext):
# Check that we can explicitly mutate the weak dict without
# interfering with delayed removal.
# `testcontext` should create an iterator, destroy one of the
# weakref'ed objects and then return a new key/value pair corresponding
# to the destroyed object.
with testcontext() as (k, v):
self.assertFalse(k in dict)
with testcontext() as (k, v):
self.assertRaises(KeyError, dict.__delitem__, k)
self.assertFalse(k in dict)
with testcontext() as (k, v):
self.assertRaises(KeyError, dict.pop, k)
self.assertFalse(k in dict)
with testcontext() as (k, v):
dict[k] = v
self.assertEqual(dict[k], v)
ddict = copy.copy(dict)
with testcontext() as (k, v):
dict.update(ddict)
self.assertEqual(dict, ddict)
with testcontext() as (k, v):
dict.clear()
self.assertEqual(len(dict), 0)
def test_weak_keys_destroy_while_iterating(self):
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
dict, objects = self.make_weak_keyed_dict()
self.check_weak_destroy_while_iterating(dict, objects, 'keys')
self.check_weak_destroy_while_iterating(dict, objects, 'items')
self.check_weak_destroy_while_iterating(dict, objects, 'values')
self.check_weak_destroy_while_iterating(dict, objects, 'keyrefs')
dict, objects = self.make_weak_keyed_dict()
@contextlib.contextmanager
def testcontext():
try:
it = iter(dict.items())
next(it)
# Schedule a key/value for removal and recreate it
v = objects.pop().arg
gc.collect() # just in case
yield Object(v), v
finally:
it = None # should commit all removals
self.check_weak_destroy_and_mutate_while_iterating(dict, testcontext)
def test_weak_values_destroy_while_iterating(self):
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
dict, objects = self.make_weak_valued_dict()
self.check_weak_destroy_while_iterating(dict, objects, 'keys')
self.check_weak_destroy_while_iterating(dict, objects, 'items')
self.check_weak_destroy_while_iterating(dict, objects, 'values')
self.check_weak_destroy_while_iterating(dict, objects, 'itervaluerefs')
self.check_weak_destroy_while_iterating(dict, objects, 'valuerefs')
dict, objects = self.make_weak_valued_dict()
@contextlib.contextmanager
def testcontext():
try:
it = iter(dict.items())
next(it)
# Schedule a key/value for removal and recreate it
k = objects.pop().arg
gc.collect() # just in case
yield k, Object(k)
finally:
it = None # should commit all removals
self.check_weak_destroy_and_mutate_while_iterating(dict, testcontext)
def test_make_weak_keyed_dict_from_dict(self): def test_make_weak_keyed_dict_from_dict(self):
o = Object(3) o = Object(3)
dict = weakref.WeakKeyDictionary({o:364}) dict = weakref.WeakKeyDictionary({o:364})

View File

@ -10,6 +10,8 @@ import sys
import warnings import warnings
import collections import collections
from collections import UserString as ustr from collections import UserString as ustr
import gc
import contextlib
class Foo: class Foo:
@ -307,6 +309,54 @@ class TestWeakSet(unittest.TestCase):
self.assertFalse(self.s == WeakSet([Foo])) self.assertFalse(self.s == WeakSet([Foo]))
self.assertFalse(self.s == 1) self.assertFalse(self.s == 1)
def test_weak_destroy_while_iterating(self):
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
# Create new items to be sure no-one else holds a reference
items = [ustr(c) for c in ('a', 'b', 'c')]
s = WeakSet(items)
it = iter(s)
next(it) # Trigger internal iteration
# Destroy an item
del items[-1]
gc.collect() # just in case
# We have removed either the first consumed items, or another one
self.assertIn(len(list(it)), [len(items), len(items) - 1])
del it
# The removal has been committed
self.assertEqual(len(s), len(items))
def test_weak_destroy_and_mutate_while_iterating(self):
# Issue #7105: iterators shouldn't crash when a key is implicitly removed
items = [ustr(c) for c in string.ascii_letters]
s = WeakSet(items)
@contextlib.contextmanager
def testcontext():
try:
it = iter(s)
next(it)
# Schedule an item for removal and recreate it
u = ustr(str(items.pop()))
gc.collect() # just in case
yield u
finally:
it = None # should commit all removals
with testcontext() as u:
self.assertFalse(u in s)
with testcontext() as u:
self.assertRaises(KeyError, s.remove, u)
self.assertFalse(u in s)
with testcontext() as u:
s.add(u)
self.assertTrue(u in s)
t = s.copy()
with testcontext() as u:
s.update(t)
self.assertEqual(len(s), len(t))
with testcontext() as u:
s.clear()
self.assertEqual(len(s), 0)
def test_main(verbose=None): def test_main(verbose=None):
support.run_unittest(TestWeakSet) support.run_unittest(TestWeakSet)

View File

@ -18,7 +18,7 @@ from _weakref import (
ProxyType, ProxyType,
ReferenceType) ReferenceType)
from _weakrefset import WeakSet from _weakrefset import WeakSet, _IterationGuard
import collections # Import after _weakref to avoid circular import. import collections # Import after _weakref to avoid circular import.
@ -46,11 +46,25 @@ class WeakValueDictionary(collections.MutableMapping):
def remove(wr, selfref=ref(self)): def remove(wr, selfref=ref(self)):
self = selfref() self = selfref()
if self is not None: if self is not None:
if self._iterating:
self._pending_removals.append(wr.key)
else:
del self.data[wr.key] del self.data[wr.key]
self._remove = remove self._remove = remove
# A list of keys to be removed
self._pending_removals = []
self._iterating = set()
self.data = d = {} self.data = d = {}
self.update(*args, **kw) self.update(*args, **kw)
def _commit_removals(self):
l = self._pending_removals
d = self.data
# We shouldn't encounter any KeyError, because this method should
# always be called *before* mutating the dict.
while l:
del d[l.pop()]
def __getitem__(self, key): def __getitem__(self, key):
o = self.data[key]() o = self.data[key]()
if o is None: if o is None:
@ -59,6 +73,8 @@ class WeakValueDictionary(collections.MutableMapping):
return o return o
def __delitem__(self, key): def __delitem__(self, key):
if self._pending_removals:
self._commit_removals()
del self.data[key] del self.data[key]
def __len__(self): def __len__(self):
@ -75,6 +91,8 @@ class WeakValueDictionary(collections.MutableMapping):
return "<WeakValueDictionary at %s>" % id(self) return "<WeakValueDictionary at %s>" % id(self)
def __setitem__(self, key, value): def __setitem__(self, key, value):
if self._pending_removals:
self._commit_removals()
self.data[key] = KeyedRef(value, self._remove, key) self.data[key] = KeyedRef(value, self._remove, key)
def copy(self): def copy(self):
@ -110,24 +128,19 @@ class WeakValueDictionary(collections.MutableMapping):
return o return o
def items(self): def items(self):
L = [] with _IterationGuard(self):
for key, wr in self.data.items(): for k, wr in self.data.items():
o = wr() v = wr()
if o is not None: if v is not None:
L.append((key, o)) yield k, v
return L
def items(self):
for wr in self.data.values():
value = wr()
if value is not None:
yield wr.key, value
def keys(self): def keys(self):
return iter(self.data.keys()) with _IterationGuard(self):
for k, wr in self.data.items():
if wr() is not None:
yield k
def __iter__(self): __iter__ = keys
return iter(self.data.keys())
def itervaluerefs(self): def itervaluerefs(self):
"""Return an iterator that yields the weak references to the values. """Return an iterator that yields the weak references to the values.
@ -139,15 +152,20 @@ class WeakValueDictionary(collections.MutableMapping):
keep the values around longer than needed. keep the values around longer than needed.
""" """
return self.data.values() with _IterationGuard(self):
for wr in self.data.values():
yield wr
def values(self): def values(self):
with _IterationGuard(self):
for wr in self.data.values(): for wr in self.data.values():
obj = wr() obj = wr()
if obj is not None: if obj is not None:
yield obj yield obj
def popitem(self): def popitem(self):
if self._pending_removals:
self._commit_removals()
while 1: while 1:
key, wr = self.data.popitem() key, wr = self.data.popitem()
o = wr() o = wr()
@ -155,6 +173,8 @@ class WeakValueDictionary(collections.MutableMapping):
return key, o return key, o
def pop(self, key, *args): def pop(self, key, *args):
if self._pending_removals:
self._commit_removals()
try: try:
o = self.data.pop(key)() o = self.data.pop(key)()
except KeyError: except KeyError:
@ -170,12 +190,16 @@ class WeakValueDictionary(collections.MutableMapping):
try: try:
wr = self.data[key] wr = self.data[key]
except KeyError: except KeyError:
if self._pending_removals:
self._commit_removals()
self.data[key] = KeyedRef(default, self._remove, key) self.data[key] = KeyedRef(default, self._remove, key)
return default return default
else: else:
return wr() return wr()
def update(self, dict=None, **kwargs): def update(self, dict=None, **kwargs):
if self._pending_removals:
self._commit_removals()
d = self.data d = self.data
if dict is not None: if dict is not None:
if not hasattr(dict, "items"): if not hasattr(dict, "items"):
@ -195,7 +219,7 @@ class WeakValueDictionary(collections.MutableMapping):
keep the values around longer than needed. keep the values around longer than needed.
""" """
return self.data.values() return list(self.data.values())
class KeyedRef(ref): class KeyedRef(ref):
@ -235,9 +259,29 @@ class WeakKeyDictionary(collections.MutableMapping):
def remove(k, selfref=ref(self)): def remove(k, selfref=ref(self)):
self = selfref() self = selfref()
if self is not None: if self is not None:
if self._iterating:
self._pending_removals.append(k)
else:
del self.data[k] del self.data[k]
self._remove = remove self._remove = remove
if dict is not None: self.update(dict) # A list of dead weakrefs (keys to be removed)
self._pending_removals = []
self._iterating = set()
if dict is not None:
self.update(dict)
def _commit_removals(self):
# NOTE: We don't need to call this method before mutating the dict,
# because a dead weakref never compares equal to a live weakref,
# even if they happened to refer to equal objects.
# However, it means keys may already have been removed.
l = self._pending_removals
d = self.data
while l:
try:
del d[l.pop()]
except KeyError:
pass
def __delitem__(self, key): def __delitem__(self, key):
del self.data[ref(key)] del self.data[ref(key)]
@ -284,34 +328,26 @@ class WeakKeyDictionary(collections.MutableMapping):
return wr in self.data return wr in self.data
def items(self): def items(self):
with _IterationGuard(self):
for wr, value in self.data.items(): for wr, value in self.data.items():
key = wr() key = wr()
if key is not None: if key is not None:
yield key, value yield key, value
def keyrefs(self):
"""Return an iterator that yields the weak references to the keys.
The references are not guaranteed to be 'live' at the time
they are used, so the result of calling the references needs
to be checked before being used. This can be used to avoid
creating references that will cause the garbage collector to
keep the keys around longer than needed.
"""
return self.data.keys()
def keys(self): def keys(self):
for wr in self.data.keys(): with _IterationGuard(self):
for wr in self.data:
obj = wr() obj = wr()
if obj is not None: if obj is not None:
yield obj yield obj
def __iter__(self): __iter__ = keys
return iter(self.keys())
def values(self): def values(self):
return iter(self.data.values()) with _IterationGuard(self):
for wr, value in self.data.items():
if wr() is not None:
yield value
def keyrefs(self): def keyrefs(self):
"""Return a list of weak references to the keys. """Return a list of weak references to the keys.
@ -323,7 +359,7 @@ class WeakKeyDictionary(collections.MutableMapping):
keep the keys around longer than needed. keep the keys around longer than needed.
""" """
return self.data.keys() return list(self.data)
def popitem(self): def popitem(self):
while 1: while 1:

View File

@ -194,6 +194,9 @@ C-API
Library Library
------- -------
- Issue #7105: Make WeakKeyDictionary and WeakValueDictionary robust against
the destruction of weakref'ed objects while iterating.
- Issue #7455: Fix possible crash in cPickle on invalid input. Patch by - Issue #7455: Fix possible crash in cPickle on invalid input. Patch by
Victor Stinner. Victor Stinner.