diff -r 92dda5f00b0f Lib/pickle.py --- a/Lib/pickle.py Thu Nov 05 20:41:57 2015 +0100 +++ b/Lib/pickle.py Fri Nov 06 15:29:02 2015 +0200 @@ -402,7 +402,13 @@ class Pickler: write(REDUCE) if obj is not None: - self.memoize(obj) + # If the object is already in the memo, this means it is + # recursive. In this case, throw away everything we put on the + # stack, and fetch the object back from the memo. + if id(obj) in self.memo: + write(POP + self.get(self.memo[id(obj)][0])) + else: + self.memoize(obj) # More new special cases (that work with older protocols as # well): when __reduce__ returns a tuple with 4 or 5 items, diff -r 92dda5f00b0f Lib/test/pickletester.py --- a/Lib/test/pickletester.py Thu Nov 05 20:41:57 2015 +0100 +++ b/Lib/test/pickletester.py Fri Nov 06 15:29:02 2015 +0200 @@ -117,6 +117,17 @@ class E(C): def __getinitargs__(self): return () +class H(object): + pass + +# Hashable mutable key +class K(list): + def __reduce__(self): + # Shouldn't support recursion itself + return K, (tuple(self),) + + __hash__ = object.__hash__ + import __main__ __main__.C = C C.__module__ = "__main__" @@ -124,6 +135,10 @@ C.__module__ = "__main__" D.__module__ = "__main__" __main__.E = E E.__module__ = "__main__" +__main__.H = H +H.__module__ = "__main__" +__main__.K = K +K.__module__ = "__main__" class myint(int): def __init__(self, x): @@ -676,18 +691,21 @@ class AbstractPickleTests(unittest.TestC for proto in protocols: s = self.dumps(l, proto) x = self.loads(s) + self.assertIs(type(x), list) self.assertEqual(len(x), 1) - self.assertTrue(x is x[0]) + self.assertIs(x[0], x) - def test_recursive_tuple(self): + def test_recursive_tuple_and_list(self): t = ([],) t[0].append(t) for proto in protocols: s = self.dumps(t, proto) x = self.loads(s) + self.assertIs(type(x), tuple) self.assertEqual(len(x), 1) + self.assertIs(type(x[0]), list) self.assertEqual(len(x[0]), 1) - self.assertTrue(x is x[0][0]) + self.assertIs(x[0][0], x) def test_recursive_dict(self): d = {} @@ -695,8 +713,54 @@ class AbstractPickleTests(unittest.TestC for proto in protocols: s = self.dumps(d, proto) x = self.loads(s) + self.assertIs(type(x), dict) self.assertEqual(x.keys(), [1]) - self.assertTrue(x[1] is x) + self.assertIs(x[1], x) + + def test_recursive_dict_key(self): + k = K() + d = {} + d[k] = 1 + k.append(d) + for proto in protocols: + s = self.dumps(d, proto) + x = self.loads(s) + self.assertIs(type(x), dict) + self.assertEqual(len(x.keys()), 1) + self.assertIs(type(x.keys()[0]), K) + self.assertEqual(len(x.keys()[0]), 1) + self.assertIs(x.keys()[0][0], x) + + def test_recursive_list_subclass(self): + y = MyList() + y.append(y) + s = self.dumps(y, 2) + x = self.loads(s) + self.assertIs(type(x), MyList) + self.assertEqual(len(x), 1) + self.assertIs(x[0], x) + + def test_recursive_dict_subclass(self): + d = MyDict() + d[1] = d + s = self.dumps(d, 2) + x = self.loads(s) + self.assertIs(type(x), MyDict) + self.assertEqual(x.keys(), [1]) + self.assertIs(x[1], x) + + def test_recursive_dict_subclass_key(self): + k = K() + d = MyDict() + d[k] = 1 + k.append(d) + s = self.dumps(d, 2) + x = self.loads(s) + self.assertIs(type(x), MyDict) + self.assertEqual(len(x.keys()), 1) + self.assertIs(type(x.keys()[0]), K) + self.assertEqual(len(x.keys()[0]), 1) + self.assertIs(x.keys()[0][0], x) def test_recursive_inst(self): i = C() @@ -721,6 +785,42 @@ class AbstractPickleTests(unittest.TestC self.assertEqual(x[0].attr.keys(), [1]) self.assertTrue(x[0].attr[1] is x) + def check_recursive_collection_and_inst(self, cls): + h = H() + y = cls([h]) + h.attr = y + for proto in protocols: + s = self.dumps(y, proto) + x = self.loads(s) + self.assertIs(type(x), type(y)) + self.assertEqual(len(x), 1) + self.assertIs(type(list(x)[0]), H) + self.assertIs(list(x)[0].attr, x) + + def test_recursive_list_and_inst(self): + self.check_recursive_collection_and_inst(list) + + def test_recursive_tuple_and_inst(self): + self.check_recursive_collection_and_inst(tuple) + + def test_recursive_dict_and_inst(self): + self.check_recursive_collection_and_inst(dict.fromkeys) + + def test_recursive_set_and_inst(self): + self.check_recursive_collection_and_inst(set) + + def test_recursive_frozenset_and_inst(self): + self.check_recursive_collection_and_inst(frozenset) + + def test_recursive_list_subclass_and_inst(self): + self.check_recursive_collection_and_inst(MyList) + + def test_recursive_tuple_subclass_and_inst(self): + self.check_recursive_collection_and_inst(MyTuple) + + def test_recursive_dict_subclass_and_inst(self): + self.check_recursive_collection_and_inst(MyDict.fromkeys) + if have_unicode: def test_unicode(self): endcases = [u'', u'<\\u>', u'<\\\u1234>', u'<\n>', diff -r 92dda5f00b0f Lib/test/test_cpickle.py --- a/Lib/test/test_cpickle.py Thu Nov 05 20:41:57 2015 +0100 +++ b/Lib/test/test_cpickle.py Fri Nov 06 15:29:02 2015 +0200 @@ -1,6 +1,7 @@ import cPickle import cStringIO import io +import functools import unittest from test.pickletester import (AbstractUnpickleTests, AbstractPickleTests, @@ -151,31 +152,6 @@ class cPickleFastPicklerTests(AbstractPi finally: self.close(f) - def test_recursive_list(self): - self.assertRaises(ValueError, - AbstractPickleTests.test_recursive_list, - self) - - def test_recursive_tuple(self): - self.assertRaises(ValueError, - AbstractPickleTests.test_recursive_tuple, - self) - - def test_recursive_inst(self): - self.assertRaises(ValueError, - AbstractPickleTests.test_recursive_inst, - self) - - def test_recursive_dict(self): - self.assertRaises(ValueError, - AbstractPickleTests.test_recursive_dict, - self) - - def test_recursive_multi(self): - self.assertRaises(ValueError, - AbstractPickleTests.test_recursive_multi, - self) - def test_nonrecursive_deep(self): # If it's not cyclic, it should pickle OK even if the nesting # depth exceeds PY_CPICKLE_FAST_LIMIT. That happens to be @@ -187,6 +163,19 @@ class cPickleFastPicklerTests(AbstractPi b = self.loads(self.dumps(a)) self.assertEqual(a, b) +for name in dir(AbstractPickleTests): + if name.startswith('test_recursive_'): + func = getattr(AbstractPickleTests, name) + if '_subclass' in name and '_and_inst' not in name: + assert_args = RuntimeError, 'maximum recursion depth exceeded' + else: + assert_args = ValueError, "can't pickle cyclic objects" + def wrapper(self, func=func, assert_args=assert_args): + with self.assertRaisesRegexp(*assert_args): + func(self) + functools.update_wrapper(wrapper, func) + setattr(cPickleFastPicklerTests, name, wrapper) + class cStringIOCPicklerFastTests(cStringIOMixin, cPickleFastPicklerTests): pass diff -r 92dda5f00b0f Modules/cPickle.c --- a/Modules/cPickle.c Thu Nov 05 20:41:57 2015 +0100 +++ b/Modules/cPickle.c Fri Nov 06 15:29:02 2015 +0200 @@ -2533,6 +2533,27 @@ save_reduce(Picklerobject *self, PyObjec /* Memoize. */ /* XXX How can ob be NULL? */ if (ob != NULL) { + /* If the object is already in the memo, this means it is + recursive. In this case, throw away everything we put on the + stack, and fetch the object back from the memo. */ + if (Py_REFCNT(ob) > 1 && !self->fast) { + PyObject *py_ob_id = PyLong_FromVoidPtr(ob); + if (!py_ob_id) + return -1; + if (PyDict_GetItem(self->memo, py_ob_id)) { + const char pop_op = POP; + if (self->write_func(self, &pop_op, 1) < 0 || + get(self, py_ob_id) < 0) { + Py_DECREF(py_ob_id); + return -1; + } + Py_DECREF(py_ob_id); + return 0; + } + Py_DECREF(py_ob_id); + if (PyErr_Occurred()) + return -1; + } if (state && !PyDict_Check(state)) { if (put2(self, ob) < 0) return -1;