diff -r a75b88048339 -r 8434af450da0 Lib/copyreg.py --- a/Lib/copyreg.py Thu Nov 14 16:16:29 2013 -0800 +++ b/Lib/copyreg.py Fri Nov 15 03:07:56 2013 -0800 @@ -87,6 +87,12 @@ def __newobj__(cls, *args): return cls.__new__(cls, *args) +def __newobj_ex__(cls, args, kwargs): + """Used by pickle protocol 4, instead of __newobj__ to allow classes with + keyword-only arguments to be pickled correctly. + """ + return cls.__new__(cls, *args, **kwargs) + def _slotnames(cls): """Return a list of slot names for a given class. diff -r a75b88048339 -r 8434af450da0 Lib/pickle.py --- a/Lib/pickle.py Thu Nov 14 16:16:29 2013 -0800 +++ b/Lib/pickle.py Fri Nov 15 03:07:56 2013 -0800 @@ -23,7 +23,7 @@ """ -from types import FunctionType, BuiltinFunctionType +from types import FunctionType, BuiltinFunctionType, ModuleType from copyreg import dispatch_table from copyreg import _extension_registry, _inverted_registry, _extension_cache from itertools import islice @@ -42,17 +42,18 @@ bytes_types = (bytes, bytearray) # These are purely informational; no code uses these. -format_version = "3.0" # File format version we write +format_version = "4.0" # File format version we write compatible_formats = ["1.0", # Original protocol 0 "1.1", # Protocol 0 with INST added "1.2", # Original protocol 1 "1.3", # Protocol 1 with BINFLOAT added "2.0", # Protocol 2 "3.0", # Protocol 3 + "4.0", # Protocol 4 ] # Old format versions we can read # This is the highest protocol number we know how to read. -HIGHEST_PROTOCOL = 3 +HIGHEST_PROTOCOL = 4 # The protocol we write by default. May be less than HIGHEST_PROTOCOL. # We intentionally write a protocol that Python 2.x cannot read; @@ -164,7 +165,190 @@ BINBYTES = b'B' # push bytes; counted binary string argument SHORT_BINBYTES = b'C' # " " ; " " " " < 256 bytes -__all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$",x)]) +# Protocol 4 +SHORT_BINUNICODE = b'\x8c' # push short string; UTF-8 length < 256 bytes +BINUNICODE8 = b'\x8d' # push very long string +BINBYTES8 = b'\x8e' # push very long bytes string +EMPTY_SET = b'\x8f' # push empty set on the stack +ADDITEMS = b'\x90' # modify set by adding topmost stack items +EMPTY_FROZENSET = b'\x91' # push empty frozenset on the stack +FROZENSET = b'\x92' # build frozenset from topmost stack items +NEWOBJ_EX = b'\x93' # like NEWOBJ but work with keyword only arguments +STACK_GLOBAL = b'\x94' # same as GLOBAL but using names on the stacks + +__all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$", x)]) + + +class _Framer: + + _FRAME_SIZE_TARGET = 64 * 1024 + + def __init__(self, file_write): + self.file_write = file_write + self.current_frame = None + + def _commit_frame(self): + f = self.current_frame + with f.getbuffer() as data: + n = len(data) + write = self.file_write + write(pack("= self._FRAME_SIZE_TARGET: + self._commit_frame() + return f.write(data) + +class _Unframer: + + def __init__(self, file_read, file_readline, file_tell=None): + self.file_read = file_read + self.file_readline = file_readline + self.file_tell = file_tell + self.framing_enabled = False + self.current_frame = None + self.frame_start = None + + def read(self, n): + if n == 0: + return b'' + _file_read = self.file_read + if not self.framing_enabled: + return _file_read(n) + f = self.current_frame + if f is not None: + data = f.read(n) + if data: + if len(data) < n: + raise UnpicklingError( + "pickle exhausted before end of frame") + return data + frame_size, = unpack(" sys.maxsize: + raise ValueError("frame size > sys.maxsize: %d" % frame_size) + if self.file_tell is not None: + self.frame_start = self.file_tell() + f = self.current_frame = io.BytesIO(_file_read(frame_size)) + self.readline = f.readline + data = f.read(n) + assert len(data) == n, (len(data), n) + return data + + def readline(self): + if not self.framing_enabled: + return self.file_readline() + else: + return self.current_frame.readline() + + def tell(self): + if self.file_tell is None: + return None + elif self.current_frame is None: + return self.file_tell() + else: + return self.frame_start + self.current_frame.tell() + + +# Tools used for pickling. + +def _getattribute(obj, name, allow_qualname=False): + dotted_path = name.split(".") + if not allow_qualname and len(dotted_path) > 1: + raise AttributeError("Can't get qualified attribute {!r} on {!r}; " + + "use protocols >= 4 to enable support" + .format(name, obj)) + for subpath in dotted_path: + if subpath == '': + raise AttributeError("Can't get local attribute {!r} on {!r}" + .format(name, obj)) + try: + obj = getattr(obj, subpath) + except AttributeError: + raise AttributeError("Can't get attribute {!r} on {!r}" + .format(name, obj)) + return obj + +def whichmodule(obj, name, allow_qualname=False): + """Find the module an object belong to.""" + module_name = getattr(obj, '__module__', None) + if module_name is not None: + return module_name + for module_name, module in sys.modules.items(): + if module_name == '__main__' or module is None: + continue + try: + if _getattribute(module, name, allow_qualname) is obj: + return module_name + except AttributeError: + pass + return '__main__' + +def encode_long(x): + r"""Encode a long to a two's complement little-endian binary string. + Note that 0 is a special case, returning an empty string, to save a + byte in the LONG1 pickling context. + + >>> encode_long(0) + b'' + >>> encode_long(255) + b'\xff\x00' + >>> encode_long(32767) + b'\xff\x7f' + >>> encode_long(-256) + b'\x00\xff' + >>> encode_long(-32768) + b'\x00\x80' + >>> encode_long(-128) + b'\x80' + >>> encode_long(127) + b'\x7f' + >>> + """ + if x == 0: + return b'' + nbytes = (x.bit_length() >> 3) + 1 + result = x.to_bytes(nbytes, byteorder='little', signed=True) + if x < 0 and nbytes > 1: + if result[-1] == 0xff and (result[-2] & 0x80) != 0: + result = result[:-1] + return result + +def decode_long(data): + r"""Decode a long from a two's complement little-endian binary string. + + >>> decode_long(b'') + 0 + >>> decode_long(b"\xff\x00") + 255 + >>> decode_long(b"\xff\x7f") + 32767 + >>> decode_long(b"\x00\xff") + -256 + >>> decode_long(b"\x00\x80") + -32768 + >>> decode_long(b"\x80") + -128 + >>> decode_long(b"\x7f") + 127 + """ + return int.from_bytes(data, byteorder='little', signed=True) + # Pickling machinery @@ -174,9 +358,9 @@ """This takes a binary file for writing a pickle data stream. The optional protocol argument tells the pickler to use the - given protocol; supported protocols are 0, 1, 2, 3. The default - protocol is 3; a backward-incompatible protocol designed for - Python 3.0. + given protocol; supported protocols are 0, 1, 2, 3 and 4. The + default protocol is 3; a backward-incompatible protocol designed for + Python 3. Specifying a negative protocol version selects the highest protocol version supported. The higher the protocol used, the @@ -189,8 +373,8 @@ meets this interface. If fix_imports is True and protocol is less than 3, pickle will try to - map the new Python 3.x names to the old module names used in Python - 2.x, so that the pickle data stream is readable with Python 2.x. + map the new Python 3 names to the old module names used in Python 2, + so that the pickle data stream is readable with Python 2. """ if protocol is None: protocol = DEFAULT_PROTOCOL @@ -199,7 +383,7 @@ elif not 0 <= protocol <= HIGHEST_PROTOCOL: raise ValueError("pickle protocol must be <= %d" % HIGHEST_PROTOCOL) try: - self.write = file.write + self._file_write = file.write except AttributeError: raise TypeError("file must have a 'write' attribute") self.memo = {} @@ -223,13 +407,22 @@ """Write a pickled representation of obj to the open file.""" # Check whether Pickler was initialized correctly. This is # only needed to mimic the behavior of _pickle.Pickler.dump(). - if not hasattr(self, "write"): + if not hasattr(self, "_file_write"): raise PicklingError("Pickler.__init__() was not called by " "%s.__init__()" % (self.__class__.__name__,)) if self.proto >= 2: - self.write(PROTO + pack("= 4: + framer = _Framer(self._file_write) + framer.start_framing() + self.write = framer.write + else: + framer = None + self.write = self._file_write self.save(obj) self.write(STOP) + if framer is not None: + framer.end_framing() def memoize(self, obj): """Store an object in the memo.""" @@ -349,24 +542,33 @@ else: self.write(PERSID + str(pid).encode("ascii") + b'\n') - def save_reduce(self, func, args, state=None, - listitems=None, dictitems=None, obj=None): + def save_reduce(self, func, args, state=None, listitems=None, + dictitems=None, obj=None): # This API is called by some subclasses - # Assert that args is a tuple if not isinstance(args, tuple): - raise PicklingError("args from save_reduce() should be a tuple") - - # Assert that func is callable + raise PicklingError("args from save_reduce() must be a tuple") if not callable(func): - raise PicklingError("func from save_reduce() should be callable") + raise PicklingError("func from save_reduce() must be callable") save = self.save write = self.write - # Protocol 2 special case: if func's name is __newobj__, use NEWOBJ - if self.proto >= 2 and getattr(func, "__name__", "") == "__newobj__": - # A __reduce__ implementation can direct protocol 2 to + func_name = getattr(func, "__name__", "") + if self.proto >= 4 and func_name == "__newobj_ex__": + cls, args, kwargs = args + if not hasattr(cls, "__new__"): + raise PicklingError("args[0] from {} args has no __new__" + .format(func_name)) + if obj is not None and cls is not obj.__class__: + raise PicklingError("args[0] from {} args has the wrong class" + .format(func_name)) + save(cls) + save(args) + save(kwargs) + write(NEWOBJ_EX) + elif self.proto >= 2 and func_name == "__newobj__": + # A __reduce__ implementation can direct protocol 2 or newer to # use the more efficient NEWOBJ opcode, while still # allowing protocol 0 and 1 to work normally. For this to # work, the function returned by __reduce__ should be @@ -409,7 +611,13 @@ write(REDUCE) if obj is not None: - self.memoize(obj) + # If the object is already in the memo, this means it is + # recursive. In this case, throw away everything we put on the + # stack, and fetch the object back from the memo. + if id(obj) in self.memo: + write(POP + self.get(self.memo[id(obj)][0])) + else: + self.memoize(obj) # More new special cases (that work with older protocols as # well): when __reduce__ returns a tuple with 4 or 5 items, @@ -493,8 +701,10 @@ (str(obj, 'latin1'), 'latin1'), obj=obj) return n = len(obj) - if n < 256: + if n <= 0xff: self.write(SHORT_BINBYTES + pack(" 0xffffffff and self.proto >= 4: + self.write(BINBYTES8 + pack("= 4: + self.write(SHORT_BINUNICODE + pack(" 0xffffffff and self.proto >= 4: + self.write(BINUNICODE8 + pack(" 0: + write(MARK) + for item in batch: + save(item) + write(ADDITEMS) + if n < self._BATCHSIZE: + return + dispatch[set] = save_set + + def save_frozenset(self, obj): + save = self.save + write = self.write + + if self.proto < 4: + self.save_reduce(set, (list(obj),), obj=obj) + return + + if not obj: + write(EMPTY_FROZENSET) + return + + write(MARK) + for item in obj: + save(item) + + if id(obj) in self.memo: + # If the object is already in the memo, this means it is + # recursive. In this case, throw away everything we put on the + # stack, and fetch the object back from the memo. + write(POP_MARK + self.get(self.memo[id(obj)][0])) + return + + write(FROZENSET) + self.memoize(obj) + dispatch[frozenset] = save_frozenset + def save_global(self, obj, name=None): write = self.write memo = self.memo + if name is None and self.proto >= 4: + name = getattr(obj, '__qualname__', None) if name is None: name = obj.__name__ - module = getattr(obj, "__module__", None) - if module is None: - module = whichmodule(obj, name) - + module_name = whichmodule(obj, name, allow_qualname=self.proto >= 4) try: - __import__(module, level=0) - mod = sys.modules[module] - klass = getattr(mod, name) + __import__(module_name, level=0) + module = sys.modules[module_name] + obj2 = _getattribute(module, name, allow_qualname=self.proto >= 4) + if obj2 is not obj: + raise PicklingError( + "Can't pickle %r: it's not the same object as %s.%s" % + (obj, module_name, name)) except (ImportError, KeyError, AttributeError): raise PicklingError( "Can't pickle %r: it's not found as %s.%s" % - (obj, module, name)) - else: - if klass is not obj: - raise PicklingError( - "Can't pickle %r: it's not the same object as %s.%s" % - (obj, module, name)) + (obj, module_name, name)) if self.proto >= 2: - code = _extension_registry.get((module, name)) + code = _extension_registry.get((module_name, name)) if code: assert code > 0 if code <= 0xff: @@ -684,17 +949,23 @@ write(EXT4 + pack("= 3. - if self.proto >= 3: - write(GLOBAL + bytes(module, "utf-8") + b'\n' + + if self.proto >= 4: + self.save(module_name) + self.save(name) + write(STACK_GLOBAL) + elif self.proto >= 3: + write(GLOBAL + bytes(module_name, "utf-8") + b'\n' + bytes(name, "utf-8") + b'\n') else: if self.fix_imports: - if (module, name) in _compat_pickle.REVERSE_NAME_MAPPING: - module, name = _compat_pickle.REVERSE_NAME_MAPPING[(module, name)] - if module in _compat_pickle.REVERSE_IMPORT_MAPPING: - module = _compat_pickle.REVERSE_IMPORT_MAPPING[module] + r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING + r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING + if (module_name, name) in r_name_mapping: + module_name, name = r_name_mapping[(module_name, name)] + if module_name in r_import_mapping: + module_name = r_import_mapping[module_name] try: - write(GLOBAL + bytes(module, "ascii") + b'\n' + + write(GLOBAL + bytes(module_name, "ascii") + b'\n' + bytes(name, "ascii") + b'\n') except UnicodeEncodeError: raise PicklingError( @@ -703,40 +974,16 @@ self.memoize(obj) + def save_method(self, obj): + if obj.__self__ is None or type(obj.__self__) is ModuleType: + self.save_global(obj) + else: + self.save_reduce(getattr, (obj.__self__, obj.__name__), obj=obj) + dispatch[FunctionType] = save_global - dispatch[BuiltinFunctionType] = save_global + dispatch[BuiltinFunctionType] = save_method dispatch[type] = save_global -# A cache for whichmodule(), mapping a function object to the name of -# the module in which the function was found. - -classmap = {} # called classmap for backwards compatibility - -def whichmodule(func, funcname): - """Figure out the module in which a function occurs. - - Search sys.modules for the module. - Cache in classmap. - Return a module name. - If the function cannot be found, return "__main__". - """ - # Python functions should always get an __module__ from their globals. - mod = getattr(func, "__module__", None) - if mod is not None: - return mod - if func in classmap: - return classmap[func] - - for name, module in list(sys.modules.items()): - if module is None: - continue # skip dummy package entries - if name != '__main__' and getattr(module, funcname, None) is func: - break - else: - name = '__main__' - classmap[func] = name - return name - # Unpickling machinery @@ -764,8 +1011,8 @@ instances pickled by Python 2.x; these default to 'ASCII' and 'strict', respectively. """ - self.readline = file.readline - self.read = file.read + self._file_readline = file.readline + self._file_read = file.read self.memo = {} self.encoding = encoding self.errors = errors @@ -779,12 +1026,16 @@ """ # Check whether Unpickler was initialized correctly. This is # only needed to mimic the behavior of _pickle.Unpickler.dump(). - if not hasattr(self, "read"): + if not hasattr(self, "_file_read"): raise UnpicklingError("Unpickler.__init__() was not called by " "%s.__init__()" % (self.__class__.__name__,)) + self._unframer = _Unframer(self._file_read, self._file_readline) + self.read = self._unframer.read + self.readline = self._unframer.readline self.mark = object() # any new unique object self.stack = [] self.append = self.stack.append + self.proto = 0 read = self.read dispatch = self.dispatch try: @@ -822,6 +1073,8 @@ if not 0 <= proto <= HIGHEST_PROTOCOL: raise ValueError("unsupported pickle protocol: %d" % proto) self.proto = proto + if proto >= 4: + self._unframer.framing_enabled = True dispatch[PROTO[0]] = load_proto def load_persid(self): @@ -940,6 +1193,14 @@ self.append(str(self.read(len), 'utf-8', 'surrogatepass')) dispatch[BINUNICODE[0]] = load_binunicode + def load_binunicode8(self): + len, = unpack(' maxsize: + raise UnpicklingError("BINUNICODE8 exceeds system's maximum size " + "of %d bytes" % maxsize) + self.append(str(self.read(len), 'utf-8', 'surrogatepass')) + dispatch[BINUNICODE8[0]] = load_binunicode8 + def load_short_binstring(self): len = self.read(1)[0] data = self.read(len) @@ -952,6 +1213,11 @@ self.append(self.read(len)) dispatch[SHORT_BINBYTES[0]] = load_short_binbytes + def load_short_binunicode(self): + len = self.read(1)[0] + self.append(str(self.read(len), 'utf-8', 'surrogatepass')) + dispatch[SHORT_BINUNICODE[0]] = load_short_binunicode + def load_tuple(self): k = self.marker() self.stack[k:] = [tuple(self.stack[k+1:])] @@ -981,6 +1247,19 @@ self.append({}) dispatch[EMPTY_DICT[0]] = load_empty_dictionary + def load_empty_set(self): + self.append(set()) + dispatch[EMPTY_SET[0]] = load_empty_set + + def load_empty_frozenset(self): + self.append(frozenset()) + dispatch[EMPTY_FROZENSET[0]] = load_empty_frozenset + + def load_frozenset(self): + k = self.marker() + self.stack[k:] = [frozenset(self.stack[k+1:])] + dispatch[FROZENSET[0]] = load_frozenset + def load_list(self): k = self.marker() self.stack[k:] = [self.stack[k+1:]] @@ -1029,11 +1308,19 @@ def load_newobj(self): args = self.stack.pop() - cls = self.stack[-1] + cls = self.stack.pop() obj = cls.__new__(cls, *args) - self.stack[-1] = obj + self.append(obj) dispatch[NEWOBJ[0]] = load_newobj + def load_newobj_ex(self): + kwargs = self.stack.pop() + args = self.stack.pop() + cls = self.stack.pop() + obj = cls.__new__(cls, *args, **kwargs) + self.append(obj) + dispatch[NEWOBJ_EX[0]] = load_newobj_ex + def load_global(self): module = self.readline()[:-1].decode("utf-8") name = self.readline()[:-1].decode("utf-8") @@ -1041,6 +1328,14 @@ self.append(klass) dispatch[GLOBAL[0]] = load_global + def load_stack_global(self): + name = self.stack.pop() + module = self.stack.pop() + if type(name) is not str or type(module) is not str: + raise UnpicklingError("STACK_GLOBAL requires str") + self.append(self.find_class(module, name)) + dispatch[STACK_GLOBAL[0]] = load_stack_global + def load_ext1(self): code = self.read(1)[0] self.get_extension(code) @@ -1080,9 +1375,8 @@ if module in _compat_pickle.IMPORT_MAPPING: module = _compat_pickle.IMPORT_MAPPING[module] __import__(module, level=0) - mod = sys.modules[module] - klass = getattr(mod, name) - return klass + return _getattribute(sys.modules[module], name, + allow_qualname=self.proto >= 4) def load_reduce(self): stack = self.stack @@ -1185,6 +1479,20 @@ del stack[mark:] dispatch[SETITEMS[0]] = load_setitems + def load_additems(self): + stack = self.stack + mark = self.marker() + set_obj = stack[mark - 1] + items = stack[mark + 1:] + if isinstance(set_obj, set): + set_obj.update(items) + else: + add = set_obj.add + for item in items: + add(item) + del stack[mark:] + dispatch[ADDITEMS[0]] = load_additems + def load_build(self): stack = self.stack state = stack.pop() @@ -1218,86 +1526,46 @@ raise _Stop(value) dispatch[STOP[0]] = load_stop -# Encode/decode ints. - -def encode_long(x): - r"""Encode a long to a two's complement little-endian binary string. - Note that 0 is a special case, returning an empty string, to save a - byte in the LONG1 pickling context. - - >>> encode_long(0) - b'' - >>> encode_long(255) - b'\xff\x00' - >>> encode_long(32767) - b'\xff\x7f' - >>> encode_long(-256) - b'\x00\xff' - >>> encode_long(-32768) - b'\x00\x80' - >>> encode_long(-128) - b'\x80' - >>> encode_long(127) - b'\x7f' - >>> - """ - if x == 0: - return b'' - nbytes = (x.bit_length() >> 3) + 1 - result = x.to_bytes(nbytes, byteorder='little', signed=True) - if x < 0 and nbytes > 1: - if result[-1] == 0xff and (result[-2] & 0x80) != 0: - result = result[:-1] - return result - -def decode_long(data): - r"""Decode an int from a two's complement little-endian binary string. - - >>> decode_long(b'') - 0 - >>> decode_long(b"\xff\x00") - 255 - >>> decode_long(b"\xff\x7f") - 32767 - >>> decode_long(b"\x00\xff") - -256 - >>> decode_long(b"\x00\x80") - -32768 - >>> decode_long(b"\x80") - -128 - >>> decode_long(b"\x7f") - 127 - """ - return int.from_bytes(data, byteorder='little', signed=True) # Shorthands -def dump(obj, file, protocol=None, *, fix_imports=True): - Pickler(file, protocol, fix_imports=fix_imports).dump(obj) +def _dump(obj, file, protocol=None, *, fix_imports=True): + _Pickler(file, protocol, fix_imports=fix_imports).dump(obj) -def dumps(obj, protocol=None, *, fix_imports=True): +def _dumps(obj, protocol=None, *, fix_imports=True): f = io.BytesIO() - Pickler(f, protocol, fix_imports=fix_imports).dump(obj) + _Pickler(f, protocol, fix_imports=fix_imports).dump(obj) res = f.getvalue() assert isinstance(res, bytes_types) return res -def load(file, *, fix_imports=True, encoding="ASCII", errors="strict"): - return Unpickler(file, fix_imports=fix_imports, +def _load(file, *, fix_imports=True, encoding="ASCII", errors="strict"): + return _Unpickler(file, fix_imports=fix_imports, encoding=encoding, errors=errors).load() -def loads(s, *, fix_imports=True, encoding="ASCII", errors="strict"): +def _loads(s, *, fix_imports=True, encoding="ASCII", errors="strict"): if isinstance(s, str): raise TypeError("Can't load pickle from unicode string") file = io.BytesIO(s) - return Unpickler(file, fix_imports=fix_imports, - encoding=encoding, errors=errors).load() + return _Unpickler(file, fix_imports=fix_imports, + encoding=encoding, errors=errors).load() # Use the faster _pickle if possible try: - from _pickle import * + from _pickle import ( + PickleError, + PicklingError, + UnpicklingError, + Pickler, + Unpickler, + dump, + dumps, + load, + loads + ) except ImportError: Pickler, Unpickler = _Pickler, _Unpickler + dump, dumps, load, loads = _dump, _dumps, _load, _loads # Doctest def _test(): diff -r a75b88048339 -r 8434af450da0 Lib/pickletools.py --- a/Lib/pickletools.py Thu Nov 14 16:16:29 2013 -0800 +++ b/Lib/pickletools.py Fri Nov 15 03:07:56 2013 -0800 @@ -11,6 +11,7 @@ ''' import codecs +import io import pickle import re import sys @@ -168,6 +169,7 @@ TAKEN_FROM_ARGUMENT1 = -2 # num bytes is 1-byte unsigned int TAKEN_FROM_ARGUMENT4 = -3 # num bytes is 4-byte signed little-endian int TAKEN_FROM_ARGUMENT4U = -4 # num bytes is 4-byte unsigned little-endian int +TAKEN_FROM_ARGUMENT8U = -5 # num bytes is 8-byte unsigned little-endian int class ArgumentDescriptor(object): __slots__ = ( @@ -175,7 +177,7 @@ 'name', # length of argument, in bytes; an int; UP_TO_NEWLINE and - # TAKEN_FROM_ARGUMENT{1,4} are negative values for variable-length + # TAKEN_FROM_ARGUMENT{1,4,8} are negative values for variable-length # cases 'n', @@ -196,7 +198,8 @@ n in (UP_TO_NEWLINE, TAKEN_FROM_ARGUMENT1, TAKEN_FROM_ARGUMENT4, - TAKEN_FROM_ARGUMENT4U)) + TAKEN_FROM_ARGUMENT4U, + TAKEN_FROM_ARGUMENT8U)) self.n = n self.reader = reader @@ -288,6 +291,27 @@ doc="Four-byte unsigned integer, little-endian.") +def read_uint8(f): + r""" + >>> import io + >>> read_uint8(io.BytesIO(b'\xff\x00\x00\x00\x00\x00\x00\x00')) + 255 + >>> read_uint8(io.BytesIO(b'\xff' * 8)) == 2**64-1 + True + """ + + data = f.read(8) + if len(data) == 8: + return _unpack(">> import io @@ -381,6 +405,36 @@ a single blank separating the two strings. """) + +def read_string1(f): + r""" + >>> import io + >>> read_string1(io.BytesIO(b"\x00")) + '' + >>> read_string1(io.BytesIO(b"\x03abcdef")) + 'abc' + """ + + n = read_uint1(f) + assert n >= 0 + data = f.read(n) + if len(data) == n: + return data.decode("latin-1") + raise ValueError("expected %d bytes in a string1, but only %d remain" % + (n, len(data))) + +string1 = ArgumentDescriptor( + name="string1", + n=TAKEN_FROM_ARGUMENT1, + reader=read_string1, + doc="""A counted string. + + The first argument is a 1-byte unsigned int giving the number + of bytes in the string, and the second argument is that many + bytes. + """) + + def read_string4(f): r""" >>> import io @@ -415,28 +469,28 @@ """) -def read_string1(f): +def read_bytes1(f): r""" >>> import io - >>> read_string1(io.BytesIO(b"\x00")) - '' - >>> read_string1(io.BytesIO(b"\x03abcdef")) - 'abc' + >>> read_bytes1(io.BytesIO(b"\x00")) + b'' + >>> read_bytes1(io.BytesIO(b"\x03abcdef")) + b'abc' """ n = read_uint1(f) assert n >= 0 data = f.read(n) if len(data) == n: - return data.decode("latin-1") - raise ValueError("expected %d bytes in a string1, but only %d remain" % + return data + raise ValueError("expected %d bytes in a bytes1, but only %d remain" % (n, len(data))) -string1 = ArgumentDescriptor( - name="string1", +bytes1 = ArgumentDescriptor( + name="bytes1", n=TAKEN_FROM_ARGUMENT1, - reader=read_string1, - doc="""A counted string. + reader=read_bytes1, + doc="""A counted bytes string. The first argument is a 1-byte unsigned int giving the number of bytes in the string, and the second argument is that many @@ -486,6 +540,7 @@ """ n = read_uint4(f) + assert n >= 0 if n > sys.maxsize: raise ValueError("bytes4 byte count > sys.maxsize: %d" % n) data = f.read(n) @@ -505,6 +560,39 @@ """) +def read_bytes8(f): + r""" + >>> import io + >>> read_bytes8(io.BytesIO(b"\x00\x00\x00\x00\x00\x00\x00\x00abc")) + b'' + >>> read_bytes8(io.BytesIO(b"\x03\x00\x00\x00\x00\x00\x00\x00abcdef")) + b'abc' + >>> read_bytes8(io.BytesIO(b"\x00\x00\x00\x00\x00\x00\x03\x00abcdef")) + Traceback (most recent call last): + ... + ValueError: expected 844424930131968 bytes in a bytes8, but only 6 remain + """ + + n = read_uint8(f) + assert n >= 0 + if n > sys.maxsize: + raise ValueError("bytes8 byte count > sys.maxsize: %d" % n) + data = f.read(n) + if len(data) == n: + return data + raise ValueError("expected %d bytes in a bytes8, but only %d remain" % + (n, len(data))) + +bytes8 = ArgumentDescriptor( + name="bytes8", + n=TAKEN_FROM_ARGUMENT8U, + reader=read_bytes8, + doc="""A counted bytes string. + + The first argument is a 8-byte little-endian unsigned int giving + the number of bytes, and the second argument is that many bytes. + """) + def read_unicodestringnl(f): r""" >>> import io @@ -530,6 +618,46 @@ escape sequences. """) + +def read_unicodestring1(f): + r""" + >>> import io + >>> s = 'abcd\uabcd' + >>> enc = s.encode('utf-8') + >>> enc + b'abcd\xea\xaf\x8d' + >>> n = bytes([len(enc)]) # little-endian 1-byte length + >>> t = read_unicodestring1(io.BytesIO(n + enc + b'junk')) + >>> s == t + True + + >>> read_unicodestring1(io.BytesIO(n + enc[:-1])) + Traceback (most recent call last): + ... + ValueError: expected 7 bytes in a unicodestring1, but only 6 remain + """ + + n = read_uint1(f) + assert n >= 0 + data = f.read(n) + if len(data) == n: + return str(data, 'utf-8', 'surrogatepass') + raise ValueError("expected %d bytes in a unicodestring1, but only %d " + "remain" % (n, len(data))) + +unicodestring1 = ArgumentDescriptor( + name="unicodestring1", + n=TAKEN_FROM_ARGUMENT1, + reader=read_unicodestring1, + doc="""A counted Unicode string. + + The first argument is a 1-byte little-endian signed int + giving the number of bytes in the string, and the second + argument-- the UTF-8 encoding of the Unicode string -- + contains that many bytes. + """) + + def read_unicodestring4(f): r""" >>> import io @@ -549,6 +677,7 @@ """ n = read_uint4(f) + assert n >= 0 if n > sys.maxsize: raise ValueError("unicodestring4 byte count > sys.maxsize: %d" % n) data = f.read(n) @@ -570,6 +699,47 @@ """) +def read_unicodestring8(f): + r""" + >>> import io + >>> s = 'abcd\uabcd' + >>> enc = s.encode('utf-8') + >>> enc + b'abcd\xea\xaf\x8d' + >>> n = bytes([len(enc)]) + bytes(7) # little-endian 8-byte length + >>> t = read_unicodestring8(io.BytesIO(n + enc + b'junk')) + >>> s == t + True + + >>> read_unicodestring8(io.BytesIO(n + enc[:-1])) + Traceback (most recent call last): + ... + ValueError: expected 7 bytes in a unicodestring8, but only 6 remain + """ + + n = read_uint8(f) + assert n >= 0 + if n > sys.maxsize: + raise ValueError("unicodestring8 byte count > sys.maxsize: %d" % n) + data = f.read(n) + if len(data) == n: + return str(data, 'utf-8', 'surrogatepass') + raise ValueError("expected %d bytes in a unicodestring8, but only %d " + "remain" % (n, len(data))) + +unicodestring8 = ArgumentDescriptor( + name="unicodestring8", + n=TAKEN_FROM_ARGUMENT8U, + reader=read_unicodestring8, + doc="""A counted Unicode string. + + The first argument is a 8-byte little-endian signed int + giving the number of bytes in the string, and the second + argument-- the UTF-8 encoding of the Unicode string -- + contains that many bytes. + """) + + def read_decimalnl_short(f): r""" >>> import io @@ -859,6 +1029,16 @@ obtype=dict, doc="A Python dict object.") +pyset = StackObject( + name="set", + obtype=set, + doc="A Python set object.") + +pyfrozenset = StackObject( + name="frozenset", + obtype=set, + doc="A Python frozenset object.") + anyobject = StackObject( name='any', obtype=object, @@ -1142,6 +1322,19 @@ literally as the string content. """), + I(name='BINBYTES8', + code='\x8e', + arg=bytes8, + stack_before=[], + stack_after=[pybytes], + proto=4, + doc="""Push a Python bytes object. + + There are two arguments: the first is a 8-byte unsigned int giving + the number of bytes in the string, and the second is that many bytes, + which are taken literally as the string content. + """), + # Ways to spell None. I(name='NONE', @@ -1190,6 +1383,19 @@ until the next newline character. """), + I(name='SHORT_BINUNICODE', + code='\x8c', + arg=unicodestring1, + stack_before=[], + stack_after=[pyunicode], + proto=4, + doc="""Push a Python Unicode string object. + + There are two arguments: the first is a 1-byte little-endian signed int + giving the number of bytes in the string. The second is that many + bytes, and is the UTF-8 encoding of the Unicode string. + """), + I(name='BINUNICODE', code='X', arg=unicodestring4, @@ -1203,6 +1409,19 @@ bytes, and is the UTF-8 encoding of the Unicode string. """), + I(name='BINUNICODE8', + code='\x8d', + arg=unicodestring8, + stack_before=[], + stack_after=[pyunicode], + proto=4, + doc="""Push a Python Unicode string object. + + There are two arguments: the first is a 8-byte little-endian signed int + giving the number of bytes in the string. The second is that many + bytes, and is the UTF-8 encoding of the Unicode string. + """), + # Ways to spell floats. I(name='FLOAT', @@ -1428,6 +1647,62 @@ 1, 2, ..., n, and in that order. """), + # Ways to build sets + + I(name='EMPTY_SET', + code='\x8f', + arg=None, + stack_before=[], + stack_after=[pyset], + proto=4, + doc="Push an empty set."), + + I(name='ADDITEMS', + code='\x90', + arg=None, + stack_before=[pyset, markobject, stackslice], + stack_after=[pyset], + proto=4, + doc="""Add an arbitrary number of items to an existing set. + + The slice of the stack following the topmost markobject is taken as + a sequence of items, added to the set immediately under the topmost + markobject. Everything at and after the topmost markobject is popped, + leaving the mutated set at the top of the stack. + + Stack before: ... pyset markobject item_1 ... item_n + Stack after: ... pyset + + where pyset has been modified via pyset.add(item_i) = item_i for i in + 1, 2, ..., n, and in that order. + """), + + # Ways to build frozensets + + I(name='EMPTY_FROZENSET', + code='\x91', + arg=None, + stack_before=[], + stack_after=[pyfrozenset], + proto=4, + doc="Push an empty frozenset."), + + I(name='FROZENSET', + code='\x92', + arg=None, + stack_before=[markobject, stackslice], + stack_after=[pyfrozenset], + proto=4, + doc="""Build a frozenset out of the topmost slice, after markobject. + + All the stack entries following the topmost markobject are placed into + a single Python frozenset, which single frozenset object replaces all + of the stack from the topmost markobject onward. For example, + + Stack before: ... markobject 1 2 3 + Stack after: ... frozenset({1, 2, 3}) + """), + # Stack manipulation. I(name='POP', @@ -1614,6 +1889,15 @@ stack, so unpickling subclasses can override this form of lookup. """), + I(name='STACK_GLOBAL', + code='\x94', + arg=None, + stack_before=[pyunicode, pyunicode], + stack_after=[anyobject], + proto=0, + doc="""Push a global object (module.attr) on the stack. + """), + # Ways to build objects of classes pickle doesn't know about directly # (user-defined classes). I despair of documenting this accurately # and comprehensibly -- you really have to read the pickle code to @@ -1770,6 +2054,21 @@ onto the stack. """), + I(name='NEWOBJ_EX', + code='\x93', + arg=None, + stack_before=[anyobject, anyobject, anyobject], + stack_after=[anyobject], + proto=4, + doc="""Build an object instance. + + The stack before should be thought of as containing a class + object followed by an argument tuple and by a keyword argument dict + (the dict being the stack top). Call these cls and args. They are + popped off the stack, and the value returned by + cls.__new__(cls, *args, *kwargs) is pushed back onto the stack. + """), + # Machine control. I(name='PROTO', @@ -1903,42 +2202,19 @@ ############################################################################## # A pickle opcode generator. -def genops(pickle): - """Generate all the opcodes in a pickle. - - 'pickle' is a file-like object, or string, containing the pickle. - - Each opcode in the pickle is generated, from the current pickle position, - stopping after a STOP opcode is delivered. A triple is generated for - each opcode: - - opcode, arg, pos - - opcode is an OpcodeInfo record, describing the current opcode. - - If the opcode has an argument embedded in the pickle, arg is its decoded - value, as a Python object. If the opcode doesn't have an argument, arg - is None. - - If the pickle has a tell() method, pos was the value of pickle.tell() - before reading the current opcode. If the pickle is a bytes object, - it's wrapped in a BytesIO object, and the latter's tell() result is - used. Else (the pickle doesn't have a tell(), and it's not obvious how - to query its current position) pos is None. - """ - - if isinstance(pickle, bytes_types): - import io - pickle = io.BytesIO(pickle) - - if hasattr(pickle, "tell"): - getpos = pickle.tell - else: - getpos = lambda: None +def _genops(data, yield_end_pos=False): + if isinstance(data, bytes_types): + data = io.BytesIO(data) + + unframer = pickle._Unframer(data.read, data.readline, + getattr(data, "tell", None)) + getpos = unframer.tell while True: - pos = getpos() - code = pickle.read(1) + code = unframer.read(1) + # So that the opcode's actual pos is announced, not the frame start + arg_pos = getpos() + pos = arg_pos - 1 if arg_pos is not None else None opcode = code2op.get(code.decode("latin-1")) if opcode is None: if code == b"": @@ -1950,38 +2226,81 @@ if opcode.arg is None: arg = None else: - arg = opcode.arg.reader(pickle) - yield opcode, arg, pos + arg = opcode.arg.reader(unframer) + if yield_end_pos: + yield opcode, arg, pos, getpos() + else: + yield opcode, arg, pos if code == b'.': assert opcode.name == 'STOP' break + elif code == b'\x80': + assert opcode.name == 'PROTO' + if arg >= 4: + unframer.framing_enabled = True + +def genops(pickle): + """Generate all the opcodes in a pickle. + + 'pickle' is a file-like object, or string, containing the pickle. + + Each opcode in the pickle is generated, from the current pickle position, + stopping after a STOP opcode is delivered. A triple is generated for + each opcode: + + opcode, arg, pos + + opcode is an OpcodeInfo record, describing the current opcode. + + If the opcode has an argument embedded in the pickle, arg is its decoded + value, as a Python object. If the opcode doesn't have an argument, arg + is None. + + If the pickle has a tell() method, pos was the value of pickle.tell() + before reading the current opcode. If the pickle is a bytes object, + it's wrapped in a BytesIO object, and the latter's tell() result is + used. Else (the pickle doesn't have a tell(), and it's not obvious how + to query its current position) pos is None. + """ + return _genops(pickle) ############################################################################## # A pickle optimizer. def optimize(p): 'Optimize a pickle string by removing unused PUT opcodes' - gets = set() # set of args used by a GET opcode - puts = [] # (arg, startpos, stoppos) for the PUT opcodes - prevpos = None # set to pos if previous opcode was a PUT - for opcode, arg, pos in genops(p): - if prevpos is not None: - puts.append((prevarg, prevpos, pos)) - prevpos = None + not_a_put = object() + gets = { not_a_put } # set of args used by a GET opcode + opcodes = [] # (startpos, stoppos, putid) + proto = 0 + for opcode, arg, pos, end_pos in _genops(p, yield_end_pos=True): if 'PUT' in opcode.name: - prevarg, prevpos = arg, pos - elif 'GET' in opcode.name: - gets.add(arg) - - # Copy the pickle string except for PUTS without a corresponding GET - s = [] - i = 0 - for arg, start, stop in puts: - j = stop if (arg in gets) else start - s.append(p[i:j]) - i = stop - s.append(p[i:]) - return b''.join(s) + opcodes.append((pos, end_pos, arg)) + else: + if 'GET' in opcode.name: + gets.add(arg) + elif opcode.name == 'PROTO': + assert pos == 0 + proto = arg + opcodes.append((pos, end_pos, not_a_put)) + prevpos, prevarg = pos, None + + # Copy the opcodes except for PUTS without a corresponding GET + out = io.BytesIO() + opcodes = iter(opcodes) + if proto >= 2: + # Write the PROTO header before any framing + start, stop, _ = next(opcodes) + out.write(p[start:stop]) + buf = pickle._Framer(out.write) + if proto >= 4: + buf.start_framing() + for start, stop, putid in opcodes: + if putid in gets: + buf.write(p[start:stop]) + if proto >= 4: + buf.end_framing() + return out.getvalue() ############################################################################## # A symbolic pickle disassembler. diff -r a75b88048339 -r 8434af450da0 Lib/test/pickletester.py --- a/Lib/test/pickletester.py Thu Nov 14 16:16:29 2013 -0800 +++ b/Lib/test/pickletester.py Fri Nov 15 03:07:56 2013 -0800 @@ -1,9 +1,10 @@ +import copyreg import io -import unittest import pickle import pickletools +import random import sys -import copyreg +import unittest import weakref from http.cookies import SimpleCookie @@ -95,6 +96,9 @@ def __getinitargs__(self): return () +class H(object): + pass + import __main__ __main__.C = C C.__module__ = "__main__" @@ -102,6 +106,8 @@ D.__module__ = "__main__" __main__.E = E E.__module__ = "__main__" +__main__.H = H +H.__module__ = "__main__" class myint(int): def __init__(self, x): @@ -574,6 +580,26 @@ self.assertEqual(list(x.keys()), [1]) self.assertTrue(x[1] is x) + def test_recursive_set(self): + h = H() + y = set({h}) + h.attr = y + for proto in protocols: + s = self.dumps(y, proto) + x = self.loads(s) + self.assertIs(list(x)[0].attr, x) + self.assertEqual(len(x), 1) + + def test_recursive_frozenset(self): + h = H() + y = frozenset({h}) + h.attr = y + for proto in protocols: + s = self.dumps(y, proto) + x = self.loads(s) + self.assertIs(list(x)[0].attr, x) + self.assertEqual(len(x), 1) + def test_recursive_inst(self): i = C() i.attr = i @@ -751,7 +777,12 @@ for proto in protocols: expected = build_none if proto >= 2: - expected = pickle.PROTO + bytes([proto]) + expected + build_proto = pickle.PROTO + bytes([proto]) + if proto >= 4: + # A 2-bytes frame + expected = build_proto + bytes([2] + [0] * 7) + expected + else: + expected = build_proto + expected p = self.dumps(None, proto) self.assertEqual(p, expected) @@ -817,7 +848,7 @@ s = self.dumps(x, proto) y = self.loads(s) self.assertEqual(x, y, (proto, x, s, y)) - expected = expected_opcode[proto, len(x)] + expected = expected_opcode[min(proto, 3), len(x)] self.assertEqual(opcode_in_pickle(expected, s), True) def test_singletons(self): @@ -842,7 +873,7 @@ s = self.dumps(x, proto) y = self.loads(s) self.assertTrue(x is y, (proto, x, s, y)) - expected = expected_opcode[proto, x] + expected = expected_opcode[min(proto, 3), x] self.assertEqual(opcode_in_pickle(expected, s), True) def test_newobj_tuple(self): @@ -990,12 +1021,40 @@ else: self.assertTrue(num_setitems >= 2) + def test_set_chunking(self): + n = 10 # too small to chunk + x = set(range(n)) + for proto in protocols: + s = self.dumps(x, proto) + y = self.loads(s) + self.assertEqual(x, y) + num_additems = count_opcode(pickle.ADDITEMS, s) + if proto < 4: + self.assertEqual(num_additems, 0) + else: + self.assertEqual(num_additems, 1) + + n = 2500 # expect at least two chunks when proto >= 4 + x = set(range(n)) + for proto in protocols: + s = self.dumps(x, proto) + y = self.loads(s) + self.assertEqual(x, y) + num_additems = count_opcode(pickle.ADDITEMS, s) + if proto < 4: + self.assertEqual(num_additems, 0) + else: + self.assertGreaterEqual(num_additems, 2) + def test_simple_newobj(self): x = object.__new__(SimpleNewObj) # avoid __init__ x.abc = 666 for proto in protocols: s = self.dumps(x, proto) - self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), proto >= 2) + self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), + 2 <= proto < 4) + self.assertEqual(opcode_in_pickle(pickle.NEWOBJ_EX, s), + proto >= 4) y = self.loads(s) # will raise TypeError if __init__ called self.assertEqual(y.abc, 666) self.assertEqual(x.__dict__, y.__dict__) @@ -1058,11 +1117,10 @@ @no_tracing def test_bad_getattr(self): + # Issue #3514: crash when there is an infinite loop in __getattr__ x = BadGetattr() - for proto in 0, 1: + for proto in protocols: self.assertRaises(RuntimeError, self.dumps, x, proto) - # protocol 2 don't raise a RuntimeError. - d = self.dumps(x, 2) def test_reduce_bad_iterator(self): # Issue4176: crash when 4th and 5th items of __reduce__() @@ -1158,7 +1216,9 @@ sizes = [len(self.dumps(2**n, proto)) for n in range(70)] # the size function is monotonic self.assertEqual(sorted(sizes), sizes) - if proto >= 2: + if proto >= 4: + self.assertLessEqual(sizes[-1], 22) + elif proto >= 2: self.assertLessEqual(sizes[-1], 14) def check_negative_32b_binXXX(self, dumped): @@ -1242,6 +1302,134 @@ else: self._check_pickling_with_opcode(obj, pickle.SETITEMS, proto) + # Exercise framing (proto >= 4) for significant workloads + + FRAME_SIZE_TARGET = 64 * 1024 + + def test_framing_many_objects(self): + obj = list(range(10**5)) + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + pickled = self.dumps(obj, proto) + unpickled = self.loads(pickled) + self.assertEqual(obj, unpickled) + # Test the framing heuristic is sane, + # assuming a given frame size target. + bytes_per_frame = len(pickled) / pickled.count(b'\x00\x00\x00\x00\x00') + self.assertGreater(bytes_per_frame, self.FRAME_SIZE_TARGET / 2) + self.assertLessEqual(bytes_per_frame, self.FRAME_SIZE_TARGET * 1) + + def test_framing_large_objects(self): + N = 1024 * 1024 + obj = [b'x' * N, b'y' * N, b'z' * N] + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + pickled = self.dumps(obj, proto) + unpickled = self.loads(pickled) + self.assertEqual(obj, unpickled) + # At least one frame was emitted per large bytes object. + n_frames = pickled.count(b'\x00\x00\x00\x00\x00') + self.assertGreaterEqual(n_frames, len(obj)) + + def test_nested_names(self): + global Nested + class Nested: + class A: + class B: + class C: + pass + + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for obj in [Nested.A, Nested.A.B, Nested.A.B.C]: + with self.subTest(proto=proto, obj=obj): + unpickled = self.loads(self.dumps(obj, proto)) + self.assertIs(obj, unpickled) + + def test_py_methods(self): + global PyMethodsTest + class PyMethodsTest: + @staticmethod + def cheese(): + return "cheese" + @classmethod + def wine(cls): + assert cls is PyMethodsTest + return "wine" + def biscuits(self): + assert isinstance(self, PyMethodsTest) + return "biscuits" + class Nested: + "Nested class" + @staticmethod + def ketchup(): + return "ketchup" + @classmethod + def maple(cls): + assert cls is PyMethodsTest.Nested + return "maple" + def pie(self): + assert isinstance(self, PyMethodsTest.Nested) + return "pie" + + py_methods = ( + PyMethodsTest.cheese, + PyMethodsTest.wine, + PyMethodsTest().biscuits, + PyMethodsTest.Nested.ketchup, + PyMethodsTest.Nested.maple, + PyMethodsTest.Nested().pie + ) + py_unbound_methods = ( + (PyMethodsTest.biscuits, PyMethodsTest), + (PyMethodsTest.Nested.pie, PyMethodsTest.Nested) + ) + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for method in py_methods: + with self.subTest(proto=proto, method=method): + unpickled = self.loads(self.dumps(method, proto)) + self.assertEqual(method(), unpickled()) + for method, cls in py_unbound_methods: + obj = cls() + with self.subTest(proto=proto, method=method): + unpickled = self.loads(self.dumps(method, proto)) + self.assertEqual(method(obj), unpickled(obj)) + + + def test_c_methods(self): + global Subclass + class Subclass(tuple): + class Nested(str): + pass + + c_methods = ( + # bound built-in method + ("abcd".index, ("c",)), + # unbound built-in method + (str.index, ("abcd", "c")), + # bound "slot" method + ([1, 2, 3].__len__, ()), + # unbound "slot" method + (list.__len__, ([1, 2, 3],)), + # bound "coexist" method + ({1, 2}.__contains__, (2,)), + # unbound "coexist" method + (set.__contains__, ({1, 2}, 2)), + # built-in class method + (dict.fromkeys, (("a", 1), ("b", 2))), + # built-in static method + (bytearray.maketrans, (b"abc", b"xyz")), + # subclass methods + (Subclass([1,2,2]).count, (2,)), + (Subclass.count, (Subclass([1,2,2]), 2)), + (Subclass.Nested("sweet").count, ("e",)), + (Subclass.Nested.count, (Subclass.Nested("sweet"), "e")), + ) + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for method, args in c_methods: + with self.subTest(proto=proto, method=method): + unpickled = self.loads(self.dumps(method, proto)) + self.assertEqual(method(*args), unpickled(*args)) + class BigmemPickleTests(unittest.TestCase): @@ -1252,10 +1440,11 @@ data = 1 << (8 * size) try: for proto in protocols: - if proto < 2: - continue - with self.assertRaises((ValueError, OverflowError)): - self.dumps(data, protocol=proto) + with self.subTest(proto=proto): + if proto < 2: + continue + with self.assertRaises((ValueError, OverflowError)): + self.dumps(data, protocol=proto) finally: data = None @@ -1268,14 +1457,15 @@ data = b"abcd" * (size // 4) try: for proto in protocols: - if proto < 3: - continue - try: - pickled = self.dumps(data, protocol=proto) - self.assertTrue(b"abcd" in pickled[:15]) - self.assertTrue(b"abcd" in pickled[-15:]) - finally: - pickled = None + with self.subTest(proto=proto): + if proto < 3: + continue + try: + pickled = self.dumps(data, protocol=proto) + self.assertTrue(b"abcd" in pickled[:19]) + self.assertTrue(b"abcd" in pickled[-18:]) + finally: + pickled = None finally: data = None @@ -1284,10 +1474,11 @@ data = b"a" * size try: for proto in protocols: - if proto < 3: - continue - with self.assertRaises((ValueError, OverflowError)): - self.dumps(data, protocol=proto) + with self.subTest(proto=proto): + if proto < 3: + continue + with self.assertRaises((ValueError, OverflowError)): + self.dumps(data, protocol=proto) finally: data = None @@ -1299,27 +1490,38 @@ data = "abcd" * (size // 4) try: for proto in protocols: - try: - pickled = self.dumps(data, protocol=proto) - self.assertTrue(b"abcd" in pickled[:15]) - self.assertTrue(b"abcd" in pickled[-15:]) - finally: - pickled = None + with self.subTest(proto=proto): + try: + pickled = self.dumps(data, protocol=proto) + self.assertTrue(b"abcd" in pickled[:19]) + self.assertTrue(b"abcd" in pickled[-18:]) + finally: + pickled = None finally: data = None - # BINUNICODE (protocols 1, 2 and 3) cannot carry more than - # 2**32 - 1 bytes of utf-8 encoded unicode. + # BINUNICODE (protocols 1, 2 and 3) cannot carry more than 2**32 - 1 bytes + # of utf-8 encoded unicode. BINUNICODE8 (protocol 4) supports these huge + # unicode strings however. - @bigmemtest(size=_4G, memuse=1 + ascii_char_size, dry_run=False) + @bigmemtest(size=_4G, memuse=2 + ascii_char_size, dry_run=False) def test_huge_str_64b(self, size): - data = "a" * size + data = "abcd" * (size // 4) try: for proto in protocols: - if proto == 0: - continue - with self.assertRaises((ValueError, OverflowError)): - self.dumps(data, protocol=proto) + with self.subTest(proto=proto): + if proto == 0: + continue + if proto < 4: + with self.assertRaises((ValueError, OverflowError)): + self.dumps(data, protocol=proto) + else: + try: + pickled = self.dumps(data, protocol=proto) + self.assertTrue(b"abcd" in pickled[:19]) + self.assertTrue(b"abcd" in pickled[-18:]) + finally: + pickled = None finally: data = None @@ -1415,10 +1617,16 @@ class MyDict(dict): sample = {"a": 1, "b": 2} +class MySet(set): + sample = {"a", "b"} + +class MyFrozenSet(frozenset): + sample = frozenset({"a", "b"}) + myclasses = [MyInt, MyFloat, MyComplex, MyStr, MyUnicode, - MyTuple, MyList, MyDict] + MyTuple, MyList, MyDict, MySet, MyFrozenSet] class SlotList(MyList): @@ -1429,6 +1637,7 @@ # raise an error, to make sure this isn't called raise TypeError("SimpleNewObj.__init__() didn't expect to get called") + class BadGetattr: def __getattr__(self, key): self.foo @@ -1464,7 +1673,7 @@ def test_highest_protocol(self): # Of course this needs to be changed when HIGHEST_PROTOCOL changes. - self.assertEqual(pickle.HIGHEST_PROTOCOL, 3) + self.assertEqual(pickle.HIGHEST_PROTOCOL, 4) def test_callapi(self): f = io.BytesIO() @@ -1645,22 +1854,23 @@ def _check_multiple_unpicklings(self, ioclass): for proto in protocols: - data1 = [(x, str(x)) for x in range(2000)] + [b"abcde", len] - f = ioclass() - pickler = self.pickler_class(f, protocol=proto) - pickler.dump(data1) - pickled = f.getvalue() + with self.subTest(proto=proto): + data1 = [(x, str(x)) for x in range(2000)] + [b"abcde", len] + f = ioclass() + pickler = self.pickler_class(f, protocol=proto) + pickler.dump(data1) + pickled = f.getvalue() - N = 5 - f = ioclass(pickled * N) - unpickler = self.unpickler_class(f) - for i in range(N): - if f.seekable(): - pos = f.tell() - self.assertEqual(unpickler.load(), data1) - if f.seekable(): - self.assertEqual(f.tell(), pos + len(pickled)) - self.assertRaises(EOFError, unpickler.load) + N = 5 + f = ioclass(pickled * N) + unpickler = self.unpickler_class(f) + for i in range(N): + if f.seekable(): + pos = f.tell() + self.assertEqual(unpickler.load(), data1) + if f.seekable(): + self.assertEqual(f.tell(), pos + len(pickled)) + self.assertRaises(EOFError, unpickler.load) def test_multiple_unpicklings_seekable(self): self._check_multiple_unpicklings(io.BytesIO) diff -r a75b88048339 -r 8434af450da0 Lib/test/test_descr.py --- a/Lib/test/test_descr.py Thu Nov 14 16:16:29 2013 -0800 +++ b/Lib/test/test_descr.py Fri Nov 15 03:07:56 2013 -0800 @@ -1,8 +1,11 @@ import builtins +import copyreg import gc +import itertools +import math +import pickle import sys import types -import math import unittest import weakref @@ -3142,176 +3145,6 @@ self.assertEqual(e.a, 1) self.assertEqual(can_delete_dict(e), can_delete_dict(ValueError())) - def test_pickles(self): - # Testing pickling and copying new-style classes and objects... - import pickle - - def sorteditems(d): - L = list(d.items()) - L.sort() - return L - - global C - class C(object): - def __init__(self, a, b): - super(C, self).__init__() - self.a = a - self.b = b - def __repr__(self): - return "C(%r, %r)" % (self.a, self.b) - - global C1 - class C1(list): - def __new__(cls, a, b): - return super(C1, cls).__new__(cls) - def __getnewargs__(self): - return (self.a, self.b) - def __init__(self, a, b): - self.a = a - self.b = b - def __repr__(self): - return "C1(%r, %r)<%r>" % (self.a, self.b, list(self)) - - global C2 - class C2(int): - def __new__(cls, a, b, val=0): - return super(C2, cls).__new__(cls, val) - def __getnewargs__(self): - return (self.a, self.b, int(self)) - def __init__(self, a, b, val=0): - self.a = a - self.b = b - def __repr__(self): - return "C2(%r, %r)<%r>" % (self.a, self.b, int(self)) - - global C3 - class C3(object): - def __init__(self, foo): - self.foo = foo - def __getstate__(self): - return self.foo - def __setstate__(self, foo): - self.foo = foo - - global C4classic, C4 - class C4classic: # classic - pass - class C4(C4classic, object): # mixed inheritance - pass - - for bin in 0, 1: - for cls in C, C1, C2: - s = pickle.dumps(cls, bin) - cls2 = pickle.loads(s) - self.assertTrue(cls2 is cls) - - a = C1(1, 2); a.append(42); a.append(24) - b = C2("hello", "world", 42) - s = pickle.dumps((a, b), bin) - x, y = pickle.loads(s) - self.assertEqual(x.__class__, a.__class__) - self.assertEqual(sorteditems(x.__dict__), sorteditems(a.__dict__)) - self.assertEqual(y.__class__, b.__class__) - self.assertEqual(sorteditems(y.__dict__), sorteditems(b.__dict__)) - self.assertEqual(repr(x), repr(a)) - self.assertEqual(repr(y), repr(b)) - # Test for __getstate__ and __setstate__ on new style class - u = C3(42) - s = pickle.dumps(u, bin) - v = pickle.loads(s) - self.assertEqual(u.__class__, v.__class__) - self.assertEqual(u.foo, v.foo) - # Test for picklability of hybrid class - u = C4() - u.foo = 42 - s = pickle.dumps(u, bin) - v = pickle.loads(s) - self.assertEqual(u.__class__, v.__class__) - self.assertEqual(u.foo, v.foo) - - # Testing copy.deepcopy() - import copy - for cls in C, C1, C2: - cls2 = copy.deepcopy(cls) - self.assertTrue(cls2 is cls) - - a = C1(1, 2); a.append(42); a.append(24) - b = C2("hello", "world", 42) - x, y = copy.deepcopy((a, b)) - self.assertEqual(x.__class__, a.__class__) - self.assertEqual(sorteditems(x.__dict__), sorteditems(a.__dict__)) - self.assertEqual(y.__class__, b.__class__) - self.assertEqual(sorteditems(y.__dict__), sorteditems(b.__dict__)) - self.assertEqual(repr(x), repr(a)) - self.assertEqual(repr(y), repr(b)) - - def test_pickle_slots(self): - # Testing pickling of classes with __slots__ ... - import pickle - # Pickling of classes with __slots__ but without __getstate__ should fail - # (if using protocol 0 or 1) - global B, C, D, E - class B(object): - pass - for base in [object, B]: - class C(base): - __slots__ = ['a'] - class D(C): - pass - try: - pickle.dumps(C(), 0) - except TypeError: - pass - else: - self.fail("should fail: pickle C instance - %s" % base) - try: - pickle.dumps(C(), 0) - except TypeError: - pass - else: - self.fail("should fail: pickle D instance - %s" % base) - # Give C a nice generic __getstate__ and __setstate__ - class C(base): - __slots__ = ['a'] - def __getstate__(self): - try: - d = self.__dict__.copy() - except AttributeError: - d = {} - for cls in self.__class__.__mro__: - for sn in cls.__dict__.get('__slots__', ()): - try: - d[sn] = getattr(self, sn) - except AttributeError: - pass - return d - def __setstate__(self, d): - for k, v in list(d.items()): - setattr(self, k, v) - class D(C): - pass - # Now it should work - x = C() - y = pickle.loads(pickle.dumps(x)) - self.assertEqual(hasattr(y, 'a'), 0) - x.a = 42 - y = pickle.loads(pickle.dumps(x)) - self.assertEqual(y.a, 42) - x = D() - x.a = 42 - x.b = 100 - y = pickle.loads(pickle.dumps(x)) - self.assertEqual(y.a + y.b, 142) - # A subclass that adds a slot should also work - class E(C): - __slots__ = ['b'] - x = E() - x.a = 42 - x.b = "foo" - y = pickle.loads(pickle.dumps(x)) - self.assertEqual(y.a, x.a) - self.assertEqual(y.b, x.b) - def test_binary_operator_override(self): # Testing overrides of binary operations... class I(int): @@ -4679,11 +4512,439 @@ self.assertEqual(X.mykey2, 'from Base2') +class PicklingTests(unittest.TestCase): + + def _check_reduce(self, proto, obj, args=(), kwargs={}, state=None, + listitems=None, dictitems=None): + if proto >= 4: + reduce_value = obj.__reduce_ex__(proto) + self.assertEqual(reduce_value[:3], + (copyreg.__newobj_ex__, + (type(obj), args, kwargs), + state)) + if listitems is not None: + self.assertListEqual(list(reduce_value[3]), listitems) + else: + self.assertIsNone(reduce_value[3]) + if dictitems is not None: + self.assertDictEqual(dict(reduce_value[4]), dictitems) + else: + self.assertIsNone(reduce_value[4]) + elif proto >= 2: + reduce_value = obj.__reduce_ex__(proto) + self.assertEqual(reduce_value[:3], + (copyreg.__newobj__, + (type(obj),) + args, + state)) + if listitems is not None: + self.assertListEqual(list(reduce_value[3]), listitems) + else: + self.assertIsNone(reduce_value[3]) + if dictitems is not None: + self.assertDictEqual(dict(reduce_value[4]), dictitems) + else: + self.assertIsNone(reduce_value[4]) + else: + base_type = type(obj).__base__ + reduce_value = (copyreg._reconstructor, + (type(obj), + base_type, + None if base_type is object else base_type(obj))) + if state is not None: + reduce_value += (state,) + self.assertEqual(obj.__reduce_ex__(proto), reduce_value) + self.assertEqual(obj.__reduce__(), reduce_value) + + def test_reduce(self): + protocols = range(pickle.HIGHEST_PROTOCOL + 1) + args = (-101, "spam") + kwargs = {'bacon': -201, 'fish': -301} + state = {'cheese': -401} + + class C1: + def __getnewargs__(self): + return args + obj = C1() + for proto in protocols: + self._check_reduce(proto, obj, args) + + for name, value in state.items(): + setattr(obj, name, value) + for proto in protocols: + self._check_reduce(proto, obj, args, state=state) + + class C2: + def __getnewargs__(self): + return "bad args" + obj = C2() + for proto in protocols: + if proto >= 2: + with self.assertRaises(TypeError): + obj.__reduce_ex__(proto) + + class C3: + def __getnewargs_ex__(self): + return (args, kwargs) + obj = C3() + for proto in protocols: + if proto >= 4: + self._check_reduce(proto, obj, args, kwargs) + elif proto >= 2: + with self.assertRaises(ValueError): + obj.__reduce_ex__(proto) + + class C4: + def __getnewargs_ex__(self): + return (args, "bad dict") + class C5: + def __getnewargs_ex__(self): + return ("bad tuple", kwargs) + class C6: + def __getnewargs_ex__(self): + return () + class C7: + def __getnewargs_ex__(self): + return "bad args" + for proto in protocols: + for cls in C4, C5, C6, C7: + obj = cls() + if proto >= 2: + with self.assertRaises((TypeError, ValueError)): + obj.__reduce_ex__(proto) + + class C8: + def __getnewargs_ex__(self): + return (args, kwargs) + obj = C8() + for proto in protocols: + if 2 <= proto < 4: + with self.assertRaises(ValueError): + obj.__reduce_ex__(proto) + class C9: + def __getnewargs_ex__(self): + return (args, {}) + obj = C9() + for proto in protocols: + self._check_reduce(proto, obj, args) + + class C10: + def __getnewargs_ex__(self): + raise IndexError + obj = C10() + for proto in protocols: + if proto >= 2: + with self.assertRaises(IndexError): + obj.__reduce_ex__(proto) + + class C11: + def __getstate__(self): + return state + obj = C11() + for proto in protocols: + self._check_reduce(proto, obj, state=state) + + class C12: + def __getstate__(self): + return "not dict" + obj = C12() + for proto in protocols: + self._check_reduce(proto, obj, state="not dict") + + class C13: + def __getstate__(self): + raise IndexError + obj = C13() + for proto in protocols: + with self.assertRaises(IndexError): + obj.__reduce_ex__(proto) + if proto < 2: + with self.assertRaises(IndexError): + obj.__reduce__() + + class C14: + __slots__ = tuple(state) + def __init__(self): + for name, value in state.items(): + setattr(self, name, value) + + obj = C14() + for proto in protocols: + if proto >= 2: + self._check_reduce(proto, obj, state=(None, state)) + else: + with self.assertRaises(TypeError): + obj.__reduce_ex__(proto) + with self.assertRaises(TypeError): + obj.__reduce__() + + class C15(dict): + pass + obj = C15({"quebec": -601}) + for proto in protocols: + self._check_reduce(proto, obj, dictitems=dict(obj)) + + class C16(list): + pass + obj = C16(["yukon"]) + for proto in protocols: + self._check_reduce(proto, obj, listitems=list(obj)) + + def _assert_is_copy(self, obj, objcopy, msg=None): + """Utility method to verify if two objects are copies of each others. + """ + if msg is None: + msg = "{!r} is not a copy of {!r}".format(obj, objcopy) + if type(obj).__repr__ is object.__repr__: + # We have this limitation for now because we use the object's repr + # to help us verify that the two objects are copies. This allows + # us to delegate the non-generic verification logic to the objects + # themselves. + raise ValueError("object passed to _assert_is_copy must " + + "override the __repr__ method.") + self.assertIsNot(obj, objcopy, msg=msg) + self.assertIs(type(obj), type(objcopy), msg=msg) + if hasattr(obj, '__dict__'): + self.assertDictEqual(obj.__dict__, objcopy.__dict__, msg=msg) + self.assertIsNot(obj.__dict__, objcopy.__dict__, msg=msg) + if hasattr(obj, '__slots__'): + self.assertListEqual(obj.__slots__, objcopy.__slots__, msg=msg) + for slot in obj.__slots__: + self.assertEqual( + hasattr(obj, slot), hasattr(objcopy, slot), msg=msg) + self.assertEqual(getattr(obj, slot, None), + getattr(objcopy, slot, None), msg=msg) + self.assertEqual(repr(obj), repr(objcopy), msg=msg) + + @staticmethod + def _generate_pickle_copiers(): + """Utility method to generate the many possible pickle configurations. + """ + class PickleCopier: + "This class copies object using pickle." + def __init__(self, proto, dumps, loads): + self.proto = proto + self.dumps = dumps + self.loads = loads + def copy(self, obj): + return self.loads(self.dumps(obj, self.proto)) + def __repr__(self): + # We try to be as descriptive as possible here since this is + # the string which we will allow us to tell the pickle + # configuration we are using during debugging. + return ("PickleCopier(proto={}, dumps={}.{}, loads={}.{})" + .format(self.proto, + self.dumps.__module__, self.dumps.__qualname__, + self.loads.__module__, self.loads.__qualname__)) + return (PickleCopier(*args) for args in + itertools.product(range(pickle.HIGHEST_PROTOCOL + 1), + {pickle.dumps, pickle._dumps}, + {pickle.loads, pickle._loads})) + + def test_pickle_slots(self): + # Tests pickling of classes with __slots__. + + # Pickling of classes with __slots__ but without __getstate__ should + # fail (if using protocol 0 or 1) + global C + class C: + __slots__ = ['a'] + with self.assertRaises(TypeError): + pickle.dumps(C(), 0) + + global D + class D(C): + pass + with self.assertRaises(TypeError): + pickle.dumps(D(), 0) + + class C: + "A class with __getstate__ and __setstate__ implemented." + __slots__ = ['a'] + def __getstate__(self): + state = getattr(self, '__dict__', {}).copy() + for cls in type(self).__mro__: + for slot in cls.__dict__.get('__slots__', ()): + try: + state[slot] = getattr(self, slot) + except AttributeError: + pass + return state + def __setstate__(self, state): + for k, v in state.items(): + setattr(self, k, v) + def __repr__(self): + return "%s()<%r>" % (type(self).__name__, self.__getstate__()) + + class D(C): + "A subclass of a class with slots." + pass + + global E + class E(C): + "A subclass with an extra slot." + __slots__ = ['b'] + + # Now it should work + for pickle_copier in self._generate_pickle_copiers(): + with self.subTest(pickle_copier=pickle_copier): + x = C() + y = pickle_copier.copy(x) + self._assert_is_copy(x, y) + + x.a = 42 + y = pickle_copier.copy(x) + self._assert_is_copy(x, y) + + x = D() + x.a = 42 + x.b = 100 + y = pickle_copier.copy(x) + self._assert_is_copy(x, y) + + x = E() + x.a = 42 + x.b = "foo" + y = pickle_copier.copy(x) + self._assert_is_copy(x, y) + + def test_reduce_copying(self): + # Tests pickling and copying new-style classes and objects. + global C1 + class C1: + "The state of this class is copyable via its instance dict." + ARGS = (1, 2) + NEED_DICT_COPYING = True + def __init__(self, a, b): + super().__init__() + self.a = a + self.b = b + def __repr__(self): + return "C1(%r, %r)" % (self.a, self.b) + + global C2 + class C2(list): + "A list subclass copyable via __getnewargs__." + ARGS = (1, 2) + NEED_DICT_COPYING = False + def __new__(cls, a, b): + self = super().__new__(cls) + self.a = a + self.b = b + return self + def __init__(self, *args): + super().__init__() + # This helps testing that __init__ is not called during the + # unpickling process, which would cause extra appends. + self.append("cheese") + @classmethod + def __getnewargs__(cls): + return cls.ARGS + def __repr__(self): + return "C2(%r, %r)<%r>" % (self.a, self.b, list(self)) + + global C3 + class C3(list): + "A list subclass copyable via __getstate__." + ARGS = (1, 2) + NEED_DICT_COPYING = False + def __init__(self, a, b): + self.a = a + self.b = b + # This helps testing that __init__ is not called during the + # unpickling process, which would cause extra appends. + self.append("cheese") + @classmethod + def __getstate__(cls): + return cls.ARGS + def __setstate__(self, state): + a, b = state + self.a = a + self.b = b + def __repr__(self): + return "C3(%r, %r)<%r>" % (self.a, self.b, list(self)) + + global C4 + class C4(int): + "An int subclass copyable via __getnewargs__." + ARGS = ("hello", "world", 1) + NEED_DICT_COPYING = False + def __new__(cls, a, b, value): + self = super().__new__(cls, value) + self.a = a + self.b = b + return self + @classmethod + def __getnewargs__(cls): + return cls.ARGS + def __repr__(self): + return "C4(%r, %r)<%r>" % (self.a, self.b, int(self)) + + global C5 + class C5(int): + "An int subclass copyable via __getnewargs_ex__." + ARGS = (1, 2) + KWARGS = {'value': 3} + NEED_DICT_COPYING = False + def __new__(cls, a, b, *, value=0): + self = super().__new__(cls, value) + self.a = a + self.b = b + return self + @classmethod + def __getnewargs_ex__(cls): + return (cls.ARGS, cls.KWARGS) + def __repr__(self): + return "C5(%r, %r)<%r>" % (self.a, self.b, int(self)) + + test_classes = (C1, C2, C3, C4, C5) + # Testing copying through pickle + pickle_copiers = self._generate_pickle_copiers() + for cls, pickle_copier in itertools.product(test_classes, pickle_copiers): + with self.subTest(cls=cls, pickle_copier=pickle_copier): + kwargs = getattr(cls, 'KWARGS', {}) + obj = cls(*cls.ARGS, **kwargs) + proto = pickle_copier.proto + if 2 <= proto < 4 and hasattr(cls, '__getnewargs_ex__'): + with self.assertRaises(ValueError): + pickle_copier.dumps(obj, proto) + continue + objcopy = pickle_copier.copy(obj) + self._assert_is_copy(obj, objcopy) + # For test classes that supports this, make sure we didn't go + # around the reduce protocol by simply copying the attribute + # dictionary. We clear attributes using the previous copy to + # not mutate the original argument. + if proto >= 2 and not cls.NEED_DICT_COPYING: + objcopy.__dict__.clear() + objcopy2 = pickle_copier.copy(objcopy) + self._assert_is_copy(obj, objcopy2) + + # Testing copying through copy.deepcopy() + for cls in test_classes: + with self.subTest(cls=cls): + kwargs = getattr(cls, 'KWARGS', {}) + obj = cls(*cls.ARGS, **kwargs) + # XXX: We need to modify the copy module to support PEP 3154's + # reduce protocol 4. + if hasattr(cls, '__getnewargs_ex__'): + continue + objcopy = deepcopy(obj) + self._assert_is_copy(obj, objcopy) + # For test classes that supports this, make sure we didn't go + # around the reduce protocol by simply copying the attribute + # dictionary. We clear attributes using the previous copy to + # not mutate the original argument. + if not cls.NEED_DICT_COPYING: + objcopy.__dict__.clear() + objcopy2 = deepcopy(objcopy) + self._assert_is_copy(obj, objcopy2) + + def test_main(): # Run all local test cases, with PTypesLongInitTest first. support.run_unittest(PTypesLongInitTest, OperatorsTest, ClassPropertiesAndMethods, DictProxyTests, - MiscTests) + MiscTests, PicklingTests) if __name__ == "__main__": test_main() diff -r a75b88048339 -r 8434af450da0 Modules/_pickle.c --- a/Modules/_pickle.c Thu Nov 14 16:16:29 2013 -0800 +++ b/Modules/_pickle.c Fri Nov 15 03:07:56 2013 -0800 @@ -6,7 +6,7 @@ /* Bump this when new opcodes are added to the pickle protocol. */ enum { - HIGHEST_PROTOCOL = 3, + HIGHEST_PROTOCOL = 4, DEFAULT_PROTOCOL = 3 }; @@ -71,7 +71,18 @@ /* Protocol 3 (Python 3.x) */ BINBYTES = 'B', - SHORT_BINBYTES = 'C' + SHORT_BINBYTES = 'C', + + /* Protocol 4 */ + SHORT_BINUNICODE = '\x8c', + BINUNICODE8 = '\x8d', + BINBYTES8 = '\x8e', + EMPTY_SET = '\x8f', + ADDITEMS = '\x90', + EMPTY_FROZENSET = '\x91', + FROZENSET = '\x92', + NEWOBJ_EX = '\x93', + STACK_GLOBAL = '\x94' }; /* These aren't opcodes -- they're ways to pickle bools before protocol 2 @@ -103,7 +114,9 @@ MAX_WRITE_BUF_SIZE = 64 * 1024, /* Prefetch size when unpickling (disabled on unpeekable streams) */ - PREFETCH = 8192 * 16 + PREFETCH = 8192 * 16, + + FRAME_SIZE_TARGET = 64 * 1024 }; /* Exception classes for pickle. These should override the ones defined in @@ -136,9 +149,6 @@ /* For looking up name pairs in copyreg._extension_registry. */ static PyObject *two_tuple = NULL; -_Py_IDENTIFIER(__name__); -_Py_IDENTIFIER(modules); - static int stack_underflow(void) { @@ -332,7 +342,12 @@ Py_ssize_t max_output_len; /* Allocation size of output_buffer. */ int proto; /* Pickle protocol number, >= 0 */ int bin; /* Boolean, true if proto > 0 */ - Py_ssize_t buf_size; /* Size of the current buffered pickle data */ + int framing; /* True when framing is enabled, proto >= 4 */ + Py_ssize_t frame_start; /* Position in output_buffer where the + where the current frame begins. -1 if there + is no frame currently open. */ + + Py_ssize_t buf_size; /* Size of the current buffered pickle data */ int fast; /* Enable fast mode if set to a true value. The fast mode disable the usage of memo, therefore speeding the pickling process by @@ -362,7 +377,9 @@ char *input_line; Py_ssize_t input_len; Py_ssize_t next_read_idx; + Py_ssize_t frame_end_idx; Py_ssize_t prefetched_idx; /* index of first prefetched byte */ + PyObject *read; /* read() method of the input stream. */ PyObject *readline; /* readline() method of the input stream. */ PyObject *peek; /* peek() method of the input stream, or NULL */ @@ -380,6 +397,7 @@ int proto; /* Protocol of the pickle loaded. */ int fix_imports; /* Indicate whether Unpickler should fix the name of globals pickled by Python 2.x. */ + int framing; /* True when framing is enabled, proto >= 4 */ } UnpicklerObject; /* Forward declarations */ @@ -673,15 +691,62 @@ if (self->output_buffer == NULL) return -1; self->output_len = 0; + self->frame_start = -1; return 0; } +static void +_Pickler_WriteFrameHeader(PicklerObject *self, char *qdata, size_t frame_len) +{ + qdata[0] = (unsigned char)(frame_len & 0xff); + qdata[1] = (unsigned char)((frame_len >> 8) & 0xff); + qdata[2] = (unsigned char)((frame_len >> 16) & 0xff); + qdata[3] = (unsigned char)((frame_len >> 24) & 0xff); + qdata[4] = (unsigned char)((frame_len >> 32) & 0xff); + qdata[5] = (unsigned char)((frame_len >> 40) & 0xff); + qdata[6] = (unsigned char)((frame_len >> 48) & 0xff); + qdata[7] = (unsigned char)((frame_len >> 56) & 0xff); +} + +static int +_Pickler_CommitFrame(PicklerObject *self) +{ + size_t frame_len; + char *qdata; + + if (!self->framing || self->frame_start == -1) + return 0; + frame_len = self->output_len - self->frame_start - 8; + qdata = PyBytes_AS_STRING(self->output_buffer) + self->frame_start; + _Pickler_WriteFrameHeader(self, qdata, frame_len); + self->frame_start = -1; + return 0; +} + +static int +_Pickler_OpcodeBoundary(PicklerObject *self) +{ + Py_ssize_t frame_len; + + if (!self->framing || self->frame_start == -1) + return 0; + frame_len = self->output_len - self->frame_start - 8; + if (frame_len >= FRAME_SIZE_TARGET) + return _Pickler_CommitFrame(self); + else + return 0; +} + static PyObject * _Pickler_GetString(PicklerObject *self) { PyObject *output_buffer = self->output_buffer; assert(self->output_buffer != NULL); + + if (_Pickler_CommitFrame(self)) + return NULL; + self->output_buffer = NULL; /* Resize down to exact size */ if (_PyBytes_Resize(&output_buffer, self->output_len) < 0) @@ -696,6 +761,7 @@ assert(self->write != NULL); + /* This will commit the frame first */ output = _Pickler_GetString(self); if (output == NULL) return -1; @@ -706,57 +772,93 @@ } static Py_ssize_t -_Pickler_Write(PicklerObject *self, const char *s, Py_ssize_t n) -{ - Py_ssize_t i, required; +_Pickler_Write(PicklerObject *self, const char *s, Py_ssize_t data_len) +{ + Py_ssize_t i, n, required; char *buffer; + int need_new_frame; assert(s != NULL); + need_new_frame = (self->framing && self->frame_start == -1); + + if (need_new_frame) + n = data_len + 8; + else + n = data_len; required = self->output_len + n; - if (required > self->max_output_len) { - if (self->write != NULL && required > MAX_WRITE_BUF_SIZE) { - /* XXX This reallocates a new buffer every time, which is a bit - wasteful. */ - if (_Pickler_FlushToFile(self) < 0) - return -1; - if (_Pickler_ClearBuffer(self) < 0) - return -1; - } - if (self->write != NULL && n > MAX_WRITE_BUF_SIZE) { - /* we already flushed above, so the buffer is empty */ - PyObject *result; - /* XXX we could spare an intermediate copy and pass - a memoryview instead */ - PyObject *output = PyBytes_FromStringAndSize(s, n); - if (s == NULL) + if (self->write != NULL && required > MAX_WRITE_BUF_SIZE) { + /* XXX This reallocates a new buffer every time, which is a bit + wasteful. */ + if (_Pickler_FlushToFile(self) < 0) + return -1; + if (_Pickler_ClearBuffer(self) < 0) + return -1; + /* The previous frame was just committed by _Pickler_FlushToFile */ + need_new_frame = self->framing; + if (need_new_frame) + n = data_len + 8; + else + n = data_len; + required = self->output_len + n; + } + if (self->write != NULL && n > MAX_WRITE_BUF_SIZE) { + /* For large pickle chunks, we write directly to the output + file instead of buffering. Note the buffer is empty at this + point (it was flushed above, since required >= n). */ + PyObject *output, *result; + if (need_new_frame) { + char frame_header[8]; + _Pickler_WriteFrameHeader(self, frame_header, (size_t) data_len); + output = PyBytes_FromStringAndSize(frame_header, 8); + if (output == NULL) return -1; result = _Pickler_FastCall(self, self->write, output); Py_XDECREF(result); - return (result == NULL) ? -1 : 0; - } - else { - if (self->output_len >= PY_SSIZE_T_MAX / 2 - n) { - PyErr_NoMemory(); - return -1; - } - self->max_output_len = (self->output_len + n) / 2 * 3; - if (_PyBytes_Resize(&self->output_buffer, self->max_output_len) < 0) + if (result == NULL) return -1; } + /* XXX we could spare an intermediate copy and pass + a memoryview instead */ + output = PyBytes_FromStringAndSize(s, data_len); + if (output == NULL) + return -1; + result = _Pickler_FastCall(self, self->write, output); + Py_XDECREF(result); + return (result == NULL) ? -1 : 0; + } + if (required > self->max_output_len) { + /* Make place in buffer for the pickle chunk */ + if (self->output_len >= PY_SSIZE_T_MAX / 2 - n) { + PyErr_NoMemory(); + return -1; + } + self->max_output_len = (self->output_len + n) / 2 * 3; + if (_PyBytes_Resize(&self->output_buffer, self->max_output_len) < 0) + return -1; } buffer = PyBytes_AS_STRING(self->output_buffer); - if (n < 8) { + if (need_new_frame) { + /* Setup new frame */ + Py_ssize_t frame_start = self->output_len; + self->frame_start = frame_start; + for (i = 0; i < 8; i++) { + /* Write an invalid value, for debugging */ + buffer[frame_start + i] = 0xFE; + } + self->output_len += 8; + } + if (data_len < 8) { /* This is faster than memcpy when the string is short. */ - for (i = 0; i < n; i++) { + for (i = 0; i < data_len; i++) { buffer[self->output_len + i] = s[i]; } } else { - memcpy(buffer + self->output_len, s, n); - } - self->output_len += n; - return n; + memcpy(buffer + self->output_len, s, data_len); + } + self->output_len += data_len; + return data_len; } static PicklerObject * @@ -774,6 +876,8 @@ self->write = NULL; self->proto = 0; self->bin = 0; + self->framing = 0; + self->frame_start = -1; self->fast = 0; self->fast_nesting = 0; self->fix_imports = 0; @@ -868,6 +972,7 @@ self->input_buffer = self->buffer.buf; self->input_len = self->buffer.len; self->next_read_idx = 0; + self->frame_end_idx = -1; self->prefetched_idx = self->input_len; return self->input_len; } @@ -932,7 +1037,7 @@ return -1; /* Prefetch some data without advancing the file pointer, if possible */ - if (self->peek) { + if (self->peek && !self->framing) { PyObject *len, *prefetched; len = PyLong_FromSsize_t(PREFETCH); if (len == NULL) { @@ -980,7 +1085,7 @@ Returns -1 (with an exception set) on failure. On success, return the number of chars read. */ static Py_ssize_t -_Unpickler_Read(UnpicklerObject *self, char **s, Py_ssize_t n) +_Unpickler_ReadUnframed(UnpicklerObject *self, char **s, Py_ssize_t n) { Py_ssize_t num_read; @@ -1006,6 +1111,61 @@ } static Py_ssize_t +_Unpickler_Read(UnpicklerObject *self, char **s, Py_ssize_t n) +{ + if (self->framing && + (self->frame_end_idx == -1 || + self->frame_end_idx <= self->next_read_idx)) { + /* Need to read new frame */ + char *dummy; + unsigned char *frame_start; + size_t frame_len; + if (_Unpickler_ReadUnframed(self, &dummy, 8) < 0) + return -1; + frame_start = (unsigned char *) dummy; + frame_len = (size_t) frame_start[0]; + frame_len |= (size_t) frame_start[1] << 8; + frame_len |= (size_t) frame_start[2] << 16; + frame_len |= (size_t) frame_start[3] << 24; +#if SIZEOF_SIZE_T >= 8 + frame_len |= (size_t) frame_start[4] << 32; + frame_len |= (size_t) frame_start[5] << 40; + frame_len |= (size_t) frame_start[6] << 48; + frame_len |= (size_t) frame_start[7] << 56; +#else + if (frame_start[4] || frame_start[5] || + frame_start[6] || frame_start[7]) { + PyErr_Format(PyExc_OverflowError, + "Frame size too large for 32-bit build"); + return -1; + } +#endif + if (frame_len > PY_SSIZE_T_MAX) { + PyErr_Format(UnpicklingError, "Invalid frame length"); + return -1; + } + if (frame_len < n) { + PyErr_Format(UnpicklingError, "Bad framing"); + return -1; + } + if (_Unpickler_ReadUnframed(self, &dummy /* unused */, + frame_len) < 0) + return -1; + /* Rewind to start of frame */ + self->frame_end_idx = self->next_read_idx; + self->next_read_idx -= frame_len; + } + if (self->framing) { + /* Check for bad input */ + if (n + self->next_read_idx > self->frame_end_idx) { + PyErr_Format(UnpicklingError, "Bad framing"); + return -1; + } + } + return _Unpickler_ReadUnframed(self, s, n); +} + +static Py_ssize_t _Unpickler_CopyLine(UnpicklerObject *self, char *line, Py_ssize_t len, char **result) { @@ -1150,6 +1310,7 @@ self->input_line = NULL; self->input_len = 0; self->next_read_idx = 0; + self->frame_end_idx = -1; self->prefetched_idx = 0; self->read = NULL; self->readline = NULL; @@ -1160,6 +1321,7 @@ self->num_marks = 0; self->marks_size = 0; self->proto = 0; + self->framing = 0; self->fix_imports = 0; memset(&self->buffer, 0, sizeof(Py_buffer)); self->memo_size = 32; @@ -1284,6 +1446,8 @@ if (self->fast) return 0; + if (_Pickler_OpcodeBoundary(self)) + goto error; x = PyMemoTable_Size(self->memo); if (PyMemoTable_Set(self->memo, obj, x) < 0) @@ -1328,44 +1492,87 @@ } static PyObject * -whichmodule(PyObject *global, PyObject *global_name) -{ - Py_ssize_t i, j; - static PyObject *module_str = NULL; - static PyObject *main_str = NULL; +getattribute(PyObject *obj, PyObject *name, int allow_qualname) { + PyObject *dotted_path; + Py_ssize_t i; + _Py_static_string(PyId_dot, "."); + _Py_static_string(PyId_locals, ""); + + dotted_path = PyUnicode_Split(name, _PyUnicode_FromId(&PyId_dot), -1); + if (dotted_path == NULL) { + return NULL; + } + assert(Py_SIZE(dotted_path) >= 1); + if (!allow_qualname && Py_SIZE(dotted_path) > 1) { + PyErr_Format(PyExc_AttributeError, + "Can't get qualified attribute %R on %R;" + "use protocols >= 4 to enable support", + name, obj); + Py_DECREF(dotted_path); + return NULL; + } + Py_INCREF(obj); + for (i = 0; i < Py_SIZE(dotted_path); i++) { + PyObject *subpath = PyList_GET_ITEM(dotted_path, i); + PyObject *tmp; + PyObject *result = PyUnicode_RichCompare( + subpath, _PyUnicode_FromId(&PyId_locals), Py_EQ); + int is_equal = (result == Py_True); + assert(PyBool_Check(result)); + Py_DECREF(result); + if (is_equal) { + PyErr_Format(PyExc_AttributeError, + "Can't get local attribute %R on %R", name, obj); + Py_DECREF(dotted_path); + Py_DECREF(obj); + return NULL; + } + tmp = PyObject_GetAttr(obj, subpath); + Py_DECREF(obj); + if (tmp == NULL) { + if (PyErr_ExceptionMatches(PyExc_AttributeError)) { + PyErr_Clear(); + PyErr_Format(PyExc_AttributeError, + "Can't get attribute %R on %R", name, obj); + } + Py_DECREF(dotted_path); + return NULL; + } + obj = tmp; + } + Py_DECREF(dotted_path); + return obj; +} + +static PyObject * +whichmodule(PyObject *global, PyObject *global_name, int allow_qualname) +{ PyObject *module_name; PyObject *modules_dict; PyObject *module; PyObject *obj; - - if (module_str == NULL) { - module_str = PyUnicode_InternFromString("__module__"); - if (module_str == NULL) + Py_ssize_t i, j; + _Py_IDENTIFIER(__module__); + _Py_IDENTIFIER(modules); + _Py_IDENTIFIER(__main__); + + module_name = _PyObject_GetAttrId(global, &PyId___module__); + + if (module_name == NULL) { + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) return NULL; - main_str = PyUnicode_InternFromString("__main__"); - if (main_str == NULL) - return NULL; - } - - module_name = PyObject_GetAttr(global, module_str); - - /* In some rare cases (e.g., bound methods of extension types), - __module__ can be None. If it is so, then search sys.modules - for the module of global. */ - if (module_name == Py_None) { - Py_DECREF(module_name); - goto search; - } - - if (module_name) { - return module_name; - } - if (PyErr_ExceptionMatches(PyExc_AttributeError)) PyErr_Clear(); - else - return NULL; - - search: + } + else { + /* In some rare cases (e.g., bound methods of extension types), + __module__ can be None. If it is so, then search sys.modules for + the module of global. */ + if (module_name != Py_None) + return module_name; + Py_CLEAR(module_name); + } + assert(module_name == NULL); + modules_dict = _PySys_GetObjectId(&PyId_modules); if (modules_dict == NULL) { PyErr_SetString(PyExc_RuntimeError, "unable to get sys.modules"); @@ -1373,34 +1580,35 @@ } i = 0; - module_name = NULL; while ((j = PyDict_Next(modules_dict, &i, &module_name, &module))) { - if (PyObject_RichCompareBool(module_name, main_str, Py_EQ) == 1) + PyObject *result = PyUnicode_RichCompare( + module_name, _PyUnicode_FromId(&PyId___main__), Py_EQ); + int is_equal = (result == Py_True); + assert(PyBool_Check(result)); + Py_DECREF(result); + if (is_equal) continue; - - obj = PyObject_GetAttr(module, global_name); + if (module == Py_None) + continue; + + obj = getattribute(module, global_name, allow_qualname); if (obj == NULL) { - if (PyErr_ExceptionMatches(PyExc_AttributeError)) - PyErr_Clear(); - else + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) return NULL; + PyErr_Clear(); continue; } - if (obj != global) { + if (obj == global) { Py_DECREF(obj); - continue; + Py_INCREF(module_name); + return module_name; } - Py_DECREF(obj); - break; } /* If no module is found, use __main__. */ - if (!j) { - module_name = main_str; - } - + module_name = _PyUnicode_FromId(&PyId___main__); Py_INCREF(module_name); return module_name; } @@ -1744,22 +1952,17 @@ reduce_value = Py_BuildValue("(O())", (PyObject*)&PyBytes_Type); } else { - static PyObject *latin1 = NULL; PyObject *unicode_str = PyUnicode_DecodeLatin1(PyBytes_AS_STRING(obj), PyBytes_GET_SIZE(obj), "strict"); + _Py_IDENTIFIER(latin1); + if (unicode_str == NULL) return -1; - if (latin1 == NULL) { - latin1 = PyUnicode_InternFromString("latin1"); - if (latin1 == NULL) { - Py_DECREF(unicode_str); - return -1; - } - } reduce_value = Py_BuildValue("(O(OO))", - codecs_encode, unicode_str, latin1); + codecs_encode, unicode_str, + _PyUnicode_FromId(&PyId_latin1)); Py_DECREF(unicode_str); } @@ -1773,14 +1976,14 @@ } else { Py_ssize_t size; - char header[5]; + char header[9]; Py_ssize_t len; size = PyBytes_GET_SIZE(obj); if (size < 0) return -1; - if (size < 256) { + if (size <= 0xff) { header[0] = SHORT_BINBYTES; header[1] = (unsigned char)size; len = 2; @@ -1793,6 +1996,14 @@ header[4] = (unsigned char)((size >> 24) & 0xff); len = 5; } + else if (self->proto >= 4) { + int i; + header[0] = BINBYTES8; + for (i = 0; i < 8; i++) { + header[i+1] = (unsigned char)((size >> (8 * i)) & 0xff); + } + len = 8; + } else { PyErr_SetString(PyExc_OverflowError, "cannot serialize a bytes object larger than 4 GiB"); @@ -1882,26 +2093,39 @@ static int write_utf8(PicklerObject *self, char *data, Py_ssize_t size) { - char pdata[5]; - -#if SIZEOF_SIZE_T > 4 - if (size > 0xffffffffUL) { - /* string too large */ + char header[9]; + Py_ssize_t len; + + if (size <= 0xff && self->proto >= 4) { + header[0] = SHORT_BINUNICODE; + header[1] = (unsigned char)(size & 0xff); + len = 2; + } + else if (size <= 0xffffffffUL) { + header[0] = BINUNICODE; + header[1] = (unsigned char)(size & 0xff); + header[2] = (unsigned char)((size >> 8) & 0xff); + header[3] = (unsigned char)((size >> 16) & 0xff); + header[4] = (unsigned char)((size >> 24) & 0xff); + len = 5; + } + else if (self->proto >= 4) { + int i; + + header[0] = BINUNICODE8; + for (i = 0; i < 8; i++) { + header[i+1] = (unsigned char)((size >> (8 * i)) & 0xff); + } + len = 9; + } + else { PyErr_SetString(PyExc_OverflowError, "cannot serialize a string larger than 4GiB"); return -1; } -#endif - - pdata[0] = BINUNICODE; - pdata[1] = (unsigned char)(size & 0xff); - pdata[2] = (unsigned char)((size >> 8) & 0xff); - pdata[3] = (unsigned char)((size >> 16) & 0xff); - pdata[4] = (unsigned char)((size >> 24) & 0xff); - - if (_Pickler_Write(self, pdata, sizeof(pdata)) < 0) - return -1; - + + if (_Pickler_Write(self, header, len) < 0) + return -1; if (_Pickler_Write(self, data, size) < 0) return -1; @@ -2598,6 +2822,223 @@ } static int +save_set(PicklerObject *self, PyObject *obj) +{ + PyObject *item; + int i; + Py_ssize_t set_size, ppos = 0; + Py_hash_t hash; + + const char empty_set_op = EMPTY_SET; + const char mark_op = MARK; + const char additems_op = ADDITEMS; + + if (self->proto < 4) { + PyObject *items; + PyObject *reduce_value; + int status; + + items = PySequence_List(obj); + if (items == NULL) { + return -1; + } + reduce_value = Py_BuildValue("(O(O))", (PyObject*)&PySet_Type, items); + Py_DECREF(items); + if (reduce_value == NULL) { + return -1; + } + /* save_reduce() will memoize the object automatically. */ + status = save_reduce(self, reduce_value, obj); + Py_DECREF(reduce_value); + return status; + } + + if (_Pickler_Write(self, &empty_set_op, 1) < 0) + return -1; + + if (memo_put(self, obj) < 0) + return -1; + + set_size = PySet_GET_SIZE(obj); + if (set_size == 0) + return 0; /* nothing to do */ + + /* Write in batches of BATCHSIZE. */ + do { + i = 0; + if (_Pickler_Write(self, &mark_op, 1) < 0) + return -1; + while (_PySet_NextEntry(obj, &ppos, &item, &hash)) { + if (save(self, item, 0) < 0) + return -1; + if (++i == BATCHSIZE) + break; + } + if (_Pickler_Write(self, &additems_op, 1) < 0) + return -1; + if (PySet_GET_SIZE(obj) != set_size) { + PyErr_Format( + PyExc_RuntimeError, + "set changed size during iteration"); + return -1; + } + } while (i == BATCHSIZE); + + return 0; +} + +static int +save_frozenset(PicklerObject *self, PyObject *obj) +{ + PyObject *iter; + Py_ssize_t len; + + const char mark_op = MARK; + const char frozenset_op = FROZENSET; + const char empty_frozenset_op = EMPTY_FROZENSET; + + if (self->fast && !fast_save_enter(self, obj)) + return -1; + + if (self->proto < 4) { + PyObject *items; + PyObject *reduce_value; + int status; + + items = PySequence_List(obj); + if (items == NULL) { + return -1; + } + reduce_value = Py_BuildValue("(O(O))", (PyObject*)&PyFrozenSet_Type, + items); + Py_DECREF(items); + if (reduce_value == NULL) { + return -1; + } + /* save_reduce() will memoize the object automatically. */ + status = save_reduce(self, reduce_value, obj); + Py_DECREF(reduce_value); + return status; + } + + len = PySet_GET_SIZE(obj); + if (len == 0) { + if (_Pickler_Write(self, &empty_frozenset_op, 1) < 0) + return -1; + return 0; + } + + if (_Pickler_Write(self, &mark_op, 1) < 0) + return -1; + + iter = PyObject_GetIter(obj); + for (;;) { + PyObject *item; + + item = PyIter_Next(iter); + if (item == NULL) { + if (PyErr_Occurred()) { + Py_DECREF(iter); + return -1; + } + break; + } + if (save(self, item, 0) < 0) { + Py_DECREF(item); + Py_DECREF(iter); + return -1; + } + Py_DECREF(item); + } + Py_DECREF(iter); + + /* If the object is already in the memo, this means it is + recursive. In this case, throw away everything we put on the + stack, and fetch the object back from the memo. */ + if (PyMemoTable_Get(self->memo, obj)) { + const char pop_mark_op = POP_MARK; + + if (_Pickler_Write(self, &pop_mark_op, 1) < 0) + return -1; + if (memo_get(self, obj) < 0) + return -1; + return 0; + } + + if (_Pickler_Write(self, &frozenset_op, 1) < 0) + return -1; + if (memo_put(self, obj) < 0) + return -1; + + return 0; +} + +static int +fix_imports(PyObject **module_name, PyObject **global_name) +{ + PyObject *key; + PyObject *item; + + key = PyTuple_Pack(2, *module_name, *global_name); + if (key == NULL) + return -1; + item = PyDict_GetItemWithError(name_mapping_3to2, key); + Py_DECREF(key); + if (item) { + PyObject *fixed_module_name; + PyObject *fixed_global_name; + + if (!PyTuple_Check(item) || PyTuple_GET_SIZE(item) != 2) { + PyErr_Format(PyExc_RuntimeError, + "_compat_pickle.REVERSE_NAME_MAPPING values " + "should be 2-tuples, not %.200s", + Py_TYPE(item)->tp_name); + return -1; + } + fixed_module_name = PyTuple_GET_ITEM(item, 0); + fixed_global_name = PyTuple_GET_ITEM(item, 1); + if (!PyUnicode_Check(fixed_module_name) || + !PyUnicode_Check(fixed_global_name)) { + PyErr_Format(PyExc_RuntimeError, + "_compat_pickle.REVERSE_NAME_MAPPING values " + "should be pairs of str, not (%.200s, %.200s)", + Py_TYPE(fixed_module_name)->tp_name, + Py_TYPE(fixed_global_name)->tp_name); + return -1; + } + + Py_CLEAR(*module_name); + Py_CLEAR(*global_name); + Py_INCREF(fixed_module_name); + Py_INCREF(fixed_global_name); + *module_name = fixed_module_name; + *global_name = fixed_global_name; + } + else if (PyErr_Occurred()) { + return -1; + } + + item = PyDict_GetItemWithError(import_mapping_3to2, *module_name); + if (item) { + if (!PyUnicode_Check(item)) { + PyErr_Format(PyExc_RuntimeError, + "_compat_pickle.REVERSE_IMPORT_MAPPING values " + "should be strings, not %.200s", + Py_TYPE(item)->tp_name); + return -1; + } + Py_CLEAR(*module_name); + Py_INCREF(item); + *module_name = item; + } + else if (PyErr_Occurred()) { + return -1; + } + + return 0; +} + +static int save_global(PicklerObject *self, PyObject *obj, PyObject *name) { PyObject *global_name = NULL; @@ -2605,20 +3046,32 @@ PyObject *module = NULL; PyObject *cls; int status = 0; + _Py_IDENTIFIER(__name__); + _Py_IDENTIFIER(__qualname__); const char global_op = GLOBAL; if (name) { + Py_INCREF(name); global_name = name; - Py_INCREF(global_name); } else { - global_name = _PyObject_GetAttrId(obj, &PyId___name__); - if (global_name == NULL) - goto error; - } - - module_name = whichmodule(obj, global_name); + if (self->proto >= 4) { + global_name = _PyObject_GetAttrId(obj, &PyId___qualname__); + if (global_name == NULL) { + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) + goto error; + PyErr_Clear(); + } + } + if (global_name == NULL) { + global_name = _PyObject_GetAttrId(obj, &PyId___name__); + if (global_name == NULL) + goto error; + } + } + + module_name = whichmodule(obj, global_name, self->proto >= 4); if (module_name == NULL) goto error; @@ -2637,11 +3090,11 @@ obj, module_name); goto error; } - cls = PyObject_GetAttr(module, global_name); + cls = getattribute(module, global_name, self->proto >= 4); if (cls == NULL) { PyErr_Format(PicklingError, - "Can't pickle %R: attribute lookup %S.%S failed", - obj, module_name, global_name); + "Can't pickle %R: attribute lookup %S on %S failed", + obj, global_name, module_name); goto error; } if (cls != obj) { @@ -2715,120 +3168,82 @@ goto error; } else { - /* Generate a normal global opcode if we are using a pickle - protocol <= 2, or if the object is not registered in the - extension registry. */ - PyObject *encoded; - PyObject *(*unicode_encoder)(PyObject *); - gen_global: - if (_Pickler_Write(self, &global_op, 1) < 0) - goto error; - - /* Since Python 3.0 now supports non-ASCII identifiers, we encode both - the module name and the global name using UTF-8. We do so only when - we are using the pickle protocol newer than version 3. This is to - ensure compatibility with older Unpickler running on Python 2.x. */ - if (self->proto >= 3) { - unicode_encoder = PyUnicode_AsUTF8String; + if (self->proto >= 4) { + const char stack_global_op = STACK_GLOBAL; + + save(self, module_name, 0); + save(self, global_name, 0); + + if (_Pickler_Write(self, &stack_global_op, 1) < 0) + goto error; } else { - unicode_encoder = PyUnicode_AsASCIIString; - } - - /* For protocol < 3 and if the user didn't request against doing so, - we convert module names to the old 2.x module names. */ - if (self->fix_imports) { - PyObject *key; - PyObject *item; - - key = PyTuple_Pack(2, module_name, global_name); - if (key == NULL) + /* Generate a normal global opcode if we are using a pickle + protocol < 4, or if the object is not registered in the + extension registry. */ + PyObject *encoded; + PyObject *(*unicode_encoder)(PyObject *); + + if (_Pickler_Write(self, &global_op, 1) < 0) goto error; - item = PyDict_GetItemWithError(name_mapping_3to2, key); - Py_DECREF(key); - if (item) { - if (!PyTuple_Check(item) || PyTuple_GET_SIZE(item) != 2) { - PyErr_Format(PyExc_RuntimeError, - "_compat_pickle.REVERSE_NAME_MAPPING values " - "should be 2-tuples, not %.200s", - Py_TYPE(item)->tp_name); + + /* For protocol < 3 and if the user didn't request against doing + so, we convert module names to the old 2.x module names. */ + if (self->proto < 3 && self->fix_imports) { + if (fix_imports(&module_name, &global_name) < 0) { goto error; } - Py_CLEAR(module_name); - Py_CLEAR(global_name); - module_name = PyTuple_GET_ITEM(item, 0); - global_name = PyTuple_GET_ITEM(item, 1); - if (!PyUnicode_Check(module_name) || - !PyUnicode_Check(global_name)) { - PyErr_Format(PyExc_RuntimeError, - "_compat_pickle.REVERSE_NAME_MAPPING values " - "should be pairs of str, not (%.200s, %.200s)", - Py_TYPE(module_name)->tp_name, - Py_TYPE(global_name)->tp_name); - goto error; - } - Py_INCREF(module_name); - Py_INCREF(global_name); } - else if (PyErr_Occurred()) { + + /* Since Python 3.0 now supports non-ASCII identifiers, we encode + both the module name and the global name using UTF-8. We do so + only when we are using the pickle protocol newer than version + 3. This is to ensure compatibility with older Unpickler running + on Python 2.x. */ + if (self->proto == 3) { + unicode_encoder = PyUnicode_AsUTF8String; + } + else { + unicode_encoder = PyUnicode_AsASCIIString; + } + encoded = unicode_encoder(module_name); + if (encoded == NULL) { + if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError)) + PyErr_Format(PicklingError, + "can't pickle module identifier '%S' using " + "pickle protocol %i", + module_name, self->proto); goto error; } - - item = PyDict_GetItemWithError(import_mapping_3to2, module_name); - if (item) { - if (!PyUnicode_Check(item)) { - PyErr_Format(PyExc_RuntimeError, - "_compat_pickle.REVERSE_IMPORT_MAPPING values " - "should be strings, not %.200s", - Py_TYPE(item)->tp_name); - goto error; - } - Py_CLEAR(module_name); - module_name = item; - Py_INCREF(module_name); - } - else if (PyErr_Occurred()) { + if (_Pickler_Write(self, PyBytes_AS_STRING(encoded), + PyBytes_GET_SIZE(encoded)) < 0) { + Py_DECREF(encoded); goto error; } + Py_DECREF(encoded); + if(_Pickler_Write(self, "\n", 1) < 0) + goto error; + + /* Save the name of the module. */ + encoded = unicode_encoder(global_name); + if (encoded == NULL) { + if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError)) + PyErr_Format(PicklingError, + "can't pickle global identifier '%S' using " + "pickle protocol %i", + global_name, self->proto); + goto error; + } + if (_Pickler_Write(self, PyBytes_AS_STRING(encoded), + PyBytes_GET_SIZE(encoded)) < 0) { + Py_DECREF(encoded); + goto error; + } + Py_DECREF(encoded); + if (_Pickler_Write(self, "\n", 1) < 0) + goto error; } - - /* Save the name of the module. */ - encoded = unicode_encoder(module_name); - if (encoded == NULL) { - if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError)) - PyErr_Format(PicklingError, - "can't pickle module identifier '%S' using " - "pickle protocol %i", module_name, self->proto); - goto error; - } - if (_Pickler_Write(self, PyBytes_AS_STRING(encoded), - PyBytes_GET_SIZE(encoded)) < 0) { - Py_DECREF(encoded); - goto error; - } - Py_DECREF(encoded); - if(_Pickler_Write(self, "\n", 1) < 0) - goto error; - - /* Save the name of the module. */ - encoded = unicode_encoder(global_name); - if (encoded == NULL) { - if (PyErr_ExceptionMatches(PyExc_UnicodeEncodeError)) - PyErr_Format(PicklingError, - "can't pickle global identifier '%S' using " - "pickle protocol %i", global_name, self->proto); - goto error; - } - if (_Pickler_Write(self, PyBytes_AS_STRING(encoded), - PyBytes_GET_SIZE(encoded)) < 0) { - Py_DECREF(encoded); - goto error; - } - Py_DECREF(encoded); - if(_Pickler_Write(self, "\n", 1) < 0) - goto error; - /* Memoize the object. */ if (memo_put(self, obj) < 0) goto error; @@ -2927,14 +3342,9 @@ get_class(PyObject *obj) { PyObject *cls; - static PyObject *str_class; - - if (str_class == NULL) { - str_class = PyUnicode_InternFromString("__class__"); - if (str_class == NULL) - return NULL; - } - cls = PyObject_GetAttr(obj, str_class); + _Py_IDENTIFIER(__class__); + + cls = _PyObject_GetAttrId(obj, &PyId___class__); if (cls == NULL) { if (PyErr_ExceptionMatches(PyExc_AttributeError)) { PyErr_Clear(); @@ -2957,12 +3367,12 @@ PyObject *listitems = Py_None; PyObject *dictitems = Py_None; Py_ssize_t size; - - int use_newobj = self->proto >= 2; + int use_newobj = 0, use_newobj_ex = 0; const char reduce_op = REDUCE; const char build_op = BUILD; const char newobj_op = NEWOBJ; + const char newobj_ex_op = NEWOBJ_EX; size = PyTuple_Size(args); if (size < 2 || size > 5) { @@ -3007,33 +3417,75 @@ return -1; } - /* Protocol 2 special case: if callable's name is __newobj__, use - NEWOBJ. */ - if (use_newobj) { - static PyObject *newobj_str = NULL; + if (self->proto >= 2) { PyObject *name; - - if (newobj_str == NULL) { - newobj_str = PyUnicode_InternFromString("__newobj__"); - if (newobj_str == NULL) - return -1; - } + _Py_IDENTIFIER(__name__); name = _PyObject_GetAttrId(callable, &PyId___name__); if (name == NULL) { - if (PyErr_ExceptionMatches(PyExc_AttributeError)) - PyErr_Clear(); - else + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { return -1; - use_newobj = 0; + } + PyErr_Clear(); + } + else if (self->proto >= 4) { + _Py_IDENTIFIER(__newobj_ex__); + use_newobj_ex = PyUnicode_Check(name) && + PyUnicode_Compare( + name, _PyUnicode_FromId(&PyId___newobj_ex__)) == 0; + Py_DECREF(name); } else { + _Py_IDENTIFIER(__newobj__); use_newobj = PyUnicode_Check(name) && - PyUnicode_Compare(name, newobj_str) == 0; + PyUnicode_Compare( + name, _PyUnicode_FromId(&PyId___newobj__)) == 0; Py_DECREF(name); } } - if (use_newobj) { + + if (use_newobj_ex) { + PyObject *cls; + PyObject *args; + PyObject *kwargs; + + if (Py_SIZE(argtup) != 3) { + PyErr_Format(PicklingError, + "length of the NEWOBJ_EX argument tuple must be " + "exactly 3, not %zd", Py_SIZE(argtup)); + return -1; + } + + cls = PyTuple_GET_ITEM(argtup, 0); + if (!PyType_Check(cls)) { + PyErr_Format(PicklingError, + "first item from NEWOBJ_EX argument tuple must " + "be a class, not %.200s", Py_TYPE(cls)->tp_name); + return -1; + } + args = PyTuple_GET_ITEM(argtup, 1); + if (!PyTuple_Check(args)) { + PyErr_Format(PicklingError, + "second item from NEWOBJ_EX argument tuple must " + "be a tuple, not %.200s", Py_TYPE(args)->tp_name); + return -1; + } + kwargs = PyTuple_GET_ITEM(argtup, 2); + if (!PyDict_Check(kwargs)) { + PyErr_Format(PicklingError, + "third item from NEWOBJ_EX argument tuple must " + "be a dict, not %.200s", Py_TYPE(kwargs)->tp_name); + return -1; + } + + if (save(self, cls, 0) < 0 || + save(self, args, 0) < 0 || + save(self, kwargs, 0) < 0 || + _Pickler_Write(self, &newobj_ex_op, 1) < 0) { + return -1; + } + } + else if (use_newobj) { PyObject *cls; PyObject *newargtup; PyObject *obj_class; @@ -3117,8 +3569,23 @@ the caller do not want to memoize the object. Not particularly useful, but that is to mimic the behavior save_reduce() in pickle.py when obj is None. */ - if (obj && memo_put(self, obj) < 0) - return -1; + if (obj != NULL) { + /* If the object is already in the memo, this means it is + recursive. In this case, throw away everything we put on the + stack, and fetch the object back from the memo. */ + if (PyMemoTable_Get(self->memo, obj)) { + const char pop_op = POP; + + if (_Pickler_Write(self, &pop_op, 1) < 0) + return -1; + if (memo_get(self, obj) < 0) + return -1; + + return 0; + } + else if (memo_put(self, obj) < 0) + return -1; + } if (listitems && batch_list(self, listitems) < 0) return -1; @@ -3136,6 +3603,34 @@ } static int +save_method(PicklerObject *self, PyObject *obj) +{ + PyObject *method_self = PyCFunction_GET_SELF(obj); + + if (method_self == NULL || PyModule_Check(method_self)) { + return save_global(self, obj, NULL); + } + else { + PyObject *builtins; + PyObject *getattr; + PyObject *reduce_value; + int status = -1; + _Py_IDENTIFIER(getattr); + + builtins = PyEval_GetBuiltins(); + getattr = _PyDict_GetItemId(builtins, &PyId_getattr); + reduce_value = \ + Py_BuildValue("O(Os)", getattr, method_self, + ((PyCFunctionObject *)obj)->m_ml->ml_name); + if (reduce_value != NULL) { + status = save_reduce(self, reduce_value, obj); + Py_DECREF(reduce_value); + } + return status; + } +} + +static int save(PicklerObject *self, PyObject *obj, int pers_save) { PyTypeObject *type; @@ -3213,6 +3708,14 @@ status = save_dict(self, obj); goto done; } + else if (type == &PySet_Type) { + status = save_set(self, obj); + goto done; + } + else if (type == &PyFrozenSet_Type) { + status = save_frozenset(self, obj); + goto done; + } else if (type == &PyList_Type) { status = save_list(self, obj); goto done; @@ -3236,7 +3739,7 @@ } } else if (type == &PyCFunction_Type) { - status = save_global(self, obj, NULL); + status = save_method(self, obj); goto done; } @@ -3269,18 +3772,9 @@ goto done; } else { - static PyObject *reduce_str = NULL; - static PyObject *reduce_ex_str = NULL; - - /* Cache the name of the reduce methods. */ - if (reduce_str == NULL) { - reduce_str = PyUnicode_InternFromString("__reduce__"); - if (reduce_str == NULL) - goto error; - reduce_ex_str = PyUnicode_InternFromString("__reduce_ex__"); - if (reduce_ex_str == NULL) - goto error; - } + _Py_IDENTIFIER(__reduce__); + _Py_IDENTIFIER(__reduce_ex__); + /* XXX: If the __reduce__ method is defined, __reduce_ex__ is automatically defined as __reduce__. While this is convenient, this @@ -3291,7 +3785,7 @@ don't actually have to check for a __reduce__ method. */ /* Check for a __reduce_ex__ method. */ - reduce_func = PyObject_GetAttr(obj, reduce_ex_str); + reduce_func = _PyObject_GetAttrId(obj, &PyId___reduce_ex__); if (reduce_func != NULL) { PyObject *proto; proto = PyLong_FromLong(self->proto); @@ -3305,7 +3799,7 @@ else goto error; /* Check for a __reduce__ method. */ - reduce_func = PyObject_GetAttr(obj, reduce_str); + reduce_func = _PyObject_GetAttrId(obj, &PyId___reduce__); if (reduce_func != NULL) { reduce_value = PyObject_Call(reduce_func, empty_tuple, NULL); } @@ -3338,6 +3832,8 @@ status = -1; } done: + if (status == 0) + status = _Pickler_OpcodeBoundary(self); Py_LeaveRecursiveCall(); Py_XDECREF(reduce_func); Py_XDECREF(reduce_value); @@ -3358,6 +3854,8 @@ header[1] = (unsigned char)self->proto; if (_Pickler_Write(self, header, 2) < 0) return -1; + if (self->proto >= 4) + self->framing = 1; } if (save(self, obj, 0) < 0 || @@ -3478,9 +3976,9 @@ "This takes a binary file for writing a pickle data stream.\n" "\n" "The optional protocol argument tells the pickler to use the\n" -"given protocol; supported protocols are 0, 1, 2, 3. The default\n" -"protocol is 3; a backward-incompatible protocol designed for\n" -"Python 3.0.\n" +"given protocol; supported protocols are 0, 1, 2, 3 and 4. The\n" +"default protocol is 3; a backward-incompatible protocol designed for\n" +"Python 3.\n" "\n" "Specifying a negative protocol version selects the highest\n" "protocol version supported. The higher the protocol used, the\n" @@ -3493,8 +3991,8 @@ "meets this interface.\n" "\n" "If fix_imports is True and protocol is less than 3, pickle will try to\n" -"map the new Python 3.x names to the old module names used in Python\n" -"2.x, so that the pickle data stream is readable with Python 2.x.\n"); +"map the new Python 3 names to the old module names used in Python 2,\n" +"so that the pickle data stream is readable with Python 2.\n"); static int Pickler_init(PicklerObject *self, PyObject *args, PyObject *kwds) @@ -3987,17 +4485,15 @@ * as a C Py_ssize_t, or -1 if it's higher than PY_SSIZE_T_MAX. */ static Py_ssize_t -calc_binsize(char *bytes, int size) +calc_binsize(char *bytes, int nbytes) { unsigned char *s = (unsigned char *)bytes; + int i; size_t x = 0; - assert(size == 4); - - x = (size_t) s[0]; - x |= (size_t) s[1] << 8; - x |= (size_t) s[2] << 16; - x |= (size_t) s[3] << 24; + for (i = 0; i < nbytes; i++) { + x |= (size_t) s[i] << (8 * i); + } if (x > PY_SSIZE_T_MAX) return -1; @@ -4011,21 +4507,21 @@ * of x-platform bugs. */ static long -calc_binint(char *bytes, int size) +calc_binint(char *bytes, int nbytes) { unsigned char *s = (unsigned char *)bytes; - int i = size; + int i; long x = 0; - for (i = 0; i < size; i++) { - x |= (long)s[i] << (i * 8); + for (i = 0; i < nbytes; i++) { + x |= (long)s[i] << (8 * i); } /* Unlike BININT1 and BININT2, BININT (more accurately BININT4) * is signed, so on a box with longs bigger than 4 bytes we need * to extend a BININT's sign bit to the full width. */ - if (SIZEOF_LONG > 4 && size == 4) { + if (SIZEOF_LONG > 4 && nbytes == 4) { x |= -(x & (1L << 31)); } @@ -4233,26 +4729,27 @@ } static int -load_binbytes(UnpicklerObject *self) +load_counted_binbytes(UnpicklerObject *self, int nbytes) { PyObject *bytes; - Py_ssize_t x; + Py_ssize_t size; char *s; - if (_Unpickler_Read(self, &s, 4) < 0) - return -1; - - x = calc_binsize(s, 4); - if (x < 0) { + if (_Unpickler_Read(self, &s, nbytes) < 0) + return -1; + + size = calc_binsize(s, nbytes); + if (size < 0) { PyErr_Format(PyExc_OverflowError, "BINBYTES exceeds system's maximum size of %zd bytes", PY_SSIZE_T_MAX); return -1; } - if (_Unpickler_Read(self, &s, x) < 0) - return -1; - bytes = PyBytes_FromStringAndSize(s, x); + if (_Unpickler_Read(self, &s, size) < 0) + return -1; + + bytes = PyBytes_FromStringAndSize(s, size); if (bytes == NULL) return -1; @@ -4261,74 +4758,27 @@ } static int -load_short_binbytes(UnpicklerObject *self) -{ - PyObject *bytes; - Py_ssize_t x; +load_counted_binstring(UnpicklerObject *self, int nbytes) +{ + PyObject *str; + Py_ssize_t size; char *s; - if (_Unpickler_Read(self, &s, 1) < 0) - return -1; - - x = (unsigned char)s[0]; - - if (_Unpickler_Read(self, &s, x) < 0) - return -1; - - bytes = PyBytes_FromStringAndSize(s, x); - if (bytes == NULL) - return -1; - - PDATA_PUSH(self->stack, bytes, -1); - return 0; -} - -static int -load_binstring(UnpicklerObject *self) -{ - PyObject *str; - Py_ssize_t x; - char *s; - - if (_Unpickler_Read(self, &s, 4) < 0) - return -1; - - x = calc_binint(s, 4); - if (x < 0) { - PyErr_SetString(UnpicklingError, - "BINSTRING pickle has negative byte count"); - return -1; - } - - if (_Unpickler_Read(self, &s, x) < 0) - return -1; - + if (_Unpickler_Read(self, &s, nbytes) < 0) + return -1; + + size = calc_binsize(s, nbytes); + if (size < 0) { + PyErr_Format(UnpicklingError, + "BINSTRING exceeds system's maximum size of %zd bytes", + PY_SSIZE_T_MAX); + return -1; + } + + if (_Unpickler_Read(self, &s, size) < 0) + return -1; /* Convert Python 2.x strings to unicode. */ - str = PyUnicode_Decode(s, x, self->encoding, self->errors); - if (str == NULL) - return -1; - - PDATA_PUSH(self->stack, str, -1); - return 0; -} - -static int -load_short_binstring(UnpicklerObject *self) -{ - PyObject *str; - Py_ssize_t x; - char *s; - - if (_Unpickler_Read(self, &s, 1) < 0) - return -1; - - x = (unsigned char)s[0]; - - if (_Unpickler_Read(self, &s, x) < 0) - return -1; - - /* Convert Python 2.x strings to unicode. */ - str = PyUnicode_Decode(s, x, self->encoding, self->errors); + str = PyUnicode_Decode(s, size, self->encoding, self->errors); if (str == NULL) return -1; @@ -4357,16 +4807,16 @@ } static int -load_binunicode(UnpicklerObject *self) +load_counted_binunicode(UnpicklerObject *self, int nbytes) { PyObject *str; Py_ssize_t size; char *s; - if (_Unpickler_Read(self, &s, 4) < 0) - return -1; - - size = calc_binsize(s, 4); + if (_Unpickler_Read(self, &s, nbytes) < 0) + return -1; + + size = calc_binsize(s, nbytes); if (size < 0) { PyErr_Format(PyExc_OverflowError, "BINUNICODE exceeds system's maximum size of %zd bytes", @@ -4374,7 +4824,6 @@ return -1; } - if (_Unpickler_Read(self, &s, size) < 0) return -1; @@ -4446,6 +4895,28 @@ } static int +load_empty_set(UnpicklerObject *self) +{ + PyObject *set; + + if ((set = PySet_New(NULL)) == NULL) + return -1; + PDATA_PUSH(self->stack, set, -1); + return 0; +} + +static int +load_empty_frozenset(UnpicklerObject *self) +{ + PyObject *set; + + if ((set = PyFrozenSet_New(NULL)) == NULL) + return -1; + PDATA_PUSH(self->stack, set, -1); + return 0; +} + +static int load_list(UnpicklerObject *self) { PyObject *list; @@ -4487,6 +4958,29 @@ return 0; } +static int +load_frozenset(UnpicklerObject *self) +{ + PyObject *items; + PyObject *frozenset; + Py_ssize_t i; + + if ((i = marker(self)) < 0) + return -1; + + items = Pdata_poptuple(self->stack, i); + if (items == NULL) + return -1; + + frozenset = PyFrozenSet_New(items); + Py_DECREF(items); + if (frozenset == NULL) + return -1; + + PDATA_PUSH(self->stack, frozenset, -1); + return 0; +} + static PyObject * instantiate(PyObject *cls, PyObject *args) { @@ -4638,6 +5132,57 @@ } static int +load_newobj_ex(UnpicklerObject *self) +{ + PyObject *cls, *args, *kwargs; + PyObject *obj; + + PDATA_POP(self->stack, kwargs); + if (kwargs == NULL) { + return -1; + } + PDATA_POP(self->stack, args); + if (args == NULL) { + Py_DECREF(kwargs); + return -1; + } + PDATA_POP(self->stack, cls); + if (cls == NULL) { + Py_DECREF(kwargs); + Py_DECREF(args); + return -1; + } + + if (!PyType_Check(cls)) { + Py_DECREF(kwargs); + Py_DECREF(args); + Py_DECREF(cls); + PyErr_Format(UnpicklingError, + "NEWOBJ_EX class argument must be a type, not %.200s", + Py_TYPE(cls)->tp_name); + return -1; + } + + if (((PyTypeObject *)cls)->tp_new == NULL) { + Py_DECREF(kwargs); + Py_DECREF(args); + Py_DECREF(cls); + PyErr_SetString(UnpicklingError, + "NEWOBJ_EX class argument doesn't have __new__"); + return -1; + } + obj = ((PyTypeObject *)cls)->tp_new((PyTypeObject *)cls, args, kwargs); + Py_DECREF(kwargs); + Py_DECREF(args); + Py_DECREF(cls); + if (obj == NULL) { + return -1; + } + PDATA_PUSH(self->stack, obj, -1); + return 0; +} + +static int load_global(UnpicklerObject *self) { PyObject *global = NULL; @@ -4674,6 +5219,31 @@ } static int +load_stack_global(UnpicklerObject *self) +{ + PyObject *global; + PyObject *module_name; + PyObject *global_name; + + PDATA_POP(self->stack, global_name); + PDATA_POP(self->stack, module_name); + if (module_name == NULL || !PyUnicode_CheckExact(module_name) || + global_name == NULL || !PyUnicode_CheckExact(global_name)) { + PyErr_SetString(UnpicklingError, "STACK_GLOBAL requires str"); + Py_XDECREF(global_name); + Py_XDECREF(module_name); + return -1; + } + global = find_class(self, module_name, global_name); + Py_DECREF(global_name); + Py_DECREF(module_name); + if (global == NULL) + return -1; + PDATA_PUSH(self->stack, global, -1); + return 0; +} + +static int load_persid(UnpicklerObject *self) { PyObject *pid; @@ -5132,6 +5702,59 @@ } static int +load_additems(UnpicklerObject *self) +{ + PyObject *set; + Py_ssize_t mark, len, i; + + mark = marker(self); + len = Py_SIZE(self->stack); + if (mark > len || mark <= 0) + return stack_underflow(); + if (len == mark) /* nothing to do */ + return 0; + + set = self->stack->data[mark - 1]; + + if (PySet_Check(set)) { + PyObject *items; + int status; + + items = Pdata_poptuple(self->stack, mark); + if (items == NULL) + return -1; + + status = _PySet_Update(set, items); + Py_DECREF(items); + return status; + } + else { + PyObject *add_func; + _Py_IDENTIFIER(add); + + add_func = _PyObject_GetAttrId(set, &PyId_add); + if (add_func == NULL) + return -1; + for (i = mark; i < len; i++) { + PyObject *result; + PyObject *item; + + item = self->stack->data[i]; + result = _Unpickler_FastCall(self, add_func, item); + if (result == NULL) { + Pdata_clear(self->stack, i + 1); + Py_SIZE(self->stack) = mark; + return -1; + } + Py_DECREF(result); + } + Py_SIZE(self->stack) = mark; + } + + return 0; +} + +static int load_build(UnpicklerObject *self) { PyObject *state, *inst, *slotstate; @@ -5325,6 +5948,7 @@ i = (unsigned char)s[0]; if (i <= HIGHEST_PROTOCOL) { self->proto = i; + self->framing = (self->proto >= 4); return 0; } @@ -5340,6 +5964,8 @@ char *s; self->num_marks = 0; + self->proto = 0; + self->framing = 0; if (Py_SIZE(self->stack)) Pdata_clear(self->stack, 0); @@ -5365,13 +5991,16 @@ OP_ARG(LONG4, load_counted_long, 4) OP(FLOAT, load_float) OP(BINFLOAT, load_binfloat) - OP(BINBYTES, load_binbytes) - OP(SHORT_BINBYTES, load_short_binbytes) - OP(BINSTRING, load_binstring) - OP(SHORT_BINSTRING, load_short_binstring) + OP_ARG(SHORT_BINBYTES, load_counted_binbytes, 1) + OP_ARG(BINBYTES, load_counted_binbytes, 4) + OP_ARG(BINBYTES8, load_counted_binbytes, 8) + OP_ARG(SHORT_BINSTRING, load_counted_binstring, 1) + OP_ARG(BINSTRING, load_counted_binstring, 4) OP(STRING, load_string) OP(UNICODE, load_unicode) - OP(BINUNICODE, load_binunicode) + OP_ARG(SHORT_BINUNICODE, load_counted_binunicode, 1) + OP_ARG(BINUNICODE, load_counted_binunicode, 4) + OP_ARG(BINUNICODE8, load_counted_binunicode, 8) OP_ARG(EMPTY_TUPLE, load_counted_tuple, 0) OP_ARG(TUPLE1, load_counted_tuple, 1) OP_ARG(TUPLE2, load_counted_tuple, 2) @@ -5381,10 +6010,16 @@ OP(LIST, load_list) OP(EMPTY_DICT, load_empty_dict) OP(DICT, load_dict) + OP(EMPTY_SET, load_empty_set) + OP(ADDITEMS, load_additems) + OP(EMPTY_FROZENSET, load_empty_frozenset) + OP(FROZENSET, load_frozenset) OP(OBJ, load_obj) OP(INST, load_inst) OP(NEWOBJ, load_newobj) + OP(NEWOBJ_EX, load_newobj_ex) OP(GLOBAL, load_global) + OP(STACK_GLOBAL, load_stack_global) OP(APPEND, load_append) OP(APPENDS, load_appends) OP(BUILD, load_build) @@ -5485,6 +6120,7 @@ PyObject *modules_dict; PyObject *module; PyObject *module_name, *global_name; + _Py_IDENTIFIER(modules); if (!PyArg_UnpackTuple(args, "find_class", 2, 2, &module_name, &global_name)) @@ -5556,11 +6192,11 @@ module = PyImport_Import(module_name); if (module == NULL) return NULL; - global = PyObject_GetAttr(module, global_name); + global = getattribute(module, global_name, self->proto >= 4); Py_DECREF(module); } else { - global = PyObject_GetAttr(module, global_name); + global = getattribute(module, global_name, self->proto >= 4); } return global; } @@ -5723,6 +6359,7 @@ self->arg = NULL; self->proto = 0; + self->framing = 0; return 0; } diff -r a75b88048339 -r 8434af450da0 Objects/classobject.c --- a/Objects/classobject.c Thu Nov 14 16:16:29 2013 -0800 +++ b/Objects/classobject.c Fri Nov 15 03:07:56 2013 -0800 @@ -69,6 +69,30 @@ return (PyObject *)im; } +static PyObject * +method_reduce(PyMethodObject *im) +{ + PyObject *self = PyMethod_GET_SELF(im); + PyObject *func = PyMethod_GET_FUNCTION(im); + PyObject *builtins; + PyObject *getattr; + PyObject *funcname; + _Py_IDENTIFIER(getattr); + + funcname = _PyObject_GetAttrId(func, &PyId___name__); + if (funcname == NULL) { + return NULL; + } + builtins = PyEval_GetBuiltins(); + getattr = _PyDict_GetItemId(builtins, &PyId_getattr); + return Py_BuildValue("O(ON)", getattr, self, funcname); +} + +static PyMethodDef method_methods[] = { + {"__reduce__", (PyCFunction)method_reduce, METH_NOARGS, NULL}, + {NULL, NULL} +}; + /* Descriptors for PyMethod attributes */ /* im_func and im_self are stored in the PyMethod object */ @@ -367,7 +391,7 @@ offsetof(PyMethodObject, im_weakreflist), /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ - 0, /* tp_methods */ + method_methods, /* tp_methods */ method_memberlist, /* tp_members */ method_getset, /* tp_getset */ 0, /* tp_base */ diff -r a75b88048339 -r 8434af450da0 Objects/descrobject.c --- a/Objects/descrobject.c Thu Nov 14 16:16:29 2013 -0800 +++ b/Objects/descrobject.c Fri Nov 15 03:07:56 2013 -0800 @@ -398,6 +398,24 @@ return descr->d_qualname; } +static PyObject * +descr_reduce(PyDescrObject *descr) +{ + PyObject *builtins; + PyObject *getattr; + _Py_IDENTIFIER(getattr); + + builtins = PyEval_GetBuiltins(); + getattr = _PyDict_GetItemId(builtins, &PyId_getattr); + return Py_BuildValue("O(OO)", getattr, PyDescr_TYPE(descr), + PyDescr_NAME(descr)); +} + +static PyMethodDef descr_methods[] = { + {"__reduce__", (PyCFunction)descr_reduce, METH_NOARGS, NULL}, + {NULL, NULL} +}; + static PyMemberDef descr_members[] = { {"__objclass__", T_OBJECT, offsetof(PyDescrObject, d_type), READONLY}, {"__name__", T_OBJECT, offsetof(PyDescrObject, d_name), READONLY}, @@ -494,7 +512,7 @@ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ - 0, /* tp_methods */ + descr_methods, /* tp_methods */ descr_members, /* tp_members */ method_getset, /* tp_getset */ 0, /* tp_base */ @@ -532,7 +550,7 @@ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ - 0, /* tp_methods */ + descr_methods, /* tp_methods */ descr_members, /* tp_members */ method_getset, /* tp_getset */ 0, /* tp_base */ @@ -569,7 +587,7 @@ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ - 0, /* tp_methods */ + descr_methods, /* tp_methods */ descr_members, /* tp_members */ member_getset, /* tp_getset */ 0, /* tp_base */ @@ -643,7 +661,7 @@ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ - 0, /* tp_methods */ + descr_methods, /* tp_methods */ descr_members, /* tp_members */ wrapperdescr_getset, /* tp_getset */ 0, /* tp_base */ @@ -1085,6 +1103,23 @@ wp->self); } +static PyObject * +wrapper_reduce(wrapperobject *wp) +{ + PyObject *builtins; + PyObject *getattr; + _Py_IDENTIFIER(getattr); + + builtins = PyEval_GetBuiltins(); + getattr = _PyDict_GetItemId(builtins, &PyId_getattr); + return Py_BuildValue("O(OO)", getattr, wp->self, PyDescr_NAME(wp->descr)); +} + +static PyMethodDef wrapper_methods[] = { + {"__reduce__", (PyCFunction)wrapper_reduce, METH_NOARGS, NULL}, + {NULL, NULL} +}; + static PyMemberDef wrapper_members[] = { {"__self__", T_OBJECT, offsetof(wrapperobject, self), READONLY}, {0} @@ -1193,7 +1228,7 @@ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ - 0, /* tp_methods */ + wrapper_methods, /* tp_methods */ wrapper_members, /* tp_members */ wrapper_getsets, /* tp_getset */ 0, /* tp_base */ diff -r a75b88048339 -r 8434af450da0 Objects/typeobject.c --- a/Objects/typeobject.c Thu Nov 14 16:16:29 2013 -0800 +++ b/Objects/typeobject.c Fri Nov 15 03:07:56 2013 -0800 @@ -3405,149 +3405,428 @@ return cached_copyreg_module; } -static PyObject * -slotnames(PyObject *cls) -{ - PyObject *clsdict; +Py_LOCAL(PyObject *) +_PyType_GetSlotNames(PyTypeObject *cls) +{ PyObject *copyreg; PyObject *slotnames; _Py_IDENTIFIER(__slotnames__); _Py_IDENTIFIER(_slotnames); - clsdict = ((PyTypeObject *)cls)->tp_dict; - slotnames = _PyDict_GetItemId(clsdict, &PyId___slotnames__); - if (slotnames != NULL && PyList_Check(slotnames)) { + assert(PyType_Check(cls)); + + /* Get the slot names from the cache in the class if possible. */ + slotnames = _PyDict_GetItemIdWithError(cls->tp_dict, &PyId___slotnames__); + if (slotnames != NULL) { + if (slotnames != Py_None && !PyList_Check(slotnames)) { + PyErr_Format(PyExc_TypeError, + "%.200s.__slotnames__ should be a list or None, " + "not %.200s", + cls->tp_name, Py_TYPE(slotnames)->tp_name); + return NULL; + } Py_INCREF(slotnames); return slotnames; } + else { + if (PyErr_Occurred()) { + return NULL; + } + /* The class does not have the slot names cached yet. */ + } copyreg = import_copyreg(); if (copyreg == NULL) return NULL; - slotnames = _PyObject_CallMethodId(copyreg, &PyId__slotnames, "O", cls); + /* Use _slotnames function from the copyreg module to find the slots + by this class and its bases. This function will cache the result + in __slotnames__. */ + slotnames = _PyObject_CallMethodIdObjArgs(copyreg, &PyId__slotnames, + cls, NULL); Py_DECREF(copyreg); - if (slotnames != NULL && - slotnames != Py_None && - !PyList_Check(slotnames)) - { + if (slotnames == NULL) + return NULL; + + if (slotnames != Py_None && !PyList_Check(slotnames)) { PyErr_SetString(PyExc_TypeError, - "copyreg._slotnames didn't return a list or None"); + "copyreg._slotnames didn't return a list or None"); Py_DECREF(slotnames); - slotnames = NULL; + return NULL; } return slotnames; } -static PyObject * -reduce_2(PyObject *obj) -{ - PyObject *cls, *getnewargs; - PyObject *args = NULL, *args2 = NULL; - PyObject *getstate = NULL, *state = NULL, *names = NULL; - PyObject *slots = NULL, *listitems = NULL, *dictitems = NULL; - PyObject *copyreg = NULL, *newobj = NULL, *res = NULL; - Py_ssize_t i, n; - _Py_IDENTIFIER(__getnewargs__); +Py_LOCAL(PyObject *) +_PyObject_GetState(PyObject *obj) +{ + PyObject *state; + PyObject *getstate; _Py_IDENTIFIER(__getstate__); - _Py_IDENTIFIER(__newobj__); - - cls = (PyObject *) Py_TYPE(obj); - - getnewargs = _PyObject_GetAttrId(obj, &PyId___getnewargs__); - if (getnewargs != NULL) { - args = PyObject_CallObject(getnewargs, NULL); - Py_DECREF(getnewargs); - if (args != NULL && !PyTuple_Check(args)) { - PyErr_Format(PyExc_TypeError, - "__getnewargs__ should return a tuple, " - "not '%.200s'", Py_TYPE(args)->tp_name); - goto end; + + getstate = _PyObject_GetAttrId(obj, &PyId___getstate__); + if (getstate == NULL) { + PyObject *slotnames; + + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { + return NULL; } - } - else { PyErr_Clear(); - args = PyTuple_New(0); - } - if (args == NULL) - goto end; - - getstate = _PyObject_GetAttrId(obj, &PyId___getstate__); - if (getstate != NULL) { + + { + PyObject **dict; + dict = _PyObject_GetDictPtr(obj); + /* It is possible that the object's dict is not initialized + yet. In this case, we will return None for the state. + We also return None if the dict is empty to make the behavior + consistent regardless whether the dict was initialized or not. + This make unit testing easier. */ + if (dict != NULL && *dict != NULL && PyDict_Size(*dict) > 0) { + state = *dict; + } + else { + state = Py_None; + } + Py_INCREF(state); + } + + slotnames = _PyType_GetSlotNames(Py_TYPE(obj)); + if (slotnames == NULL) { + Py_DECREF(state); + return NULL; + } + + assert(slotnames == Py_None || PyList_Check(slotnames)); + if (slotnames != Py_None && Py_SIZE(slotnames) > 0) { + PyObject *slots; + Py_ssize_t slotnames_size, i; + + slots = PyDict_New(); + if (slots == NULL) { + Py_DECREF(slotnames); + Py_DECREF(state); + return NULL; + } + + slotnames_size = Py_SIZE(slotnames); + for (i = 0; i < slotnames_size; i++) { + PyObject *name, *value; + + name = PyList_GET_ITEM(slotnames, i); + value = PyObject_GetAttr(obj, name); + if (value == NULL) { + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { + goto error; + } + /* It is not an error if the attribute is not present. */ + PyErr_Clear(); + } + else { + int err = PyDict_SetItem(slots, name, value); + Py_DECREF(value); + if (err) { + goto error; + } + } + + /* The list is stored on the class so it may mutates while we + iterate over it */ + if (slotnames_size != Py_SIZE(slotnames)) { + PyErr_Format(PyExc_RuntimeError, + "__slotsname__ changed size during iteration"); + goto error; + } + + /* We handle errors within the loop here. */ + if (0) { + error: + Py_DECREF(slotnames); + Py_DECREF(slots); + Py_DECREF(state); + return NULL; + } + } + + /* If we found some slot attributes, pack them in a tuple along + the orginal attribute dictionary. */ + if (PyDict_Size(slots) > 0) { + PyObject *state2; + + state2 = PyTuple_Pack(2, state, slots); + Py_DECREF(state); + if (state2 == NULL) { + Py_DECREF(slotnames); + Py_DECREF(slots); + return NULL; + } + state = state2; + } + Py_DECREF(slots); + } + Py_DECREF(slotnames); + } + else { /* getstate != NULL */ state = PyObject_CallObject(getstate, NULL); Py_DECREF(getstate); if (state == NULL) - goto end; + return NULL; + } + + return state; +} + +Py_LOCAL(int) +_PyObject_GetNewArguments(PyObject *obj, PyObject **args, PyObject **kwargs) +{ + PyObject *getnewargs, *getnewargs_ex; + _Py_IDENTIFIER(__getnewargs_ex__); + _Py_IDENTIFIER(__getnewargs__); + + if (args == NULL || kwargs == NULL) { + PyErr_BadInternalCall(); + return -1; + } + + /* We first attempt to fetch the arguments for __new__ by calling + __getnewargs_ex__ on the object. */ + getnewargs_ex = _PyObject_GetAttrId(obj, &PyId___getnewargs_ex__); + if (getnewargs_ex != NULL) { + PyObject *newargs = PyObject_CallObject(getnewargs_ex, NULL); + Py_DECREF(getnewargs_ex); + if (newargs == NULL) { + return -1; + } + if (!PyTuple_Check(newargs)) { + PyErr_Format(PyExc_TypeError, + "__getnewargs_ex__ should return a tuple, " + "not '%.200s'", Py_TYPE(newargs)->tp_name); + Py_DECREF(newargs); + return -1; + } + if (Py_SIZE(newargs) != 2) { + PyErr_Format(PyExc_ValueError, + "__getnewargs_ex__ should return a tuple of " + "length 2, not %zd", Py_SIZE(newargs)); + Py_DECREF(newargs); + return -1; + } + *args = PyTuple_GET_ITEM(newargs, 0); + Py_INCREF(*args); + *kwargs = PyTuple_GET_ITEM(newargs, 1); + Py_INCREF(*kwargs); + Py_DECREF(newargs); + + /* XXX We should perhaps allow None to be passed here. */ + if (!PyTuple_Check(*args)) { + PyErr_Format(PyExc_TypeError, + "first item of the tuple returned by " + "__getnewargs_ex__ must be a tuple, not '%.200s'", + Py_TYPE(*args)->tp_name); + Py_CLEAR(*args); + Py_CLEAR(*kwargs); + return -1; + } + if (!PyDict_Check(*kwargs)) { + PyErr_Format(PyExc_TypeError, + "second item of the tuple returned by " + "__getnewargs_ex__ must be a dict, not '%.200s'", + Py_TYPE(*kwargs)->tp_name); + Py_CLEAR(*args); + Py_CLEAR(*kwargs); + return -1; + } + return 0; + } else { + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { + return -1; + } + PyErr_Clear(); + } + + /* The object does not have __getnewargs_ex__ so we fallback on using + __getnewargs__ instead. */ + getnewargs = _PyObject_GetAttrId(obj, &PyId___getnewargs__); + if (getnewargs != NULL) { + *args = PyObject_CallObject(getnewargs, NULL); + Py_DECREF(getnewargs); + if (*args == NULL) { + return -1; + } + if (!PyTuple_Check(*args)) { + PyErr_Format(PyExc_TypeError, + "__getnewargs__ should return a tuple, " + "not '%.200s'", Py_TYPE(*args)->tp_name); + Py_CLEAR(*args); + return -1; + } + *kwargs = NULL; + return 0; + } else { + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { + return -1; + } + PyErr_Clear(); + } + + /* The object does not have __getnewargs_ex__ and __getnewargs__. This may + means __new__ does not takes any arguments on this object, or that the + object does not implement the reduce protocol for pickling or + copying. */ + *args = NULL; + *kwargs = NULL; + return 0; +} + +Py_LOCAL(int) +_PyObject_GetItemsIter(PyObject *obj, PyObject **listitems, + PyObject **dictitems) +{ + if (listitems == NULL || dictitems == NULL) { + PyErr_BadInternalCall(); + return -1; + } + + if (!PyList_Check(obj)) { + *listitems = Py_None; + Py_INCREF(*listitems); } else { - PyObject **dict; - PyErr_Clear(); - dict = _PyObject_GetDictPtr(obj); - if (dict && *dict) - state = *dict; - else - state = Py_None; - Py_INCREF(state); - names = slotnames(cls); - if (names == NULL) - goto end; - if (names != Py_None && PyList_GET_SIZE(names) > 0) { - assert(PyList_Check(names)); - slots = PyDict_New(); - if (slots == NULL) - goto end; - n = 0; - /* Can't pre-compute the list size; the list - is stored on the class so accessible to other - threads, which may be run by DECREF */ - for (i = 0; i < PyList_GET_SIZE(names); i++) { - PyObject *name, *value; - name = PyList_GET_ITEM(names, i); - value = PyObject_GetAttr(obj, name); - if (value == NULL) - PyErr_Clear(); - else { - int err = PyDict_SetItem(slots, name, - value); - Py_DECREF(value); - if (err) - goto end; - n++; - } - } - if (n) { - state = Py_BuildValue("(NO)", state, slots); - if (state == NULL) - goto end; - } + *listitems = PyObject_GetIter(obj); + if (listitems == NULL) + return -1; + } + + if (!PyDict_Check(obj)) { + *dictitems = Py_None; + Py_INCREF(*dictitems); + } + else { + PyObject *items; + _Py_IDENTIFIER(items); + + items = _PyObject_CallMethodIdObjArgs(obj, &PyId_items, NULL); + if (items == NULL) { + Py_CLEAR(*listitems); + return -1; } - } - - if (!PyList_Check(obj)) { - listitems = Py_None; - Py_INCREF(listitems); - } - else { - listitems = PyObject_GetIter(obj); - if (listitems == NULL) - goto end; - } - - if (!PyDict_Check(obj)) { - dictitems = Py_None; - Py_INCREF(dictitems); - } - else { - _Py_IDENTIFIER(items); - PyObject *items = _PyObject_CallMethodId(obj, &PyId_items, ""); - if (items == NULL) - goto end; - dictitems = PyObject_GetIter(items); + *dictitems = PyObject_GetIter(items); Py_DECREF(items); - if (dictitems == NULL) - goto end; - } + if (*dictitems == NULL) { + Py_CLEAR(*listitems); + return -1; + } + } + + assert(*listitems != NULL && *dictitems != NULL); + + return 0; +} + +static PyObject * +reduce_4(PyObject *obj) +{ + PyObject *args = NULL, *kwargs = NULL; + PyObject *copyreg; + PyObject *newobj, *newargs, *state, *listitems, *dictitems; + PyObject *result; + _Py_IDENTIFIER(__newobj_ex__); + + if (_PyObject_GetNewArguments(obj, &args, &kwargs) < 0) { + return NULL; + } + if (args == NULL) { + args = PyTuple_New(0); + if (args == NULL) + return NULL; + } + if (kwargs == NULL) { + kwargs = PyDict_New(); + if (kwargs == NULL) + return NULL; + } + + copyreg = import_copyreg(); + if (copyreg == NULL) { + Py_DECREF(args); + Py_DECREF(kwargs); + return NULL; + } + newobj = _PyObject_GetAttrId(copyreg, &PyId___newobj_ex__); + Py_DECREF(copyreg); + if (newobj == NULL) { + Py_DECREF(args); + Py_DECREF(kwargs); + return NULL; + } + newargs = PyTuple_Pack(3, Py_TYPE(obj), args, kwargs); + Py_DECREF(args); + Py_DECREF(kwargs); + if (newargs == NULL) { + Py_DECREF(newobj); + return NULL; + } + state = _PyObject_GetState(obj); + if (state == NULL) { + Py_DECREF(newobj); + Py_DECREF(newargs); + return NULL; + } + if (_PyObject_GetItemsIter(obj, &listitems, &dictitems) < 0) { + Py_DECREF(newobj); + Py_DECREF(newargs); + Py_DECREF(state); + return NULL; + } + + result = PyTuple_Pack(5, newobj, newargs, state, listitems, dictitems); + Py_DECREF(newobj); + Py_DECREF(newargs); + Py_DECREF(state); + Py_DECREF(listitems); + Py_DECREF(dictitems); + return result; +} + +static PyObject * +reduce_2(PyObject *obj) +{ + PyObject *cls; + PyObject *args = NULL, *args2 = NULL, *kwargs = NULL; + PyObject *state = NULL, *listitems = NULL, *dictitems = NULL; + PyObject *copyreg = NULL, *newobj = NULL, *res = NULL; + Py_ssize_t i, n; + _Py_IDENTIFIER(__newobj__); + + if (_PyObject_GetNewArguments(obj, &args, &kwargs) < 0) { + return NULL; + } + if (args == NULL) { + assert(kwargs == NULL); + args = PyTuple_New(0); + if (args == NULL) { + return NULL; + } + } + else if (kwargs != NULL) { + if (PyDict_Size(kwargs) > 0) { + PyErr_SetString(PyExc_ValueError, + "must use protocol 4 or greater to copy this " + "object; since __getnewargs_ex__ returned " + "keyword arguments."); + Py_DECREF(args); + Py_DECREF(kwargs); + return NULL; + } + Py_CLEAR(kwargs); + } + + state = _PyObject_GetState(obj); + if (state == NULL) + goto end; + + if (_PyObject_GetItemsIter(obj, &listitems, &dictitems) < 0) + goto end; copyreg = import_copyreg(); if (copyreg == NULL) @@ -3560,6 +3839,7 @@ args2 = PyTuple_New(n+1); if (args2 == NULL) goto end; + cls = (PyObject *) Py_TYPE(obj); Py_INCREF(cls); PyTuple_SET_ITEM(args2, 0, cls); for (i = 0; i < n; i++) { @@ -3573,9 +3853,7 @@ end: Py_XDECREF(args); Py_XDECREF(args2); - Py_XDECREF(slots); Py_XDECREF(state); - Py_XDECREF(names); Py_XDECREF(listitems); Py_XDECREF(dictitems); Py_XDECREF(copyreg); @@ -3603,7 +3881,9 @@ { PyObject *copyreg, *res; - if (proto >= 2) + if (proto >= 4) + return reduce_4(self); + else if (proto >= 2) return reduce_2(self); copyreg = import_copyreg();