diff --git a/Lib/pickle.py b/Lib/pickle.py --- a/Lib/pickle.py +++ b/Lib/pickle.py @@ -476,6 +476,11 @@ class _Pickler: f(self, obj) # Call unbound method with explicit self return + self.save_user_defined(obj) + + def save_user_defined(self, obj): + # Save object with object-defined methods (reduce and copyreg) + t = type(obj) # Check private dispatch table if any, or else copyreg.dispatch_table reduce = getattr(self, 'dispatch_table', dispatch_table).get(t) if reduce is not None: @@ -973,7 +978,16 @@ class _Pickler: return self.save_reduce(type, (...,), obj=obj) return self.save_global(obj) - dispatch[FunctionType] = save_global + def save_function(self, obj): + try: + return self.save_global(obj) + except PicklingError: + # fall back to reduce and clear the exception + pass + + return self.save_user_defined(obj) + + dispatch[FunctionType] = save_function dispatch[type] = save_type diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py --- a/Lib/test/pickletester.py +++ b/Lib/test/pickletester.py @@ -3,6 +3,7 @@ import copyreg import dbm import io import functools +import marshal import pickle import pickletools import struct @@ -657,6 +658,12 @@ def create_data(): return x +def unpickle_lambda(name, argdefs, code): + code = marshal.loads(code) + fun_type = type(unpickle_lambda) + return fun_type(code, {}, name, argdefs) + + class AbstractUnpickleTests(unittest.TestCase): # Subclass must define self.loads. @@ -2259,21 +2266,20 @@ class AbstractPickleTests(unittest.TestC pass # Since the function is local, lookup will fail for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): - with self.assertRaises((AttributeError, pickle.PicklingError)): + with self.assertRaises((AttributeError, TypeError, pickle.PicklingError)): pickletools.dis(self.dumps(f, proto)) # Same without a __module__ attribute (exercises a different path # in _pickle.c). del f.__module__ for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): - with self.assertRaises((AttributeError, pickle.PicklingError)): + with self.assertRaises((AttributeError, TypeError, pickle.PicklingError)): pickletools.dis(self.dumps(f, proto)) # Yet a different path. f.__name__ = f.__qualname__ for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): - with self.assertRaises((AttributeError, pickle.PicklingError)): + with self.assertRaises((AttributeError, TypeError, pickle.PicklingError)): pickletools.dis(self.dumps(f, proto)) - class BigmemPickleTests(unittest.TestCase): # Binary protocols can serialize longs of up to 2GB-1 @@ -2774,6 +2780,40 @@ class AbstractPicklerUnpicklerObjectTest unpickler = self.unpickler_class(f) self.assertEqual(unpickler.load(), data) + def test_custom_functions_pickle(self): + f = lambda: 42 + + def pickle_lambda(fun): + # Simplified pickling for testing purposes: + # only pure lambdas, even no imports handled + name = fun.__name__ + if name != '': + raise pickle.PicklingError( + "Can't pickle function object: %s" % (name,)) + + if fun.__closure__ is not None: + raise pickle.PicklingError( + "Can't pickle closures, only pure lambda allowed: %s" % (name,)) + + co = fun.__code__ + args = ( + name, + fun.__defaults__, + marshal.dumps(co), + ) + return unpickle_lambda, args + + dt = pickle.dispatch_table.copy() + dt[type(f)] = pickle_lambda + i = io.BytesIO() + + pickler = self.pickler_class(i) + pickler.dispatch_table = dt + pickler.dump(f) + i = io.BytesIO(i.getvalue()) + unpickler = self.unpickler_class(i) + self.assertEqual(f(), unpickler.load()()) + # Tests for dispatch_table attribute diff --git a/Modules/_pickle.c b/Modules/_pickle.c --- a/Modules/_pickle.c +++ b/Modules/_pickle.c @@ -3869,8 +3869,18 @@ save(PicklerObject *self, PyObject *obj, goto done; } else if (type == &PyFunction_Type) { + PickleState *st; status = save_global(self, obj, NULL); - goto done; + st = _Pickle_GetGlobalState(); + if (status < 0 + && (PyErr_ExceptionMatches(st->PicklingError) + || PyErr_ExceptionMatches(PyExc_AttributeError))) { + /* fall back to reduce */ + PyErr_Clear(); + } + else { + goto done; + } } /* XXX: This part needs some unit tests. */