# HG changeset patch # Parent 40ce325744fcaca003d5e01aff82cd2f17970aa1 diff -r 40ce325744fc Lib/pickle.py --- a/Lib/pickle.py Mon Feb 27 20:07:10 2012 +0000 +++ b/Lib/pickle.py Thu Mar 01 14:14:39 2012 +0000 @@ -211,6 +211,18 @@ self.bin = protocol >= 1 self.fast = 0 self.fix_imports = fix_imports and protocol < 3 + try: + self._dispatch_table = self.dispatch_table + except AttributeError: + self._dispatch_table = dispatch_table + + @property + def dispatch_table(self): + return self._dispatch_table + + @dispatch_table.setter + def dispatch_table(self, value): + self._dispatch_table = value def clear_memo(self): """Clears the pickler's "memo". @@ -297,8 +309,8 @@ f(self, obj) # Call unbound method with explicit self return - # Check copyreg.dispatch_table - reduce = dispatch_table.get(t) + # Check self._dispatch_table + reduce = self._dispatch_table.get(t) if reduce: rv = reduce(obj) else: diff -r 40ce325744fc Lib/test/pickletester.py --- a/Lib/test/pickletester.py Mon Feb 27 20:07:10 2012 +0000 +++ b/Lib/test/pickletester.py Thu Mar 01 14:14:39 2012 +0000 @@ -1605,6 +1605,93 @@ self.assertEqual(unpickler.load(), data) +# Tests for dispatch_table attribute + +REDUCE_A = 'reduce_A' + +class AAA(object): + def __reduce__(self): + return str, (REDUCE_A,) + +class BBB(object): + pass + +class AbstractDispatchTableTests(unittest.TestCase): + + def test_class_dispatch_table(self): + dt = self.get_dispatch_table() + + class MyPickler(self.pickler_class): + dispatch_table = dt + + def dumps(obj, protocol=None): + f = io.BytesIO() + p = MyPickler(f, protocol) + p.dump(obj) + return f.getvalue() + + self._test_dispatch_table(dumps, dt) + + def test_instance_dispatch_table(self): + dt = self.get_dispatch_table() + + def dumps(obj, protocol=None): + f = io.BytesIO() + p = pickle.Pickler(f, protocol) + p.dispatch_table = dt + p.dump(obj) + return f.getvalue() + + self._test_dispatch_table(dumps, dt) + + def _test_dispatch_table(self, dumps, dispatch_table): + def custom_load_dump(obj): + return pickle.loads(dumps(obj, 0)) + + def default_load_dump(obj): + return pickle.loads(pickle.dumps(obj, 0)) + + # pickling complex numbers using protocol 0 relies on copyreg + # so check pickling a complex number still works + z = 1 + 2j + self.assertEqual(custom_load_dump(z), z) + self.assertEqual(default_load_dump(z), z) + + # modify pickling of complex + REDUCE_1 = 'reduce_1' + def reduce_1(obj): + return str, (REDUCE_1,) + dispatch_table[complex] = reduce_1 + self.assertEqual(custom_load_dump(z), REDUCE_1) + self.assertEqual(default_load_dump(z), z) + + # check picklability of AAA and BBB + a = AAA() + b = BBB() + self.assertEqual(custom_load_dump(a), REDUCE_A) + self.assertIsInstance(custom_load_dump(b), BBB) + self.assertEqual(default_load_dump(a), REDUCE_A) + self.assertIsInstance(default_load_dump(b), BBB) + + # modify pickling of BBB + dispatch_table[BBB] = reduce_1 + self.assertEqual(custom_load_dump(a), REDUCE_A) + self.assertEqual(custom_load_dump(b), REDUCE_1) + self.assertEqual(default_load_dump(a), REDUCE_A) + self.assertIsInstance(default_load_dump(b), BBB) + + # revert pickling of BBB and modify pickling of AAA + REDUCE_2 = 'reduce_2' + def reduce_2(obj): + return str, (REDUCE_2,) + dispatch_table[AAA] = reduce_2 + del dispatch_table[BBB] + self.assertEqual(custom_load_dump(a), REDUCE_2) + self.assertIsInstance(custom_load_dump(b), BBB) + self.assertEqual(default_load_dump(a), REDUCE_A) + self.assertIsInstance(default_load_dump(b), BBB) + + if __name__ == "__main__": # Print some stuff that can be used to rewrite DATA{0,1,2} from pickletools import dis diff -r 40ce325744fc Lib/test/test_pickle.py --- a/Lib/test/test_pickle.py Mon Feb 27 20:07:10 2012 +0000 +++ b/Lib/test/test_pickle.py Thu Mar 01 14:14:39 2012 +0000 @@ -1,5 +1,6 @@ import pickle import io +import collections from test import support @@ -7,6 +8,7 @@ from test.pickletester import AbstractPickleModuleTests from test.pickletester import AbstractPersistentPicklerTests from test.pickletester import AbstractPicklerUnpicklerObjectTests +from test.pickletester import AbstractDispatchTableTests from test.pickletester import BigmemPickleTests try: @@ -80,6 +82,18 @@ unpickler_class = pickle._Unpickler +class PyDispatchTableTests(AbstractDispatchTableTests): + pickler_class = pickle._Pickler + def get_dispatch_table(self): + return pickle.dispatch_table.copy() + + +class PyChainDispatchTableTests(AbstractDispatchTableTests): + pickler_class = pickle._Pickler + def get_dispatch_table(self): + return collections.ChainMap({}, pickle.dispatch_table) + + if has_c_implementation: class CPicklerTests(PyPicklerTests): pickler = _pickle.Pickler @@ -101,14 +115,26 @@ pickler_class = _pickle.Pickler unpickler_class = _pickle.Unpickler + class CDispatchTableTests(AbstractDispatchTableTests): + pickler_class = pickle.Pickler + def get_dispatch_table(self): + return pickle.dispatch_table.copy() + + class CChainDispatchTableTests(AbstractDispatchTableTests): + pickler_class = pickle.Pickler + def get_dispatch_table(self): + return collections.ChainMap({}, pickle.dispatch_table) + def test_main(): - tests = [PickleTests, PyPicklerTests, PyPersPicklerTests] + tests = [PickleTests, PyPicklerTests, PyPersPicklerTests, + PyDispatchTableTests, PyChainDispatchTableTests] if has_c_implementation: tests.extend([CPicklerTests, CPersPicklerTests, CDumpPickle_LoadPickle, DumpPickle_CLoadPickle, PyPicklerUnpicklerObjectTests, CPicklerUnpicklerObjectTests, + CDispatchTableTests, CChainDispatchTableTests, InMemoryPickleTests]) support.run_unittest(*tests) support.run_doctest(pickle) diff -r 40ce325744fc Modules/_pickle.c --- a/Modules/_pickle.c Mon Feb 27 20:07:10 2012 +0000 +++ b/Modules/_pickle.c Thu Mar 01 14:14:39 2012 +0000 @@ -319,6 +319,9 @@ objects to support self-referential objects pickling. */ PyObject *pers_func; /* persistent_id() method, can be NULL */ + PyObject *dispatch_table; /* If type(self).dispatch_table existed when + self was created, then this is a cached + reference to it. Otherwise it is NULL. */ PyObject *arg; PyObject *write; /* write() method of the output stream. */ @@ -764,6 +767,7 @@ return NULL; self->pers_func = NULL; + self->dispatch_table = NULL; self->arg = NULL; self->write = NULL; self->proto = 0; @@ -3176,17 +3180,24 @@ /* XXX: This part needs some unit tests. */ /* Get a reduction callable, and call it. This may come from - * copyreg.dispatch_table, the object's __reduce_ex__ method, - * or the object's __reduce__ method. + * self.dispatch_table, copyreg.dispatch_table, the object's + * __reduce_ex__ method, or the object's __reduce__ method. */ - reduce_func = PyDict_GetItem(dispatch_table, (PyObject *)type); + if (self->dispatch_table == NULL) { + reduce_func = PyDict_GetItem(dispatch_table, (PyObject *)type); + /* PyDict_GetItem() unlike PyObject_GetItem() and + PyObject_GetAttr() returns a borrowed ref */ + Py_XINCREF(reduce_func); + } else { + reduce_func = PyObject_GetItem(self->dispatch_table, (PyObject *)type); + if (reduce_func == NULL) { + if (PyErr_ExceptionMatches(PyExc_KeyError)) + PyErr_Clear(); + else + goto error; + } + } if (reduce_func != NULL) { - /* Here, the reference count of the reduce_func object returned by - PyDict_GetItem needs to be increased to be consistent with the one - returned by PyObject_GetAttr. This is allow us to blindly DECREF - reduce_func at the end of the save() routine. - */ - Py_INCREF(reduce_func); Py_INCREF(obj); reduce_value = _Pickler_FastCall(self, reduce_func, obj); } @@ -3359,6 +3370,7 @@ Py_XDECREF(self->output_buffer); Py_XDECREF(self->write); Py_XDECREF(self->pers_func); + Py_XDECREF(self->dispatch_table); Py_XDECREF(self->arg); Py_XDECREF(self->fast_memo); @@ -3372,6 +3384,7 @@ { Py_VISIT(self->write); Py_VISIT(self->pers_func); + Py_VISIT(self->dispatch_table); Py_VISIT(self->arg); Py_VISIT(self->fast_memo); return 0; @@ -3383,6 +3396,7 @@ Py_CLEAR(self->output_buffer); Py_CLEAR(self->write); Py_CLEAR(self->pers_func); + Py_CLEAR(self->dispatch_table); Py_CLEAR(self->arg); Py_CLEAR(self->fast_memo); @@ -3427,6 +3441,7 @@ PyObject *proto_obj = NULL; PyObject *fix_imports = Py_True; _Py_IDENTIFIER(persistent_id); + _Py_IDENTIFIER(dispatch_table); if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|OO:Pickler", kwlist, &file, &proto_obj, &fix_imports)) @@ -3468,6 +3483,13 @@ if (self->pers_func == NULL) return -1; } + self->dispatch_table = NULL; + if (_PyObject_HasAttrId((PyObject *)self, &PyId_dispatch_table)) { + self->dispatch_table = _PyObject_GetAttrId((PyObject *)self, + &PyId_dispatch_table); + if (self->dispatch_table == NULL) + return -1; + } return 0; } @@ -3749,6 +3771,7 @@ static PyMemberDef Pickler_members[] = { {"bin", T_INT, offsetof(PicklerObject, bin)}, {"fast", T_INT, offsetof(PicklerObject, fast)}, + {"dispatch_table", T_OBJECT_EX, offsetof(PicklerObject, dispatch_table)}, {NULL} };