Index: Lib/weakref.py =================================================================== --- Lib/weakref.py (revision 61050) +++ Lib/weakref.py (working copy) @@ -29,7 +29,6 @@ "WeakKeyDictionary", "ReferenceType", "ProxyType", "CallableProxyType", "ProxyTypes", "WeakValueDictionary"] - class WeakValueDictionary(UserDict.UserDict): """Mapping class that references values weakly. @@ -43,13 +42,31 @@ # way in). def __init__(self, *args, **kw): - def remove(wr, selfref=ref(self)): + def iter_is_dead(wr, selfref=ref(self)): self = selfref() - if self is not None: + if self is None: + return + self._iterators_wr.remove(wr) + if not self._iterators_wr: + for key in self._removal_list: + del self.data[key] + self._removal_list = [] + + def value_is_dead(wr, selfref=ref(self)): + self = selfref() + if self is None: + return + if self._iterators_wr: + self._removal_list.append(wr.key) + else: del self.data[wr.key] - self._remove = remove + + self._iter_is_dead = iter_is_dead + self._value_is_dead = value_is_dead + self._iterators_wr = set() + self._removal_list = [] UserDict.UserDict.__init__(self, *args, **kw) - + def __getitem__(self, key): o = self.data[key]() if o is None: @@ -75,7 +92,7 @@ return "" % id(self) def __setitem__(self, key, value): - self.data[key] = KeyedRef(value, self._remove, key) + self.data[key] = KeyedRef(value, self._value_is_dead, key) def copy(self): new = WeakValueDictionary() @@ -99,12 +116,7 @@ return o def items(self): - L = [] - for key, wr in self.data.items(): - o = wr() - if o is not None: - L.append((key, o)) - return L + return list(self.iteritems()) def iteritems(self): for wr in self.data.itervalues(): @@ -113,10 +125,12 @@ yield wr.key, value def iterkeys(self): - return self.data.iterkeys() + result = (key for key, wr in self.data.iteritems() if wr() is not None) + self._iterators_wr.add(ref(result, self._iter_is_dead)) + return result def __iter__(self): - return self.data.iterkeys() + return self.iterkeys() def itervaluerefs(self): """Return an iterator that yields the weak references to the values. @@ -159,18 +173,17 @@ try: wr = self.data[key] except KeyError: - self.data[key] = KeyedRef(default, self._remove, key) + self[key] = default return default else: return wr() def update(self, dict=None, **kwargs): - d = self.data if dict is not None: if not hasattr(dict, "items"): dict = type({})(dict) for key, o in dict.items(): - d[key] = KeyedRef(o, self._remove, key) + self[key] = o if len(kwargs): self.update(kwargs) @@ -187,14 +200,9 @@ return self.data.values() def values(self): - L = [] - for wr in self.data.values(): - o = wr() - if o is not None: - L.append(o) - return L + return list(self.itervalues()) + - class KeyedRef(ref): """Specialized reference that includes a key corresponding to the value. @@ -227,15 +235,32 @@ accesses. """ - def __init__(self, dict=None): - self.data = {} - def remove(k, selfref=ref(self)): + def __init__(self, *args, **kw): + def iter_is_dead(wr, selfref=ref(self)): self = selfref() - if self is not None: - del self.data[k] - self._remove = remove - if dict is not None: self.update(dict) - + if self is None: + return + self._iterators_wr.remove(wr) + if not self._iterators_wr: + for key in self._removal_list: + del self.data[key] + self._removal_list = [] + + def key_is_dead(wr, selfref=ref(self)): + self = selfref() + if self is None: + return + if self._iterators_wr: + self._removal_list.append(wr) + else: + del self.data[wr] + + self._iter_is_dead = iter_is_dead + self._key_is_dead = key_is_dead + self._iterators_wr = set() + self._removal_list = [] + UserDict.UserDict.__init__(self, *args, **kw) + def __delitem__(self, key): del self.data[ref(key)] @@ -246,7 +271,7 @@ return "" % id(self) def __setitem__(self, key, value): - self.data[ref(key, self._remove)] = value + self.data[ref(key, self._key_is_dead)] = value def copy(self): new = WeakKeyDictionary() @@ -274,12 +299,7 @@ return wr in self.data def items(self): - L = [] - for key, value in self.data.items(): - o = key() - if o is not None: - L.append((o, value)) - return L + return list(self.iteritems()) def iteritems(self): for wr, value in self.data.iteritems(): @@ -301,15 +321,17 @@ def iterkeys(self): for wr in self.data.iterkeys(): - obj = wr() - if obj is not None: - yield obj + key = wr() + if key is not None: + yield key def __iter__(self): return self.iterkeys() def itervalues(self): - return self.data.itervalues() + result = (val for key, val in self.iteritems()) + self._iterators_wr.add(ref(result, self._iter_is_dead)) + return result def keyrefs(self): """Return a list of weak references to the keys. @@ -324,12 +346,7 @@ return self.data.keys() def keys(self): - L = [] - for wr in self.data.keys(): - o = wr() - if o is not None: - L.append(o) - return L + return list(self.iterkeys()) def popitem(self): while 1: @@ -342,7 +359,7 @@ return self.data.pop(ref(key), *args) def setdefault(self, key, default=None): - return self.data.setdefault(ref(key, self._remove),default) + return self.data.setdefault(ref(key, self._key_is_dead),default) def update(self, dict=None, **kwargs): d = self.data @@ -350,6 +367,6 @@ if not hasattr(dict, "items"): dict = type({})(dict) for key, value in dict.items(): - d[ref(key, self._remove)] = value + d[ref(key, self._key_is_dead)] = value if len(kwargs): self.update(kwargs) Index: Lib/test/test_weakref.py =================================================================== --- Lib/test/test_weakref.py (revision 61050) +++ Lib/test/test_weakref.py (working copy) @@ -1042,6 +1042,66 @@ del d[o] self.assertEqual(len(d), 0) self.assertEqual(count, 2) + + def test_weak_valued_dict_gc_collect_during_iterkeys(self): + # if a value goes out during iteration, just don't iterate through it + # and finish iteration + self.COUNT = 2 + d, objects = self.make_weak_valued_dict() + it = d.iterkeys() + del objects[0] + key = it.next() + self.assertEqual(key, 1) + + def test_weak_valued_dict_flushed_dead_items_when_iters_go_out(self): + # remove dead values from the dict when iterators go out + d, objects = self.make_weak_valued_dict() + it1 = d.iterkeys() + it2 = d.iterkeys() + del objects[0] + self.assertEqual(len(d), self.COUNT) + del it1 + self.assertEqual(len(d), self.COUNT) + del it2 + self.assertEqual(len(d), self.COUNT - 1) + + def test_weak_keyed_dict_gc_collect_during_itervalues(self): + # if a key goes out during iteration, just don't iterate through it + # and finish iteration + self.COUNT = 2 + d, objects = self.make_weak_keyed_dict() + it = d.itervalues() + del objects[0] + value = it.next() + self.assertEqual(value, 1) + + def test_weak_keyed_dict_flushed_dead_items_when_iters_go_out(self): + # remove dead values from the dict when iterators go out + d, objects = self.make_weak_keyed_dict() + it1 = d.itervalues() + it2 = d.itervalues() + del objects[0] + self.assertEqual(len(d), self.COUNT) + del it1 + self.assertEqual(len(d), self.COUNT) + del it2 + self.assertEqual(len(d), self.COUNT - 1) + + def test_weak_keyed_dict_are_not_held_by_their_keys(self): + d = weakref.WeakKeyDictionary() + o = Object(None) + d[o] = 'foo' + wr = weakref.ref(d) + del(d) + self.assert_(wr() is None) + + def test_weak_valued_dict_are_not_held_by_their_values(self): + d = weakref.WeakValueDictionary() + o = Object(None) + d['foo'] = o + wr = weakref.ref(d) + del(d) + self.assert_(wr() is None) from test import mapping_tests