diff -r 780722877a3e Lib/pickle.py --- a/Lib/pickle.py Wed May 01 13:16:11 2013 -0700 +++ b/Lib/pickle.py Sat May 11 03:06:28 2013 +0300 @@ -23,7 +23,7 @@ """ -from types import FunctionType, BuiltinFunctionType +from types import FunctionType, BuiltinFunctionType, MethodType, ModuleType from copyreg import dispatch_table from copyreg import _extension_registry, _inverted_registry, _extension_cache from itertools import islice @@ -34,10 +34,44 @@ import io import codecs import _compat_pickle +import builtins +from inspect import ismodule, isclass __all__ = ["PickleError", "PicklingError", "UnpicklingError", "Pickler", "Unpickler", "dump", "dumps", "load", "loads"] +# Issue 15397: Unbinding of methods +# Adds the possibility to unbind methods as well as a few definitions missing +# from the types module. + +_MethodDescriptorType = type(list.append) +_WrapperDescriptorType = type(list.__add__) +_MethodWrapperType = type([].__add__) + +def _unbind(f): + """Unbinds a bound method.""" + self = getattr(f, '__self__', None) + if self is not None and not isinstance(self, ModuleType) \ + and not isinstance(self, type): + if hasattr(f, '__func__'): + return f.__func__ + return getattr(type(f.__self__), f.__name__) + raise TypeError('not a bound method') + + +def _bind_method(self, func): + """This method is used internally to pickle bound methods using the REDUCE + opcode.""" + return func.__get__(self) + +def _isclassmethod(func): + """Tests if a given function is a classmethod.""" + if type(func) not in [MethodType, BuiltinFunctionType]: + return False + if hasattr(func, '__self__') and type(func.__self__) is type: + return True + return False + # Shortcut for use in isinstance testing bytes_types = (bytes, bytearray) @@ -173,6 +207,8 @@ ADDITEMS = b'\x90' # modify set by adding topmost stack items EMPTY_FROZENSET = b'\x91' # push empty frozenset on the stack FROZENSET = b'\x92' # build frozenset from topmost stack items +BINGLOBAL = b'\x93' # push a global (like GLOBAL) +BINGLOBAL_BIG = b'\x94' # push an unusually large global name __all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$", x)]) @@ -726,23 +762,34 @@ write = self.write memo = self.memo + getattr_func = getattr_recurse if self.proto >= 4 else getattr + if name is None: - name = obj.__name__ + if self.proto >= 4: + name = obj.__qualname__ + else: + name = obj.__name__ module = getattr(obj, "__module__", None) if module is None: - module = whichmodule(obj, name) + module = whichmodule(obj, name, getattr_func) try: __import__(module, level=0) mod = sys.modules[module] - klass = getattr(mod, name) - except (ImportError, KeyError, AttributeError): + klass = getattr_func(mod, name) + except (ImportError, KeyError, AttributeError) as e: raise PicklingError( "Can't pickle %r: it's not found as %s.%s" % - (obj, module, name)) + (obj, module, name)) from e else: - if klass is not obj: + # Note: The 'is' operator does not currently work as expected when + # applied on functions which are classmethods ("dict.fromkeys is + # dict.fromkeys" is False). Therefore, we only perform the check + # below if the object we are dealing with ("obj") is not a + # classmethod. + # XXX remove the additional check when this is fixed + if klass is not obj and not _isclassmethod(obj): raise PicklingError( "Can't pickle %r: it's not the same object as %s.%s" % (obj, module, name)) @@ -758,10 +805,31 @@ else: write(EXT4 + pack("= 4 and module == '__builtins__': + module = 'builtins' # Non-ASCII identifiers are supported only with protocols >= 3. if self.proto >= 3: - write(GLOBAL + bytes(module, "utf-8") + b'\n' + - bytes(name, "utf-8") + b'\n') + module_bin = bytes(module, 'utf-8') + name_bin = bytes(name, 'utf-8') + if self.proto >= 4 and len(module_bin) <= 255 and \ + len(name_bin) <= 255: + write(BINGLOBAL + bytes([len(module_bin)]) + + module_bin + bytes([len(name_bin)]) + name_bin) + # use BINGLOBAL_BIG for representing unusually large globals in + # pickle >= 4 + elif self.proto >= 4: + assert len(module_bin) <= 65535 + assert len(name_bin) <= 65535 + write(BINGLOBAL_BIG + pack('>> getattr(sys.modules['os'], 'path.isdir') + Traceback (most recent call last): + ... + AttributeError: 'module' object has no attribute 'path.isdir' + >>> getattr_recurse(sys.modules['os'], 'path.isdir')('.') + True + >>> getattr_recurse(sys.modules['os'], 'path.foo') + Traceback (most recent call last): + ... + AttributeError: 'module' object has no attribute 'foo' + """ + ret = module + for attr in name.split('.'): + if attr == '': + raise TypeError('Cannot work with the locals of '+ + ret.__qualname__) + + if default is _None: + ret = getattr(ret, attr) + """ + raise AttributeError('\'%s\' object has no attribute \'%s\'' % + (type(ret), n)) + ret = ret_ + """ + else: + try: + ret = getattr(ret, attr) + except AttributeError: + return default + return ret + + +def whichmodule(func, funcname, getattr_func=getattr): """Figure out the module in which a function occurs. Search sys.modules for the module. @@ -799,13 +926,20 @@ mod = getattr(func, "__module__", None) if mod is not None: return mod + # XXX this is for classmethods. since whichmodule() uses `is' to compare + # for equality of functions and "dict.fromkeys is dict.fromkeys" evaluates + # to False, whichmodule(dict.fromkeys, 'dict.fromkeys') would incorrectly + # return '__main__' + elif hasattr(func, "__self__") and hasattr(func.__self__, "__module__") \ + and func.__self__.__module__ is not None: + return func.__self__.__module__ if func in classmap: return classmap[func] for name, module in list(sys.modules.items()): if module is None: continue # skip dummy package entries - if name != '__main__' and getattr(module, funcname, None) is func: + if name != '__main__' and getattr_func(module, funcname, None) is func: break else: name = '__main__' @@ -1180,10 +1314,20 @@ module, name = _compat_pickle.NAME_MAPPING[(module, name)] if module in _compat_pickle.IMPORT_MAPPING: module = _compat_pickle.IMPORT_MAPPING[module] - __import__(module, level=0) - mod = sys.modules[module] - klass = getattr(mod, name) - return klass + + try: + __import__(module, level=0) + + mod = sys.modules[module] + + if self.proto >= 4: + klass = getattr_recurse(mod, name) + else: + klass = getattr(mod, name) + return klass + except Exception as e: + raise UnpicklingError("Couldn't find class %s.%s" % + (module,name)) from e def load_reduce(self): stack = self.stack @@ -1333,6 +1477,24 @@ raise _Stop(value) dispatch[STOP[0]] = load_stop + def load_binglobal(self): + module_size = self.read(1)[0] + module = self.read(module_size).decode('utf-8') + name_size = self.read(1)[0] + name = self.read(name_size).decode('utf-8') + klass = self.find_class(module, name) + self.append(klass) + dispatch[BINGLOBAL[0]] = load_binglobal + + def load_binglobal_big(self): + module_size = unpack('>> import io + >>> read_unicodestringu2(io.BytesIO(b'\x00\x00abc')) + '' + >>> read_unicodestringu2(io.BytesIO(b'\x03\x00abc')) + 'abc' + >>> read_unicodestringu2(io.BytesIO(b'\x04\x00' + ('\U0001D223'.encode('utf-8')))) + '\U0001d223' + >>> read_unicodestringu2(io.BytesIO(b'\x0d\x00' + ('\ufb93' * 4).encode('utf-8'))) + Traceback (most recent call last): + ... + ValueError: expected 13 bytes in a unicodestringu2, but only 12 remain + """ + n = read_uint2(f) + assert n >= 0 + + data = f.read(n) + if len(data) == n: + return str(data, 'utf-8', 'surrogatepass') + raise ValueError("expected %d bytes in a unicodestringu2, but only %d " + "remain" % (n, len(data))) + +unicodestringu2 = ArgumentDescriptor( + name='unicodestringu2', + n=TAKEN_FROM_ARGUMENT2U, + reader=read_unicodestringu2, + doc="""A counted semi-short Unicode string. + + The first argument is a 2-byte little-endian unsigned short, giving + the number of bytes in the string, and the second argument is + the UTF-8 encoding of the Unicode string + """) +def read_unicodestring1_pair(f): + r""" + >>> import io + >>> read_unicodestring1_pair(io.BytesIO(b"\x00\x00whatever")) + ' ' + >>> read_unicodestring1_pair(io.BytesIO(b"\x05hello\x06world!blabla")) + 'hello world!' + """ + return "%s %s" % (read_unicodestring1(f), read_unicodestring1(f)) + +unicodestring1_pair = ArgumentDescriptor( + name='unicodestring1_pair', + n=TAKEN_FROM_ARGUMENT1, + reader=read_unicodestring1_pair, + doc="""Read a pair of small unicode strings. + + Both of the strings are preceded by an uint1 + indicating the length of the utf-8 encoded + string to follow""") + +def read_unicodestringu2_pair(f): + r""" + >>> import io + >>> read_unicodestringu2_pair(io.BytesIO(b"\x00\x00\x00\x00whatever")) + ' ' + >>> read_unicodestringu2_pair(io.BytesIO( + ... b"\x05\x00hello\x06\x00world!blabla")) + 'hello world!' + """ + return "%s %s" % (read_unicodestringu2(f), read_unicodestringu2(f)) + +unicodestringu2_pair = ArgumentDescriptor( + name='unicodestringu2_pair', + n=TAKEN_FROM_ARGUMENT2U, + reader=read_unicodestringu2_pair, + doc="""Read a pair of semi-small unicode strings. + + Both of the strings are preceded by a + little-endian uint2 indicating the length + of the utf-8 encoded string to follow""") def read_string1(f): r""" @@ -2107,6 +2182,35 @@ ID is passed to self.persistent_load(), and whatever object that returns is pushed on the stack. See PERSID for more detail. """), + I(name='BINGLOBAL', + code='\x93', + arg=unicodestring1_pair, + stack_before=[], + stack_after=[anyobject], + proto=4, + doc="""Push a global object (module.obj) on the stack. + + This works in a similar way to GLOBAL, but instead of taking a pair of + newline-terminated strings as parameters (representing the module name + and the attribute respectively), it takes a pair of two small utf-8 + encoded strings, with their 8bit size prepended to them (the + equivalent of two consecutive SHORT_BINUNICODE opcodes). + + On versions 4 and above, this object is automatically memoized by the + unpickler (there's no need for BINPUT after this opcode). + """), + + I(name='BINGLOBAL_BIG', + code='\x94', + arg=unicodestringu2_pair, + stack_before=[], + stack_after=[anyobject], + proto=4, + doc="""Push a global object (module.obj) on the stack. + + This is used instead of BINGLOBAL for unusually large global names (i.e. + >255 bytes). + """), ] del I @@ -2414,6 +2518,20 @@ if stack: raise ValueError("stack not empty after STOP: %r" % stack) +def maxversion(pickled_data): + """Find the maximum version amongst the used opcodes in the given pickled + data. + + Like in `dis', pickle is a file-like object, or string, containing at least + one pickle. The pickle is disassembled from the current position until the + first STOP opcode is encountered, and the maximum version of the + encountered opcodes is returned. + """ + ret = -1 + for opcode, arg, pos in genops(pickled_data): + ret = max(ret, opcode.proto) + return ret + # For use in the doctest, simply as an example of a class to pickle. class _Example: def __init__(self, value): diff -r 780722877a3e Lib/test/pickletester.py --- a/Lib/test/pickletester.py Wed May 01 13:16:11 2013 -0700 +++ b/Lib/test/pickletester.py Sat May 11 03:06:28 2013 +0300 @@ -433,6 +433,43 @@ x.append(5) return x +class Nested: + n = 'Nested' + + class B: + n = 'Nested.B' + + def f(): + return 'Nested.B.f' + def ff(self): + return 'Nested.B.ff' + + @classmethod + def cm(klass): + return klass.n + + @staticmethod + def sm(): + return 'sm' + + class C: + n = 'Nested.B.C' + + def __init__(self): + self.a = 123 + + def f(): + return 'Nested.B.C.f' + def ff(self): + return 'Nested.B.C.ff' + + def get_a(self): + return self.a + +# used to test pickling of unusually large names +class _aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa: + pass + class AbstractPickleTests(unittest.TestCase): # Subclass must define self.dumps, self.loads. @@ -1291,6 +1328,105 @@ self._check_pickling_with_opcode(obj, pickle.SETITEM, proto) else: self._check_pickling_with_opcode(obj, pickle.SETITEMS, proto) + + def _loads(self, data, version=pickle.HIGHEST_PROTOCOL, minversion=-1, + *kargs, **kwargs): + """Uses loads, but first makes sure there aren't any opcodes of too + high or too low of a version number. + + Usecase: + data = self.dumps([1, 2, 3], proto) + undata = self._loads(data, proto) + + v3_feature = .. + data = self.dumps(v3_feature, 4) + undata = self._loads(v3_feature, 4, 3) + """ + maxv = pickletools.maxversion(data) + self.assertLessEqual(maxv, version) + self.assertLessEqual(minversion, maxv) + return self.loads(data, *kargs, **kwargs) + + def _used_opcodes(self, data): + opcodes=set() + for opcode, arg, pos in pickletools.genops(data): + opcodes.add(opcode.name) + return opcodes + + def test_v4_binglobal_big(self): + klass=_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa + for proto in protocols: + data=self.dumps(klass, proto) + klass_=self._loads(data, proto) + self.assertEqual(klass_, klass) + opcodes=self._used_opcodes(data) + if proto < 4: + self.assertNotIn('BINGLOBAL', opcodes) + self.assertNotIn('BINGLOBAL_BIG', opcodes) + self.assertIn('GLOBAL', opcodes) + else: + self.assertNotIn('GLOBAL', opcodes) + self.assertNotIn('BINGLOBAL', opcodes) + self.assertIn('BINGLOBAL_BIG', opcodes) + + def test_v4_nested_classes(self): + """test pickling nested classes""" + for proto in range(4, 1+pickle.HIGHEST_PROTOCOL): + for klass in (Nested, Nested.B, Nested.B.C): + data = self.dumps(klass, proto) + undata = self._loads(data, proto, 4) + + self.assertEqual(klass.n, undata.n) + self.assertEqual(klass.n, undata.__qualname__) + self.assertEqual(klass.__qualname__, undata.__qualname__) + + for func in (Nested.B.f, Nested.B.C.f): + data = self.dumps(func, proto) + undata = self._loads(data, proto, 4) + + self.assertEqual(func.__qualname__, undata.__qualname__) + self.assertEqual(func(), undata()) + self.assertLessEqual(4, pickletools.maxversion(data)) + + inst = Nested.B.C() + inst.a = 42 + + data = self.dumps(inst, proto) + undata = self._loads(data, proto, 4) + + self.assertEqual(inst.a, undata.get_a()) + + data = self.dumps( [(inst, Nested.B), (Nested.B.C.f, Nested.B.f, + Nested.B.C.f), + Nested, Nested.B.C, inst, Nested.B.f], proto) + inst.a = -42 + undata = self._loads(data, proto, 4) + + self.assertEqual(42, undata[0][0].a) + self.assertEqual('Nested.B.f', undata[0][1].f()) + self.assertEqual('Nested.B.C.f', undata[1][0]()) + self.assertEqual('Nested.B.f', undata[1][1]()) + self.assertEqual('Nested.B.C.f', undata[1][2]()) + self.assertEqual('Nested', undata[2].n) + self.assertEqual('Nested.B.C', undata[3].n) + self.assertEqual(42, undata[4].get_a()) + self.assertEqual('Nested.B.f', undata[5]()) + + def test_v4_weird_funcs(self): + funcs = [list.append, list.__add__, dict.fromkeys, len, Nested.B.cm, + Nested.B.sm] + for proto in range(4, 1+pickle.HIGHEST_PROTOCOL): + data=self.dumps(funcs, proto) + funcs_=self._loads(data, proto) + l=[] + funcs_[0](l, 1) # l.append(1) + l=funcs_[1](l, [2,3]) # l += [2,3] + self.assertEqual([1,2,3], l) + self.assertEqual(3, funcs_[3](l)) # len(l) + # dict.fromkeys([1,2]) = {1: None, 2: None} + self.assertEqual({1 : None, 2 : None}, funcs_[2]([1,2])) + self.assertEqual('Nested.B', funcs_[4]()) # Nested.B.cm() + self.assertEqual('sm', funcs_[5]()) # Nested.B.sm() class BigmemPickleTests(unittest.TestCase): diff -r 780722877a3e Modules/_pickle.c --- a/Modules/_pickle.c Wed May 01 13:16:11 2013 -0700 +++ b/Modules/_pickle.c Sat May 11 03:06:28 2013 +0300 @@ -80,7 +80,9 @@ EMPTY_SET = '\x8f', ADDITEMS = '\x90', EMPTY_FROZENSET = '\x91', - FROZENSET = '\x92' + FROZENSET = '\x92', + BINGLOBAL = '\x93', + BINGLOBAL_BIG = '\x94' }; /* These aren't opcodes -- they're ways to pickle bools before protocol 2 @@ -145,6 +147,161 @@ /* For looking up name pairs in copyreg._extension_registry. */ static PyObject *two_tuple = NULL; +static PyObject *v4_common_modules = NULL; + +static PyObject * +unbind (PyObject *func) +{ + PyObject *self = NULL, *unbound = NULL, *name; + static PyObject *self_str = NULL, *func_str = NULL, *name_str = NULL; + + if (!self_str) { + self_str = PyUnicode_InternFromString("__self__"); + if (!self_str) return NULL; + } + + self = PyObject_GetAttr(func, self_str); + PyErr_Clear(); + if (!self || PyModule_Check(self) || PyType_Check(self)) { + PyErr_SetString(PyExc_TypeError, "not a bound method"); + Py_XDECREF(self); + return NULL; + } + else { + if (!func_str) { + func_str = PyUnicode_InternFromString("__func__"); + if (!func_str) goto done; + } + unbound = PyObject_GetAttr(func, func_str); + if (unbound) goto done; + else { + if (PyErr_ExceptionMatches(PyExc_AttributeError)) + PyErr_Clear(); + else return NULL; + if (!name_str) { + name_str = PyUnicode_InternFromString("__name__"); + if (!name_str) goto done; + } + name = PyObject_GetAttr(func, name_str); + if (!name) goto done; + unbound = PyObject_GetAttr((PyObject*)Py_TYPE(self), name); + Py_DECREF(name); + } + } + +done: + Py_DECREF(self); + return unbound; +} + +static int isclassmethod (PyObject *func) +{ + PyObject *self; + static PyObject *self_str = NULL; + + if (Py_TYPE(func) != &PyMethod_Type && + Py_TYPE(func) != &PyCFunction_Type && + Py_TYPE(func) != &PyClassMethod_Type && + Py_TYPE(func) != &PyClassMethodDescr_Type) return 0; + + if (!self_str) { + self_str = PyUnicode_InternFromString("__self__"); + if (!self_str) return 0; + } + + self = PyObject_GetAttr(func, self_str); + if (self && PyType_Check(self)) { Py_DECREF(self); return 1; } + Py_XDECREF(self); + return 0; +} + +static PyObject * +getattr_recurse (PyObject *obj, PyObject *attr) +{ + static PyObject *locals_str = NULL, *qualname_str = NULL, *dot = NULL; + PyObject *attr_parts, *iter, *item, *crt = obj, *prev; + + assert(PyUnicode_Check(attr)); + + if (locals_str == NULL) { + /* + appears as a token in __qualname__. E.g.: + >>> def f(): + ... def g(): + ... pass + ... return g.__qualname__ + ... + >>> f() + 'f..g' + */ + locals_str = PyUnicode_InternFromString(""); + if (locals_str == NULL) return NULL; + } + if (qualname_str == NULL) { + qualname_str = PyUnicode_InternFromString("__qualname__"); + if (qualname_str == NULL) return NULL; + } + if (dot == NULL) { + dot = PyUnicode_InternFromString("."); + if (dot == NULL) return NULL; + } + + attr_parts = PyUnicode_Split(attr, dot, 128); + if (!attr_parts) + return NULL; + + iter = PyObject_GetIter(attr_parts); + + // Making sure that the first call to Py_DECREF(prev) below won't decrement + // obj's refcount + Py_INCREF(obj); + + while ( (item = PyIter_Next(iter)) ) { + //check item=="" + PyObject *is_locals = PyUnicode_RichCompare(item, locals_str, Py_EQ); + + if (is_locals == Py_True) { + PyObject *qualname = PyObject_GetAttr(crt, qualname_str); + if (qualname == NULL) { crt = NULL; goto error; } + PyErr_Format(PyExc_TypeError, + "Cannot work with the locals of %U", qualname); + Py_DECREF(item); + Py_DECREF(qualname); + Py_DECREF(is_locals); + crt = NULL; + goto error; + } + else if (is_locals == Py_NotImplemented) { + PyErr_BadInternalCall(); + crt = NULL; + Py_DECREF(item); + Py_DECREF(is_locals); + goto error; + } + else if (is_locals == NULL) { + crt = NULL; + Py_DECREF(item); + goto error; + } + + prev = crt; + crt = PyObject_GetAttr(crt, item); + Py_DECREF(prev); + Py_DECREF(is_locals); + if (crt == NULL) { Py_DECREF(item); goto error; } + + Py_DECREF(item); + } + + //iteration failed + if (PyErr_Occurred()) crt = NULL; + +error: + Py_DECREF(iter); + Py_DECREF(attr_parts); + return crt; +} + static int stack_underflow(void) { @@ -1339,9 +1496,11 @@ Py_ssize_t i, j; static PyObject *module_str = NULL; static PyObject *main_str = NULL; + static PyObject *self_str = NULL; PyObject *module_name; PyObject *modules_dict; PyObject *module; + PyObject *self; PyObject *obj; if (module_str == NULL) { @@ -1351,27 +1510,47 @@ main_str = PyUnicode_InternFromString("__main__"); if (main_str == NULL) return NULL; + self_str = PyUnicode_InternFromString("__self__"); + if (self_str == NULL) + return NULL; } module_name = PyObject_GetAttr(global, module_str); /* In some rare cases (e.g., bound methods of extension types), - __module__ can be None. If it is so, then search sys.modules - for the module of global. */ - if (module_name == Py_None) { + __module__ can be None. If it is so, then search sys.modules for the + module of global. Before doing so, check if the global has a __self__ + attribute which in turn has a __module__. */ + if (!module_name) { + if (PyErr_ExceptionMatches(PyExc_AttributeError)) + PyErr_Clear(); + else + return NULL; + } + else if (module_name == Py_None) { Py_DECREF(module_name); - goto search; - } - - if (module_name) { - return module_name; - } - if (PyErr_ExceptionMatches(PyExc_AttributeError)) - PyErr_Clear(); - else - return NULL; - - search: + } + else return module_name; + + self = PyObject_GetAttr(global, self_str); + if (!self) { + if (PyErr_ExceptionMatches(PyExc_AttributeError)) + PyErr_Clear(); + else + return NULL; + } + else { + module_name = PyObject_GetAttr(self, module_str); + Py_DECREF(self); + if (!module_name) { + if (PyErr_ExceptionMatches(PyExc_AttributeError)) + PyErr_Clear(); + else + return NULL; + } + else return module_name; + } + modules_dict = PySys_GetObject("modules"); if (modules_dict == NULL) return NULL; @@ -1382,7 +1561,7 @@ if (PyObject_RichCompareBool(module_name, main_str, Py_EQ) == 1) continue; - obj = PyObject_GetAttr(module, global_name); + obj = getattr_recurse(module, global_name); if (obj == NULL) { if (PyErr_ExceptionMatches(PyExc_AttributeError)) PyErr_Clear(); @@ -2620,6 +2799,69 @@ return status; } +static int save_global_nonbinary( + PicklerObject *self, + PyObject *module_name, + PyObject *global_name) +{ + static char global_op = GLOBAL; + PyObject *encoded; + PyObject *(*unicode_encoder)(PyObject *); + + /* Since Python 3.0 now supports non-ASCII identifiers, we encode both + the module name and the global name using UTF-8. We do so only when + we are using the pickle protocol newer than version 3. This is to + ensure compatibility with older Unpickler running on Python 2.x. */ + if (self->proto >= 3) { + unicode_encoder = PyUnicode_AsUTF8String; + } + else { + unicode_encoder = PyUnicode_AsASCIIString; + } + + if ( _Pickler_Write(self, &global_op, 1) < 0) + return -1; + + /* Save the name of the module. */ + encoded = unicode_encoder(module_name); + if (encoded == NULL) { + if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError)) + PyErr_Format(PicklingError, + "can't pickle module identifier '%S' using " + "pickle protocol %i", module_name, self->proto); + return -1; + } + if (_Pickler_Write(self, PyBytes_AS_STRING(encoded), + PyBytes_GET_SIZE(encoded)) < 0) { + Py_DECREF(encoded); + return -1; + } + Py_DECREF(encoded); + if(_Pickler_Write(self, "\n", 1) < 0) + return -1; + + /* Save the name of the global. */ + encoded = unicode_encoder(global_name); + if (encoded == NULL) { + if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError)) + PyErr_Format(PicklingError, + "can't pickle global identifier '%S' using " + "pickle protocol %i", global_name, self->proto); + return -1; + } + if (_Pickler_Write(self, PyBytes_AS_STRING(encoded), + PyBytes_GET_SIZE(encoded)) < 0) { + Py_DECREF(encoded); + return -1; + } + Py_DECREF(encoded); + if(_Pickler_Write(self, "\n", 1) < 0) + return -1; + + return 0; +} + + static int save_set(PicklerObject *self, PyObject *obj) { @@ -2772,30 +3014,159 @@ return 0; } +/* + * Only for pickle >= 4. + * Uses opcodes BINGLOBAL, BINGLOBAL_BIG + */ +static int save_global_binary( + PicklerObject *self, + PyObject *module_name, + PyObject *global_name) +{ + char global_op; + int return_code = 0; + PyObject *encoded_module_name, *encoded_global_name; + Py_ssize_t encoded_module_size, encoded_global_size; + + assert(module_name != NULL && global_name != NULL); + + encoded_module_name = PyUnicode_AsUTF8String(module_name); + if (encoded_module_name == NULL) { + if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError)) + PyErr_Format(PicklingError, + "can't pickle module identifier '%S' using " + "pickle protocol %i", encoded_module_name, + self->proto); + return -1; + } + encoded_module_size = PyBytes_GET_SIZE(encoded_module_name); + if (encoded_module_size < 0) { + Py_DECREF(encoded_module_name); + return -1; + } + + encoded_global_name = PyUnicode_AsUTF8String(global_name); + if (encoded_global_name == NULL) { + if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError)) + PyErr_Format(PicklingError, + "can't pickle global identifier '%S' using " + "pickle protocol %i", global_name, self->proto); + Py_DECREF(encoded_module_name); + return -1; + } + encoded_global_size = PyBytes_GET_SIZE(encoded_global_name); + if (encoded_global_size < 0) goto error; + + /* BINGLOBAL */ + if (encoded_module_size <= 0xff && encoded_global_size <= 0xff) { + char module_size_byte = encoded_module_size, + global_size_byte = encoded_global_size; + + /* write the opcode */ + global_op = BINGLOBAL; + if (_Pickler_Write(self, &global_op, 1) < 0) + goto error; + + /* write the size of the module (1 byte) */ + if (_Pickler_Write(self, &module_size_byte, 1) < 0) + goto error; + + /* write the module name */ + if (_Pickler_Write(self, PyBytes_AS_STRING(encoded_module_name), + encoded_module_size) < 0) + goto error; + + /* write the size of the global (1 byte) */ + if (_Pickler_Write(self, &global_size_byte, 1) < 0) + goto error; + + /* write the global name */ + if (_Pickler_Write(self, PyBytes_AS_STRING(encoded_global_name), + encoded_global_size) < 0) + goto error; + + } + /* BINGLOBAL_BIG */ + else { + char data[2]; + /* nearly useless checks */ + if (encoded_module_size > 0xffff) { + PyErr_Format(PyExc_OverflowError, "Unusually large module name."); + goto error; + } + else if (encoded_global_size > 0xffff) { + PyErr_Format(PyExc_OverflowError, "Unusually large global name."); + goto error; + } + + /* write the opcode */ + global_op = BINGLOBAL_BIG; + if (_Pickler_Write(self, &global_op, 1) < 0) + goto error; + + /* write the size of the module (2 bytes) */ + data[0] = (unsigned char)(encoded_module_size & 0xff); + data[1] = (unsigned char)((encoded_module_size >> 8) & 0xff); + if (_Pickler_Write(self, data, 2) < 0) + goto error; + + /* write the module name */ + if (_Pickler_Write(self, PyBytes_AS_STRING(encoded_module_name), + encoded_module_size) < 0) + goto error; + + /* write the size of the global (2 bytes) */ + data[0] = (unsigned char)(encoded_global_size & 0xff); + data[1] = (unsigned char)((encoded_global_size >> 8) & 0xff); + if (_Pickler_Write(self, data, 2) < 0) + goto error; + + /* write the global name */ + if (_Pickler_Write(self, PyBytes_AS_STRING(encoded_global_name), + encoded_global_size) < 0) + goto error; + } + + if (0) { + // only goto error after both encoded_global_name + // and encoded_module_name have been initialized +error: + return_code = -1; + } + Py_DECREF(encoded_module_name); + Py_DECREF(encoded_global_name); + return return_code; +} + static int save_global(PicklerObject *self, PyObject *obj, PyObject *name) { - static PyObject *name_str = NULL; + static PyObject *name_str = NULL, + *qualname_str = NULL; PyObject *global_name = NULL; PyObject *module_name = NULL; PyObject *module = NULL; PyObject *cls; int status = 0; - const char global_op = GLOBAL; - - if (name_str == NULL) { + if (self->proto < 4 && name_str == NULL) { name_str = PyUnicode_InternFromString("__name__"); if (name_str == NULL) goto error; } + else if (self->proto >= 4 && qualname_str == NULL) { + qualname_str = PyUnicode_InternFromString("__qualname__"); + if (qualname_str == NULL) + goto error; + } if (name) { global_name = name; Py_INCREF(global_name); } else { - global_name = PyObject_GetAttr(obj, name_str); + global_name = PyObject_GetAttr(obj, + self->proto >= 4 ? qualname_str : name_str); if (global_name == NULL) goto error; } @@ -2819,14 +3190,21 @@ obj, module_name); goto error; } - cls = PyObject_GetAttr(module, global_name); + if (self->proto < 4) { + cls = PyObject_GetAttr(module, global_name); + } + else { + cls = getattr_recurse(module, global_name); + } if (cls == NULL) { PyErr_Format(PicklingError, "Can't pickle %R: attribute lookup %S.%S failed", obj, module_name, global_name); goto error; } - if (cls != obj) { + // we ignore this step for classmethods because + // "dict.fromkeys is dict.fromkeys" evaluates to false + if (cls != obj && !isclassmethod(obj)) { Py_DECREF(cls); PyErr_Format(PicklingError, "Can't pickle %R: it's not the same object as %S.%S", @@ -2867,8 +3245,8 @@ if (code <= 0 || code > 0x7fffffffL) { if (!PyErr_Occurred()) PyErr_Format(PicklingError, - "Can't pickle %R: extension code %ld is out of range", - obj, code); + "Can't pickle %R: extension code %ld is out of range", + obj, code); goto error; } @@ -2900,23 +3278,9 @@ /* Generate a normal global opcode if we are using a pickle protocol <= 2, or if the object is not registered in the extension registry. */ - PyObject *encoded; - PyObject *(*unicode_encoder)(PyObject *); gen_global: - if (_Pickler_Write(self, &global_op, 1) < 0) - goto error; - - /* Since Python 3.0 now supports non-ASCII identifiers, we encode both - the module name and the global name using UTF-8. We do so only when - we are using the pickle protocol newer than version 3. This is to - ensure compatibility with older Unpickler running on Python 2.x. */ - if (self->proto >= 3) { - unicode_encoder = PyUnicode_AsUTF8String; - } - else { - unicode_encoder = PyUnicode_AsASCIIString; - } + /* For protocol < 3 and if the user didn't request against doing so, we convert module names to the old 2.x module names. */ @@ -2974,42 +3338,17 @@ goto error; } } - - /* Save the name of the module. */ - encoded = unicode_encoder(module_name); - if (encoded == NULL) { - if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError)) - PyErr_Format(PicklingError, - "can't pickle module identifier '%S' using " - "pickle protocol %i", module_name, self->proto); - goto error; + + if (self->proto < 4) { + //uses opcode GLOBAL + if (save_global_nonbinary(self, module_name, global_name) < 0) + goto error; } - if (_Pickler_Write(self, PyBytes_AS_STRING(encoded), - PyBytes_GET_SIZE(encoded)) < 0) { - Py_DECREF(encoded); - goto error; + else if (self->proto >= 4) { + //uses one of the opcodes: BINGLOBAL or BINGLOBAL_BIG + if (save_global_binary(self, module_name, global_name) < 0) + goto error; } - Py_DECREF(encoded); - if(_Pickler_Write(self, "\n", 1) < 0) - goto error; - - /* Save the name of the module. */ - encoded = unicode_encoder(global_name); - if (encoded == NULL) { - if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError)) - PyErr_Format(PicklingError, - "can't pickle global identifier '%S' using " - "pickle protocol %i", global_name, self->proto); - goto error; - } - if (_Pickler_Write(self, PyBytes_AS_STRING(encoded), - PyBytes_GET_SIZE(encoded)) < 0) { - Py_DECREF(encoded); - goto error; - } - Py_DECREF(encoded); - if(_Pickler_Write(self, "\n", 1) < 0) - goto error; /* Memoize the object. */ if (memo_put(self, obj) < 0) @@ -3028,6 +3367,77 @@ } static int +save_global_or_method(PicklerObject *self, PyObject *obj) +{ + PyObject *unbound, *obj_self = NULL, *tuple, *inner_tuple; + static PyObject *str_self = NULL, *binding_function = NULL, + *pickle_str = NULL; + int ret = -1; + + unbound = unbind(obj); + if (unbound == NULL) { + if (PyErr_ExceptionMatches(PyExc_TypeError)) { + PyErr_Clear(); + return save_global(self, obj, NULL); + } + return -1; + } + else if (self->proto < 4) { + PyErr_SetString(PicklingError, + "Can't pickle bound methods in pickle<4"); + Py_DECREF(unbound); + return -1; + } + else { + if (pickle_str == NULL) { + pickle_str = PyUnicode_InternFromString("pickle"); + if (pickle_str == NULL) { + Py_DECREF(unbound); + return -1; + } + } + if (binding_function == NULL) { + PyObject *pickle_module = PyImport_Import(pickle_str); + if (pickle_module == NULL) { + Py_DECREF(unbound); + return -1; + } + binding_function = PyObject_GetAttrString(pickle_module, + "_bind_method"); + if (binding_function == NULL) { + Py_DECREF(unbound); + return -1; + } + } + if (str_self == NULL) { + str_self = PyUnicode_InternFromString("__self__"); + if (str_self == NULL) { + Py_DECREF(unbound); + return -1; + } + } + + obj_self = PyObject_GetAttr(obj, str_self); + if (obj_self == NULL) { + Py_DECREF(unbound); + return -1; + } + + inner_tuple = PyTuple_Pack(2, obj_self, unbound); + if (!inner_tuple) goto done; + + tuple = PyTuple_Pack(2, binding_function, inner_tuple); + if (!tuple) goto done; + + ret = save_reduce(self, tuple, obj); + Py_DECREF(tuple); +done: + Py_DECREF(obj_self); + Py_DECREF(unbound); + return ret; + } +} +static int save_ellipsis(PicklerObject *self, PyObject *obj) { PyObject *str = PyUnicode_FromString("Ellipsis"); @@ -3441,7 +3851,13 @@ goto done; } } - else if (type == &PyCFunction_Type) { + else if (type == &PyCFunction_Type || type == &PyMethod_Type || + type == &_PyMethodWrapper_Type) { + status = save_global_or_method(self, obj); + goto done; + } + else if (type == &PyWrapperDescr_Type || type == &PyMethodDescr_Type || + type == &PyClassMethodDescr_Type) { status = save_global(self, obj, NULL); goto done; } @@ -4840,6 +5256,97 @@ } static int +load_binglobal(UnpicklerObject *self) +{ + PyObject *module_name, *global_name, *global = NULL; + char *s; + Py_ssize_t encoded_size; + + /* read module's size (1 byte) */ + if (_Unpickler_Read(self, &s, 1) < 1) + return -1; + encoded_size = (unsigned char)s[0]; + + /* read module name */ + if (_Unpickler_Read(self, &s, encoded_size) < encoded_size) + return -1; + module_name = PyUnicode_DecodeUTF8(s, encoded_size, "strict"); + if (!module_name) + return -1; + + /* read global's size */ + if (_Unpickler_Read(self, &s, 1) < 1) + return -1; + encoded_size = (unsigned char)s[0]; + + /* read global name */ + if (_Unpickler_Read(self, &s, encoded_size) < encoded_size) { + Py_DECREF(module_name); + return -1; + } + global_name = PyUnicode_DecodeUTF8(s, encoded_size, "strict"); + + if (global_name) { + global = find_class(self, module_name, global_name); + Py_DECREF(global_name); + } + + Py_DECREF(module_name); + + if (global) { + PDATA_PUSH(self->stack, global, -1); + return 0; + } + return -1; +} + +static int +load_binglobal_big(UnpicklerObject *self) +{ + /* like load_binglobal, s/1/2/g */ + PyObject *module_name, *global_name, *global = NULL; + char *s; + Py_ssize_t encoded_size; + + /* read module's size (2 bytes) */ + if (_Unpickler_Read(self, &s, 2) < 2) + return -1; + encoded_size = (Py_ssize_t)(s[0]) | ((Py_ssize_t)(s[1])<<8); + + /* read module name */ + if (_Unpickler_Read(self, &s, encoded_size) < encoded_size) + return -1; + module_name = PyUnicode_DecodeUTF8(s, encoded_size, "strict"); + if (!module_name) + return -1; + + /* read global's size */ + if (_Unpickler_Read(self, &s, 2) < 2) + return -1; + encoded_size = (Py_ssize_t)(s[0]) | ((Py_ssize_t)(s[1])<<8); + + /* read global name */ + if (_Unpickler_Read(self, &s, encoded_size) < encoded_size) { + Py_DECREF(module_name); + return -1; + } + global_name = PyUnicode_DecodeUTF8(s, encoded_size, "strict"); + + if (global_name) { + global = find_class(self, module_name, global_name); + Py_DECREF(global_name); + } + + Py_DECREF(module_name); + + if (global) { + PDATA_PUSH(self->stack, global, -1); + return 0; + } + return -1; +} + +static int load_global(UnpicklerObject *self) { PyObject *global = NULL; @@ -5645,6 +6152,8 @@ OP(INST, load_inst) OP(NEWOBJ, load_newobj) OP(GLOBAL, load_global) + OP(BINGLOBAL, load_binglobal) + OP(BINGLOBAL_BIG, load_binglobal_big) OP(APPEND, load_append) OP(APPENDS, load_appends) OP(BUILD, load_build) @@ -5814,11 +6323,15 @@ module = PyImport_Import(module_name); if (module == NULL) return NULL; + if (self->proto < 4) global = PyObject_GetAttr(module, global_name); + else global = getattr_recurse(module, global_name); + Py_DECREF(module); + } + else if (self->proto < 4) { global = PyObject_GetAttr(module, global_name); - Py_DECREF(module); } else { - global = PyObject_GetAttr(module, global_name); + global = getattr_recurse(module, global_name); } return global; }