diff --git a/Lib/pickle.py b/Lib/pickle.py --- a/Lib/pickle.py +++ b/Lib/pickle.py @@ -476,6 +476,11 @@ class _Pickler: f(self, obj) # Call unbound method with explicit self return + self.save_user_defined(obj) + + def save_user_defined(self, obj): + # Save object with object-defined methods (reduce and copyreg) + t = type(obj) # Check private dispatch table if any, or else copyreg.dispatch_table reduce = getattr(self, 'dispatch_table', dispatch_table).get(t) if reduce is not None: @@ -973,7 +978,16 @@ class _Pickler: return self.save_reduce(type, (...,), obj=obj) return self.save_global(obj) - dispatch[FunctionType] = save_global + def save_function(self, obj): + try: + return self.save_global(obj) + except PicklingError: + # fall back to reduce and clear the exception + pass + + return self.save_user_defined(obj) + + dispatch[FunctionType] = save_function dispatch[type] = save_type diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py --- a/Lib/test/pickletester.py +++ b/Lib/test/pickletester.py @@ -657,6 +657,13 @@ def create_data(): return x +def unpickle_lambda(name, argdefs, *co_args): + code_type = type(unpickle_lambda.__code__) + code = code_type(*co_args) + fun_type = type(unpickle_lambda) + return fun_type(code, {}, name, argdefs) + + class AbstractUnpickleTests(unittest.TestCase): # Subclass must define self.loads. @@ -2259,20 +2266,65 @@ class AbstractPickleTests(unittest.TestC pass # Since the function is local, lookup will fail for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): - with self.assertRaises((AttributeError, pickle.PicklingError)): + with self.assertRaises((AttributeError, TypeError, pickle.PicklingError)): pickletools.dis(self.dumps(f, proto)) # Same without a __module__ attribute (exercises a different path # in _pickle.c). del f.__module__ for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): - with self.assertRaises((AttributeError, pickle.PicklingError)): + with self.assertRaises((AttributeError, TypeError, pickle.PicklingError)): pickletools.dis(self.dumps(f, proto)) # Yet a different path. f.__name__ = f.__qualname__ for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): - with self.assertRaises((AttributeError, pickle.PicklingError)): + with self.assertRaises((AttributeError, TypeError, pickle.PicklingError)): pickletools.dis(self.dumps(f, proto)) + def test_custom_functions_pickle(self): + f = lambda: 42 + + def pickle_lambda(fun): + # Simplified pickling for testing purposes: + # only pure lambdas, even no imports handled + name = fun.__name__ + if name != '': + raise pickle.PicklingError( + "Can't pickle function object: %s" % (name,)) + + if fun.__closure__ is not None: + raise pickle.PicklingError( + "Can't pickle closures, only pure lambda allowed: %s" % (name,)) + + co = fun.__code__ + args = ( + name, + fun.__defaults__, + co.co_argcount, + co.co_kwonlyargcount, + co.co_nlocals, + co.co_stacksize, + co.co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_freevars, + co.co_cellvars, + ) + return unpickle_lambda, args + + dt = pickle.dispatch_table.copy() + dt[type(f)] = pickle_lambda + i = io.BytesIO() + + pickler = self.pickler(i) + pickler.dispatch_table = dt + pickler.dump(f) + self.assertEqual(f(), self.loads(i.getvalue())()) class BigmemPickleTests(unittest.TestCase): diff --git a/Modules/_pickle.c b/Modules/_pickle.c --- a/Modules/_pickle.c +++ b/Modules/_pickle.c @@ -3869,8 +3869,18 @@ save(PicklerObject *self, PyObject *obj, goto done; } else if (type == &PyFunction_Type) { + PickleState *st; status = save_global(self, obj, NULL); - goto done; + st = _Pickle_GetGlobalState(); + if (status < 0 + && (PyErr_ExceptionMatches(st->PicklingError) + || PyErr_ExceptionMatches(PyExc_AttributeError))) { + /* fall back to reduce */ + PyErr_Clear(); + } + else { + goto done; + } } /* XXX: This part needs some unit tests. */