Index: Objects/dictobject.c =================================================================== --- Objects/dictobject.c (revision 42537) +++ Objects/dictobject.c (working copy) @@ -8,6 +8,7 @@ */ #include "Python.h" +#include "ceval.h" typedef PyDictEntry dictentry; typedef PyDictObject dictobject; @@ -882,8 +883,15 @@ return NULL; } v = (mp->ma_lookup)(mp, key, hash) -> me_value; - if (v == NULL) - PyErr_SetObject(PyExc_KeyError, key); + if (v == NULL) { + if (PyDict_CheckExact(mp)) { + /* Avoid infinite recursion via interning. */ + PyErr_SetObject(PyExc_KeyError, key); + return NULL; + } + return PyEval_CallMethod((PyObject *)mp, + "on_missing", "(O)", key); + } else Py_INCREF(v); return v; @@ -1756,6 +1764,12 @@ return dictiter_new(dict, &PyDictIterItem_Type); } +static PyObject * +dict_on_missing(dictobject *dict, PyObject *key) +{ + PyErr_SetObject(PyExc_KeyError, key); + return NULL; +} PyDoc_STRVAR(has_key__doc__, "D.has_key(k) -> True if D has a key k, else False"); @@ -1811,6 +1825,10 @@ PyDoc_STRVAR(iteritems__doc__, "D.iteritems() -> an iterator over the (key, value) items of D"); +PyDoc_STRVAR(on_missing__doc__, +"D.on_missing(key) raises KeyError(key)\n\ +This is a hook called by __getitem__() when the key is not found."); + static PyMethodDef mapp_methods[] = { {"__contains__",(PyCFunction)dict_has_key, METH_O | METH_COEXIST, contains__doc__}, @@ -1846,6 +1864,8 @@ itervalues__doc__}, {"iteritems", (PyCFunction)dict_iteritems, METH_NOARGS, iteritems__doc__}, + {"on_missing", (PyCFunction)dict_on_missing, METH_O, + on_missing__doc__}, {NULL, NULL} /* sentinel */ }; Index: Lib/UserDict.py =================================================================== --- Lib/UserDict.py (revision 42537) +++ Lib/UserDict.py (working copy) @@ -14,7 +14,12 @@ else: return cmp(self.data, dict) def __len__(self): return len(self.data) - def __getitem__(self, key): return self.data[key] + def __getitem__(self, key): + if key in self.data: + return self.data[key] + return self.on_missing(key) + def on_missing(self, key): + raise KeyError(key) def __setitem__(self, key, item): self.data[key] = item def __delitem__(self, key): del self.data[key] def clear(self): self.data.clear() Index: Lib/test/test_dict.py =================================================================== --- Lib/test/test_dict.py (revision 42537) +++ Lib/test/test_dict.py (working copy) @@ -395,6 +395,46 @@ else: self.fail("< didn't raise Exc") + def test_on_missing(self): + # on_missing() must exist, and raise KeyError() + self.assert_(hasattr(dict, "on_missing")) + d = {} + self.assert_(hasattr(d, "on_missing")) + self.assertRaises(KeyError, d.on_missing, 42) + self.assertRaises(KeyError, dict.on_missing, d, 42) + + def test_on_missing_subclass(self): + class D(dict): + def on_missing(self, key): + return 42 + d = D({1: 2, 3: 4}) + self.assertEqual(d[1], 2) + self.assertEqual(d[3], 4) + self.assert_(2 not in d) + self.assert_(2 not in d.keys()) + self.assertEqual(d[2], 42) + class E(dict): + def on_missing(self, key): + raise RuntimeError(key) + e = E() + try: + e[42] + except RuntimeError, err: + self.assertEqual(err.args, (42,)) + else: + self.fail_("e[42] didn't raise RuntimeError") + class F(dict): + def on_missing(self, key): + return super(F, self).on_missing(key) + f = F() + try: + f[42] + except KeyError, err: + self.assertEqual(err.args, (42,)) + else: + self.fail_("e[42] didn't raise KeyError") + + import mapping_tests class GeneralMappingTests(mapping_tests.BasicTestMappingProtocol): Index: Lib/test/test_defaultdict.py =================================================================== --- Lib/test/test_defaultdict.py (revision 0) +++ Lib/test/test_defaultdict.py (revision 0) @@ -0,0 +1,90 @@ +"""Unit tests for collections.defaultdict.""" + +import os +import tempfile +import unittest + +from collections import defaultdict + +class TestDefaultDict(unittest.TestCase): + + def test_basic(self): + d1 = defaultdict() + self.assertEqual(d1.default_factory, None) + d1.default_factory = list + d1[12].append(42) + self.assertEqual(d1, {12: [42]}) + d1[12].append(24) + self.assertEqual(d1, {12: [42, 24]}) + d1[13] + d1[14] + self.assertEqual(d1, {12: [42, 24], 13: [], 14: []}) + self.assert_(d1[12] is not d1[13] is not d1[14]) + d2 = defaultdict(list, foo=1, bar=2) + self.assertEqual(d2.default_factory, list) + self.assertEqual(d2, {"foo": 1, "bar": 2}) + self.assertEqual(d2["foo"], 1) + self.assertEqual(d2["bar"], 2) + self.assertEqual(d2[42], []) + self.assert_("foo" in d2) + self.assert_("foo" in d2.keys()) + self.assert_("bar" in d2) + self.assert_("bar" in d2.keys()) + self.assert_(42 in d2) + self.assert_(42 in d2.keys()) + self.assert_(12 not in d2) + self.assert_(12 not in d2.keys()) + d2.default_factory = None + self.assertEqual(d2.default_factory, None) + try: + d2[15] + except KeyError, err: + self.assertEqual(err.args, (15,)) + else: + self.fail("d2[15] didn't raise KeyError") + + def test_on_missing(self): + d1 = defaultdict() + self.assertRaises(KeyError, d1.on_missing, 42) + d1.default_factory = list + self.assertEqual(d1.on_missing(42), []) + + def test_repr(self): + d1 = defaultdict() + self.assertEqual(d1.default_factory, None) + self.assertEqual(repr(d1), "defaultdict(None, {})") + d1[11] = 41 + self.assertEqual(repr(d1), "defaultdict(None, {11: 41})") + d2 = defaultdict(0) + self.assertEqual(d2.default_factory, 0) + d2[12] = 42 + self.assertEqual(repr(d2), "defaultdict(0, {12: 42})") + def foo(): return 43 + d3 = defaultdict(foo) + self.assert_(d3.default_factory is foo) + d3[13] + self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo)) + + def test_print(self): + d1 = defaultdict() + def foo(): return 42 + d2 = defaultdict(foo, {1: 2}) + # NOTE: We can't use tempfile.[Named]TemporaryFile since this + # code must exercise the tp_print C code, which only gets + # invoked for *real* files. + tfn = tempfile.mktemp() + try: + f = open(tfn, "w+") + try: + print >>f, d1 + print >>f, d2 + f.seek(0) + self.assertEqual(f.readline(), repr(d1) + "\n") + self.assertEqual(f.readline(), repr(d2) + "\n") + finally: + f.close() + finally: + os.remove(tfn) + +if __name__ == "__main__": + unittest.main() Property changes on: Lib/test/test_defaultdict.py ___________________________________________________________________ Name: svn:keywords + Id Name: svn:eol-style + native Index: Modules/collectionsmodule.c =================================================================== --- Modules/collectionsmodule.c (revision 42537) +++ Modules/collectionsmodule.c (working copy) @@ -1065,10 +1065,204 @@ 0, }; +/* defaultdict type *********************************************************/ + +typedef struct { + PyDictObject dict; + PyObject *default_factory; +} defdictobject; + +PyDoc_STRVAR(on_missing_doc, +"on_missing(key)\n\ + # Pseudo-code:\n\ + if self.default_factory is None: return super.on_missing(key)\n\ + return self[key] = self.default_factory()\n\ +"); + +static PyObject * +defdict_on_missing(defdictobject *dd, PyObject *key) +{ + PyObject *factory = dd->default_factory; + PyObject *value; + if (factory == NULL || factory == Py_None) { + /* XXX Call dict.on_missing(key) */ + PyErr_SetObject(PyExc_KeyError, key); + return NULL; + } + value = PyEval_CallObject(factory, NULL); + if (value == NULL) + return value; + if (PyObject_SetItem((PyObject *)dd, key, value) < 0) { + Py_DECREF(value); + return NULL; + } + return value; +} + +static PyMethodDef defdict_methods[] = { + {"on_missing", (PyCFunction)defdict_on_missing, METH_O, + on_missing_doc}, + {NULL} +}; + +static PyMemberDef defdict_members[] = { + {"default_factory", T_OBJECT, + offsetof(defdictobject, default_factory), 0, + PyDoc_STR("Factory for default value called by on_missing().")}, + {NULL} +}; + +static void +defdict_dealloc(defdictobject *dd) +{ + if (dd->default_factory != NULL) { + Py_DECREF(dd->default_factory); + dd->default_factory = NULL; + } + PyDict_Type.tp_dealloc((PyObject *)dd); +} + +static int +defdict_print(defdictobject *dd, FILE *fp, int flags) +{ + int sts; + fprintf(fp, "defaultdict("); + if (dd->default_factory == NULL) + fprintf(fp, "None"); + else { + PyObject_Print(dd->default_factory, fp, 0); + } + fprintf(fp, ", "); + sts = PyDict_Type.tp_print((PyObject *)dd, fp, 0); + fprintf(fp, ")"); + return sts; +} + +static PyObject * +defdict_repr(defdictobject *dd) +{ + PyObject *defrepr; + PyObject *baserepr; + PyObject *result; + baserepr = PyDict_Type.tp_repr((PyObject *)dd); + if (baserepr == NULL) + return NULL; + if (dd->default_factory == NULL) + defrepr = PyString_FromString("None"); + else + defrepr = PyObject_Repr(dd->default_factory); + if (defrepr == NULL) { + Py_DECREF(baserepr); + return NULL; + } + result = PyString_FromFormat("defaultdict(%s, %s)", + PyString_AS_STRING(defrepr), + PyString_AS_STRING(baserepr)); + Py_DECREF(defrepr); + Py_DECREF(baserepr); + return result; +} + +static int +defdict_traverse(PyObject *self, visitproc visit, void *arg) +{ + Py_VISIT(((defdictobject *)self)->default_factory); + return PyDict_Type.tp_traverse(self, visit, arg); +} + +static int +defdict_tp_clear(defdictobject *dd) +{ + if (dd->default_factory != NULL) { + Py_DECREF(dd->default_factory); + dd->default_factory = NULL; + } + return PyDict_Type.tp_clear((PyObject *)dd); +} + +static int +defdict_init(PyObject *self, PyObject *args, PyObject *kwds) +{ + defdictobject *dd = (defdictobject *)self; + PyObject *olddefault = dd->default_factory; + PyObject *newdefault = NULL; + PyObject *newargs; + int result; + if (args == NULL || !PyTuple_Check(args)) + newargs = PyTuple_New(0); + else { + Py_ssize_t n = PyTuple_GET_SIZE(args); + if (n > 0) + newdefault = PyTuple_GET_ITEM(args, 0); + newargs = PySequence_GetSlice(args, 1, n); + } + if (newargs == NULL) + return -1; + Py_XINCREF(newdefault); + dd->default_factory = newdefault; + result = PyDict_Type.tp_init(self, newargs, kwds); + Py_DECREF(newargs); + Py_XDECREF(olddefault); + return result; +} + +PyDoc_STRVAR(defdict_doc, +"defaultdict(default_factory) --> dict with default factory\n\ +\n\ +The default factory is called without arguments to produce\n\ +a new value when a key is not present, in __getitem__ only."); + +static PyTypeObject defdict_type = { + PyObject_HEAD_INIT(NULL) + 0, /* ob_size */ + "collections.defaultdict", /* tp_name */ + sizeof(defdictobject), /* tp_basicsize */ + 0, /* tp_itemsize */ + /* methods */ + (destructor)defdict_dealloc, /* tp_dealloc */ + (printfunc)defdict_print, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + (reprfunc)defdict_repr, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + PyObject_GenericGetAttr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC | + Py_TPFLAGS_HAVE_WEAKREFS, /* tp_flags */ + defdict_doc, /* tp_doc */ + (traverseproc)defdict_traverse, /* tp_traverse */ + (inquiry)defdict_tp_clear, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset*/ + 0, /* tp_iter */ + 0, /* tp_iternext */ + defdict_methods, /* tp_methods */ + defdict_members, /* tp_members */ + 0, /* tp_getset */ + &PyDict_Type, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)defdict_init, /* tp_init */ + PyType_GenericAlloc, /* tp_alloc */ + 0, /* tp_new */ + PyObject_GC_Del, /* tp_free */ +}; + /* module level code ********************************************************/ PyDoc_STRVAR(module_doc, "High performance data structures\n\ +- deque: ordered collection accessible from endpoints only\n\ +- defaultdict: dict subclass with a default value factory attribute\n\ "); PyMODINIT_FUNC @@ -1085,6 +1279,11 @@ Py_INCREF(&deque_type); PyModule_AddObject(m, "deque", (PyObject *)&deque_type); + if (PyType_Ready(&defdict_type) < 0) + return; + Py_INCREF(&defdict_type); + PyModule_AddObject(m, "defaultdict", (PyObject *)&defdict_type); + if (PyType_Ready(&dequeiter_type) < 0) return;