diff --git a/Doc/library/weakref.rst b/Doc/library/weakref.rst --- a/Doc/library/weakref.rst +++ b/Doc/library/weakref.rst @@ -163,7 +163,7 @@ .. method:: WeakKeyDictionary.iterkeyrefs() - Return an :term:`iterator` that yields the weak references to the keys. + Return an iterable of the weak references to the keys. .. versionadded:: 2.5 @@ -195,7 +195,7 @@ .. method:: WeakValueDictionary.itervaluerefs() - Return an :term:`iterator` that yields the weak references to the values. + Return an iterable of the weak references to the values. .. versionadded:: 2.5 diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py --- a/Lib/test/test_weakref.py +++ b/Lib/test/test_weakref.py @@ -4,6 +4,8 @@ import UserList import weakref import operator +import contextlib +import copy from test import test_support @@ -906,7 +908,7 @@ def check_len_cycles(self, dict_type, cons): N = 20 items = [RefCycle() for i in range(N)] - dct = dict_type(cons(o) for o in items) + dct = dict_type(cons(i, o) for i, o in enumerate(items)) # Keep an iterator alive it = dct.iteritems() try: @@ -916,18 +918,20 @@ del items gc.collect() n1 = len(dct) + list(it) del it gc.collect() n2 = len(dct) - # one item may be kept alive inside the iterator - self.assertIn(n1, (0, 1)) + # iteration should prevent garbage collection here + self.assertEqual(n1, 20) + #self.assertIn(n1, (0, 1)) self.assertEqual(n2, 0) def test_weak_keyed_len_cycles(self): - self.check_len_cycles(weakref.WeakKeyDictionary, lambda k: (k, 1)) + self.check_len_cycles(weakref.WeakKeyDictionary, lambda n, k: (k, n)) def test_weak_valued_len_cycles(self): - self.check_len_cycles(weakref.WeakValueDictionary, lambda k: (1, k)) + self.check_len_cycles(weakref.WeakValueDictionary, lambda n, k: (n, k)) def check_len_race(self, dict_type, cons): # Extended sanity checks for len() in the face of cyclic collection @@ -1093,6 +1097,86 @@ self.assertTrue(len(values) == 0, "itervalues() did not touch all values") + def check_weak_destroy_while_iterating(self, dict, objects, iter_name): + n = len(dict) + it = iter(getattr(dict, iter_name)()) + next(it) # Trigger internal iteration + # Destroy an object + del objects[-1] + gc.collect() # just in case + # We have removed either the first consumed object, or another one + self.assertIn(len(list(it)), [len(objects), len(objects) - 1]) + del it + # The removal has been committed + self.assertEqual(len(dict), n - 1) + + def check_weak_destroy_and_mutate_while_iterating(self, dict, testcontext): + # Check that we can explicitly mutate the weak dict without + # interfering with delayed removal. + # `testcontext` should create an iterator, destroy one of the + # weakref'ed objects and then return a new key/value pair corresponding + # to the destroyed object. + with testcontext() as (k, v): + self.assertFalse(k in dict) + with testcontext() as (k, v): + self.assertRaises(KeyError, dict.__delitem__, k) + self.assertFalse(k in dict) + with testcontext() as (k, v): + self.assertRaises(KeyError, dict.pop, k) + self.assertFalse(k in dict) + with testcontext() as (k, v): + dict[k] = v + self.assertEqual(dict[k], v) + ddict = copy.copy(dict) + with testcontext() as (k, v): + dict.update(ddict) + self.assertEqual(dict, ddict) + with testcontext() as (k, v): + dict.clear() + self.assertEqual(len(dict), 0) + + def test_weak_keys_destroy_while_iterating(self): + # Issue #7105: iterators shouldn't crash when a key is implicitly removed + dict, objects = self.make_weak_keyed_dict() + self.check_weak_destroy_while_iterating(dict, objects, 'iterkeys') + self.check_weak_destroy_while_iterating(dict, objects, 'iteritems') + self.check_weak_destroy_while_iterating(dict, objects, 'itervalues') + self.check_weak_destroy_while_iterating(dict, objects, 'iterkeyrefs') + dict, objects = self.make_weak_keyed_dict() + @contextlib.contextmanager + def testcontext(): + try: + it = iter(dict.iteritems()) + next(it) + # Schedule a key/value for removal and recreate it + v = objects.pop().arg + gc.collect() # just in case + yield Object(v), v + finally: + it = None # should commit all removals + self.check_weak_destroy_and_mutate_while_iterating(dict, testcontext) + + def test_weak_values_destroy_while_iterating(self): + # Issue #7105: iterators shouldn't crash when a key is implicitly removed + dict, objects = self.make_weak_valued_dict() + self.check_weak_destroy_while_iterating(dict, objects, 'iterkeys') + self.check_weak_destroy_while_iterating(dict, objects, 'iteritems') + self.check_weak_destroy_while_iterating(dict, objects, 'itervalues') + self.check_weak_destroy_while_iterating(dict, objects, 'itervaluerefs') + dict, objects = self.make_weak_valued_dict() + @contextlib.contextmanager + def testcontext(): + try: + it = iter(dict.iteritems()) + next(it) + # Schedule a key/value for removal and recreate it + k = objects.pop().arg + gc.collect() # just in case + yield k, Object(k) + finally: + it = None # should commit all removals + self.check_weak_destroy_and_mutate_while_iterating(dict, testcontext) + def test_make_weak_keyed_dict_from_dict(self): o = Object(3) dict = weakref.WeakKeyDictionary({o:364}) diff --git a/Lib/test/test_weakset.py b/Lib/test/test_weakset.py --- a/Lib/test/test_weakset.py +++ b/Lib/test/test_weakset.py @@ -11,6 +11,7 @@ import collections import gc import contextlib +from UserString import UserString as ustr class Foo: @@ -448,6 +449,54 @@ self.assertGreaterEqual(n2, 0) self.assertLessEqual(n2, n1) + def test_weak_destroy_while_iterating(self): + # Issue #7105: iterators shouldn't crash when a key is implicitly removed + # Create new items to be sure no-one else holds a reference + items = [ustr(c) for c in ('a', 'b', 'c')] + s = WeakSet(items) + it = iter(s) + next(it) # Trigger internal iteration + # Destroy an item + del items[-1] + gc.collect() # just in case + # We have removed either the first consumed items, or another one + self.assertIn(len(list(it)), [len(items), len(items) - 1]) + del it + # The removal has been committed + self.assertEqual(len(s), len(items)) + + def test_weak_destroy_and_mutate_while_iterating(self): + # Issue #7105: iterators shouldn't crash when a key is implicitly removed + items = [ustr(c) for c in string.ascii_letters] + s = WeakSet(items) + @contextlib.contextmanager + def testcontext(): + try: + it = iter(s) + next(it) + # Schedule an item for removal and recreate it + u = ustr(str(items.pop())) + gc.collect() # just in case + yield u + finally: + it = None # should commit all removals + + with testcontext() as u: + self.assertFalse(u in s) + with testcontext() as u: + self.assertRaises(KeyError, s.remove, u) + self.assertFalse(u in s) + with testcontext() as u: + s.add(u) + self.assertTrue(u in s) + t = s.copy() + with testcontext() as u: + s.update(t) + self.assertEqual(len(s), len(t)) + with testcontext() as u: + s.clear() + self.assertEqual(len(s), 0) + def test_main(verbose=None): test_support.run_unittest(TestWeakSet) diff --git a/Lib/weakref.py b/Lib/weakref.py --- a/Lib/weakref.py +++ b/Lib/weakref.py @@ -20,7 +20,7 @@ ProxyType, ReferenceType) -from _weakrefset import WeakSet +from _weakrefset import WeakSet, _IterationGuard from exceptions import ReferenceError @@ -48,10 +48,24 @@ def remove(wr, selfref=ref(self)): self = selfref() if self is not None: - del self.data[wr.key] + if self._iterating: + self._pending_removals.append(wr.key) + else: + del self.data[wr.key] self._remove = remove + # A list of keys to be removed + self._pending_removals = [] + self._iterating = set() UserDict.UserDict.__init__(self, *args, **kw) + def _commit_removals(self): + l = self._pending_removals + d = self.data + # We shouldn't encounter any KeyError, because this method should + # always be called *before* mutating the dict. + while l: + del d[l.pop()] + def __getitem__(self, key): o = self.data[key]() if o is None: @@ -59,6 +73,11 @@ else: return o + def __delitem__(self, key): + if self._pending_removals: + self._commit_removals() + del self.data[key] + def __contains__(self, key): try: o = self.data[key]() @@ -77,8 +96,15 @@ return "" % id(self) def __setitem__(self, key, value): + if self._pending_removals: + self._commit_removals() self.data[key] = KeyedRef(value, self._remove, key) + def clear(self): + if self._pending_removals: + self._commit_removals() + self.data.clear() + def copy(self): new = WeakValueDictionary() for key, wr in self.data.items(): @@ -120,16 +146,18 @@ return L def iteritems(self): - for wr in self.data.itervalues(): - value = wr() - if value is not None: - yield wr.key, value + with _IterationGuard(self): + for wr in self.data.itervalues(): + value = wr() + if value is not None: + yield wr.key, value def iterkeys(self): - return self.data.iterkeys() + with _IterationGuard(self): + for k in self.data.iterkeys(): + yield k - def __iter__(self): - return self.data.iterkeys() + __iter__ = iterkeys def itervaluerefs(self): """Return an iterator that yields the weak references to the values. @@ -141,15 +169,20 @@ keep the values around longer than needed. """ - return self.data.itervalues() + with _IterationGuard(self): + for wr in self.data.itervalues(): + yield wr def itervalues(self): - for wr in self.data.itervalues(): - obj = wr() - if obj is not None: - yield obj + with _IterationGuard(self): + for wr in self.data.itervalues(): + obj = wr() + if obj is not None: + yield obj def popitem(self): + if self._pending_removals: + self._commit_removals() while 1: key, wr = self.data.popitem() o = wr() @@ -157,6 +190,8 @@ return key, o def pop(self, key, *args): + if self._pending_removals: + self._commit_removals() try: o = self.data.pop(key)() except KeyError: @@ -172,12 +207,16 @@ try: wr = self.data[key] except KeyError: + if self._pending_removals: + self._commit_removals() self.data[key] = KeyedRef(default, self._remove, key) return default else: return wr() def update(self, dict=None, **kwargs): + if self._pending_removals: + self._commit_removals() d = self.data if dict is not None: if not hasattr(dict, "items"): @@ -245,9 +284,29 @@ def remove(k, selfref=ref(self)): self = selfref() if self is not None: - del self.data[k] + if self._iterating: + self._pending_removals.append(k) + else: + del self.data[k] self._remove = remove - if dict is not None: self.update(dict) + # A list of dead weakrefs (keys to be removed) + self._pending_removals = [] + self._iterating = set() + if dict is not None: + self.update(dict) + + def _commit_removals(self): + # NOTE: We don't need to call this method before mutating the dict, + # because a dead weakref never compares equal to a live weakref, + # even if they happened to refer to equal objects. + # However, it means keys may already have been removed. + l = self._pending_removals + d = self.data + while l: + try: + del d[l.pop()] + except KeyError: + pass def __delitem__(self, key): del self.data[ref(key)] @@ -306,10 +365,11 @@ return L def iteritems(self): - for wr, value in self.data.iteritems(): - key = wr() - if key is not None: - yield key, value + with _IterationGuard(self): + for wr, value in self.data.iteritems(): + key = wr() + if key is not None: + yield key, value def iterkeyrefs(self): """Return an iterator that yields the weak references to the keys. @@ -321,19 +381,23 @@ keep the keys around longer than needed. """ - return self.data.iterkeys() + with _IterationGuard(self): + for wr in self.data.iterkeys(): + yield wr def iterkeys(self): - for wr in self.data.iterkeys(): - obj = wr() - if obj is not None: - yield obj + with _IterationGuard(self): + for wr in self.data.iterkeys(): + obj = wr() + if obj is not None: + yield obj - def __iter__(self): - return self.iterkeys() + __iter__ = iterkeys def itervalues(self): - return self.data.itervalues() + with _IterationGuard(self): + for value in self.data.itervalues(): + yield value def keyrefs(self): """Return a list of weak references to the keys.