diff --git a/Doc/library/collections.rst b/Doc/library/collections.rst --- a/Doc/library/collections.rst +++ b/Doc/library/collections.rst @@ -685,6 +685,10 @@ stack manipulations such as ``dup``, ``drop``, ``swap``, ``over``, ``pick``, initialized from the first argument to the constructor, if present, or to ``None``, if absent. + .. versionchanged:: 3.6 + :class:`defaultdict` is now available in pure Python for alternate + implementations. + :class:`defaultdict` Examples ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/Lib/collections/__init__.py b/Lib/collections/__init__.py --- a/Lib/collections/__init__.py +++ b/Lib/collections/__init__.py @@ -14,8 +14,8 @@ list, set, and tuple. ''' -__all__ = ['deque', 'defaultdict', 'namedtuple', 'UserDict', 'UserList', - 'UserString', 'Counter', 'OrderedDict', 'ChainMap'] +__all__ = ['defaultdict', 'namedtuple', 'UserDict', 'UserList', + 'UserString', 'Counter', 'OrderedDict', 'ChainMap'] # For backwards compatibility, continue to make the collections ABCs # available through the collections module. @@ -37,6 +37,68 @@ except ImportError: pass else: MutableSequence.register(deque) + # we insert deque at the front of __all__ to keep it somewhat categorized + __all__.insert(0, 'deque') + +################################################################################ +### defaultdict +################################################################################ + +class defaultdict(dict): + """defaultdict(default_factory[, ...]) --> dict with default factory + + The default factory is called without arguments to produce + a new value when a key is not present, in __getitem__ only. + A defaultdict compares equal to a dict with the same items. + All remaining arguments are treated the same as if they were + passed to the dict constructor, including keyword arguments. + + """ + + def __init__(*args, **kwds): + if not args: + raise TypeError("descriptor '__init__' of 'collections." + "defaultdict' needs an argument") + self, *args = args + if args: + factory, *args = args + if factory is not None and not callable(factory): + raise TypeError("first argument must be callable or None") + self.default_factory = factory + else: + self.default_factory = None + + dict.__init__(self, *args, **kwds) + + def __missing__(self, key): + """Called by __getitem__ for missing key.""" + if self.default_factory is None: + raise KeyError(key) + self[key] = value = self.default_factory() + return value + + def copy(self): + """D.copy() -> a shallow copy of D.""" + return type(self)(self.default_factory, self) + + __copy__ = copy + + def __reduce__(self): + # __reduce__ must return a 5-tuple as follows: + # - factory function (constructor) + # - tuple of args for the factory function + # - additional state (here None) + # - sequence iterator (here None) + # - dictionary iterator (yielding successive (key, value) pairs) + return type(self), (self.default_factory,), None, None, iter(self.items()) + + def __repr__(self): + try: + factory_repr = repr(self.default_factory) + except RecursionError: + factory_repr = "..." + dict_repr = dict.__repr__(self) + return f"defaultdict({factory_repr}, {dict_repr})" try: from _collections import defaultdict diff --git a/Lib/test/test_defaultdict.py b/Lib/test/test_defaultdict.py --- a/Lib/test/test_defaultdict.py +++ b/Lib/test/test_defaultdict.py @@ -1,20 +1,35 @@ """Unit tests for collections.defaultdict.""" import os +import sys import copy import pickle import tempfile import unittest +import contextlib -from collections import defaultdict +from test.support import import_fresh_module + +c_collections = import_fresh_module("collections", fresh=["_collections"]) +py_collections = import_fresh_module("collections", blocked=["_collections"]) + +# Taken from test_ordered_dict.py (Should this be in test.support?) +@contextlib.contextmanager +def replaced_module(name, replacement): + original_module = sys.modules[name] + sys.modules[name] = replacement + try: + yield + finally: + sys.modules[name] = original_module def foobar(): return list -class TestDefaultDict(unittest.TestCase): +class DefaultDictTests: def test_basic(self): - d1 = defaultdict() + d1 = self.defaultdict() self.assertEqual(d1.default_factory, None) d1.default_factory = list d1[12].append(42) @@ -25,7 +40,7 @@ class TestDefaultDict(unittest.TestCase): d1[14] self.assertEqual(d1, {12: [42, 24], 13: [], 14: []}) self.assertTrue(d1[12] is not d1[13] is not d1[14]) - d2 = defaultdict(list, foo=1, bar=2) + d2 = self.defaultdict(list, foo=1, bar=2) self.assertEqual(d2.default_factory, list) self.assertEqual(d2, {"foo": 1, "bar": 2}) self.assertEqual(d2["foo"], 1) @@ -47,35 +62,35 @@ class TestDefaultDict(unittest.TestCase): self.assertEqual(err.args, (15,)) else: self.fail("d2[15] didn't raise KeyError") - self.assertRaises(TypeError, defaultdict, 1) + self.assertRaises(TypeError, self.defaultdict, 1) def test_missing(self): - d1 = defaultdict() + d1 = self.defaultdict() self.assertRaises(KeyError, d1.__missing__, 42) d1.default_factory = list self.assertEqual(d1.__missing__(42), []) def test_repr(self): - d1 = defaultdict() + d1 = self.defaultdict() self.assertEqual(d1.default_factory, None) self.assertEqual(repr(d1), "defaultdict(None, {})") - self.assertEqual(eval(repr(d1)), d1) + self.assertEqual(eval("self." + repr(d1)), d1) d1[11] = 41 self.assertEqual(repr(d1), "defaultdict(None, {11: 41})") - d2 = defaultdict(int) + d2 = self.defaultdict(int) self.assertEqual(d2.default_factory, int) d2[12] = 42 self.assertEqual(repr(d2), "defaultdict(, {12: 42})") def foo(): return 43 - d3 = defaultdict(foo) + d3 = self.defaultdict(foo) self.assertTrue(d3.default_factory is foo) d3[13] self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo)) def test_print(self): - d1 = defaultdict() + d1 = self.defaultdict() def foo(): return 42 - d2 = defaultdict(foo, {1: 2}) + d2 = self.defaultdict(foo, {1: 2}) # NOTE: We can't use tempfile.[Named]TemporaryFile since this # code must exercise the tp_print C code, which only gets # invoked for *real* files. @@ -94,32 +109,32 @@ class TestDefaultDict(unittest.TestCase): os.remove(tfn) def test_copy(self): - d1 = defaultdict() + d1 = self.defaultdict() d2 = d1.copy() - self.assertEqual(type(d2), defaultdict) + self.assertEqual(type(d2), self.defaultdict) self.assertEqual(d2.default_factory, None) self.assertEqual(d2, {}) d1.default_factory = list d3 = d1.copy() - self.assertEqual(type(d3), defaultdict) + self.assertEqual(type(d3), self.defaultdict) self.assertEqual(d3.default_factory, list) self.assertEqual(d3, {}) d1[42] d4 = d1.copy() - self.assertEqual(type(d4), defaultdict) + self.assertEqual(type(d4), self.defaultdict) self.assertEqual(d4.default_factory, list) self.assertEqual(d4, {42: []}) d4[12] self.assertEqual(d4, {42: [], 12: []}) # Issue 6637: Copy fails for empty default dict - d = defaultdict() + d = self.defaultdict() d['a'] = 42 e = d.copy() self.assertEqual(e['a'], 42) def test_shallow_copy(self): - d1 = defaultdict(foobar, {1: 1}) + d1 = self.defaultdict(foobar, {1: 1}) d2 = copy.copy(d1) self.assertEqual(d2.default_factory, foobar) self.assertEqual(d2, d1) @@ -129,7 +144,7 @@ class TestDefaultDict(unittest.TestCase): self.assertEqual(d2, d1) def test_deep_copy(self): - d1 = defaultdict(foobar, {1: [1]}) + d1 = self.defaultdict(foobar, {1: [1]}) d2 = copy.deepcopy(d1) self.assertEqual(d2.default_factory, foobar) self.assertEqual(d2, d1) @@ -140,7 +155,7 @@ class TestDefaultDict(unittest.TestCase): self.assertEqual(d2, d1) def test_keyerror_without_factory(self): - d1 = defaultdict() + d1 = self.defaultdict() try: d1[(1,)] except KeyError as err: @@ -150,7 +165,7 @@ class TestDefaultDict(unittest.TestCase): def test_recursive_repr(self): # Issue2045: stack overflow when default_factory is a bound method - class sub(defaultdict): + class sub(self.defaultdict): def __init__(self): self.default_factory = self._factory def _factory(self): @@ -173,15 +188,29 @@ class TestDefaultDict(unittest.TestCase): os.remove(tfn) def test_callable_arg(self): - self.assertRaises(TypeError, defaultdict, {}) + self.assertRaises(TypeError, self.defaultdict, {}) - def test_pickleing(self): - d = defaultdict(int) + def _test_pickleing(self): + d = self.defaultdict(int) d[1] for proto in range(pickle.HIGHEST_PROTOCOL + 1): s = pickle.dumps(d, proto) o = pickle.loads(s) self.assertEqual(d, o) +class TestCDefaultDict(DefaultDictTests, unittest.TestCase): + defaultdict = c_collections.defaultdict + + def test_pickleing(self): + self._test_pickleing() + +class TestPyDefaultDict(DefaultDictTests, unittest.TestCase): + defaultdict = py_collections.defaultdict + + def test_pickleing(self): + # pickle directly pulls the module, so we have to fake it + with replaced_module("collections", py_collections): + self._test_pickleing() + if __name__ == "__main__": unittest.main()