diff --git a/Lib/pickle.py b/Lib/pickle.py --- a/Lib/pickle.py +++ b/Lib/pickle.py @@ -176,6 +176,92 @@ FROZENSET = b'\x92' # build froz __all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$", x)]) + +class _Framer: + + _FRAME_SIZE_TARGET = 64 * 1024 + + def __init__(self, file_write): + self.file_write = file_write + self.current_frame = None + + def _commit_frame(self): + f = self.current_frame + with f.getbuffer() as data: + n = len(data) + write = self.file_write + write(pack("= self._FRAME_SIZE_TARGET: + self._commit_frame() + return f.write(data) + +class _Unframer: + + def __init__(self, file_read, file_readline, file_tell=None): + self.file_read = file_read + self.file_readline = file_readline + self.file_tell = file_tell + self.framing_enabled = False + self.current_frame = None + self.frame_start = None + + def read(self, n): + if n == 0: + return b'' + _file_read = self.file_read + if not self.framing_enabled: + return _file_read(n) + f = self.current_frame + if f is not None: + data = f.read(n) + if data: + if len(data) < n: + raise UnpicklingError("pickle exhausted before end of frame") + return data + frame_size, = unpack(" sys.maxsize: + raise ValueError("frame size > sys.maxsize: %d" % frame_size) + if self.file_tell is not None: + self.frame_start = self.file_tell() + f = self.current_frame = io.BytesIO(_file_read(frame_size)) + self.readline = f.readline + data = f.read(n) + assert len(data) == n, (len(data), n) + return data + + def readline(self): + if not self.framing_enabled: + return self.file_readline() + else: + return self.current_frame.readline() + + def tell(self): + if self.file_tell is None: + return None + elif self.current_frame is None: + return self.file_tell() + else: + return self.frame_start + self.current_frame.tell() + + # Pickling machinery class _Pickler: @@ -209,7 +295,7 @@ class _Pickler: elif not 0 <= protocol <= HIGHEST_PROTOCOL: raise ValueError("pickle protocol must be <= %d" % HIGHEST_PROTOCOL) try: - self.write = file.write + self._file_write = file.write except AttributeError: raise TypeError("file must have a 'write' attribute") self.memo = {} @@ -233,13 +319,22 @@ class _Pickler: """Write a pickled representation of obj to the open file.""" # Check whether Pickler was initialized correctly. This is # only needed to mimic the behavior of _pickle.Pickler.dump(). - if not hasattr(self, "write"): + if not hasattr(self, "_file_write"): raise PicklingError("Pickler.__init__() was not called by " "%s.__init__()" % (self.__class__.__name__,)) if self.proto >= 2: - self.write(PROTO + pack("= 4: + framer = _Framer(self._file_write) + framer.start_framing() + self.write = framer.write + else: + framer = None + self.write = self._file_write self.save(obj) self.write(STOP) + if framer is not None: + framer.end_framing() def memoize(self, obj): """Store an object in the memo.""" @@ -840,8 +935,8 @@ class _Unpickler: instances pickled by Python 2.x; these default to 'ASCII' and 'strict', respectively. """ - self.readline = file.readline - self.read = file.read + self._file_readline = file.readline + self._file_read = file.read self.memo = {} self.encoding = encoding self.errors = errors @@ -855,12 +950,16 @@ class _Unpickler: """ # Check whether Unpickler was initialized correctly. This is # only needed to mimic the behavior of _pickle.Unpickler.dump(). - if not hasattr(self, "read"): + if not hasattr(self, "_file_read"): raise UnpicklingError("Unpickler.__init__() was not called by " "%s.__init__()" % (self.__class__.__name__,)) + self._unframer = _Unframer(self._file_read, self._file_readline) + self.read = self._unframer.read + self.readline = self._unframer.readline self.mark = object() # any new unique object self.stack = [] self.append = self.stack.append + self.proto = 0 read = self.read dispatch = self.dispatch try: @@ -898,6 +997,8 @@ class _Unpickler: if not 0 <= proto <= HIGHEST_PROTOCOL: raise ValueError("unsupported pickle protocol: %d" % proto) self.proto = proto + if proto >= 4: + self._unframer.framing_enabled = True dispatch[PROTO[0]] = load_proto def load_persid(self): diff --git a/Lib/pickletools.py b/Lib/pickletools.py --- a/Lib/pickletools.py +++ b/Lib/pickletools.py @@ -11,6 +11,7 @@ dis(pickle, out=None, memo=None, indentl ''' import codecs +import io import pickle import re import sys @@ -2178,42 +2179,19 @@ del assure_pickle_consistency ############################################################################## # A pickle opcode generator. -def genops(pickle): - """Generate all the opcodes in a pickle. - - 'pickle' is a file-like object, or string, containing the pickle. - - Each opcode in the pickle is generated, from the current pickle position, - stopping after a STOP opcode is delivered. A triple is generated for - each opcode: - - opcode, arg, pos - - opcode is an OpcodeInfo record, describing the current opcode. - - If the opcode has an argument embedded in the pickle, arg is its decoded - value, as a Python object. If the opcode doesn't have an argument, arg - is None. - - If the pickle has a tell() method, pos was the value of pickle.tell() - before reading the current opcode. If the pickle is a bytes object, - it's wrapped in a BytesIO object, and the latter's tell() result is - used. Else (the pickle doesn't have a tell(), and it's not obvious how - to query its current position) pos is None. - """ - - if isinstance(pickle, bytes_types): - import io - pickle = io.BytesIO(pickle) - - if hasattr(pickle, "tell"): - getpos = pickle.tell - else: - getpos = lambda: None +def _genops(data, yield_end_pos=False): + if isinstance(data, bytes_types): + data = io.BytesIO(data) + + unframer = pickle._Unframer(data.read, data.readline, + getattr(data, "tell", None)) + getpos = unframer.tell while True: - pos = getpos() - code = pickle.read(1) + code = unframer.read(1) + # So that the opcode's actual pos is announced, not the frame start + arg_pos = getpos() + pos = arg_pos - 1 if arg_pos is not None else None opcode = code2op.get(code.decode("latin-1")) if opcode is None: if code == b"": @@ -2225,38 +2203,81 @@ def genops(pickle): if opcode.arg is None: arg = None else: - arg = opcode.arg.reader(pickle) - yield opcode, arg, pos + arg = opcode.arg.reader(unframer) + if yield_end_pos: + yield opcode, arg, pos, getpos() + else: + yield opcode, arg, pos if code == b'.': assert opcode.name == 'STOP' break + elif code == b'\x80': + assert opcode.name == 'PROTO' + if arg >= 4: + unframer.framing_enabled = True + +def genops(pickle): + """Generate all the opcodes in a pickle. + + 'pickle' is a file-like object, or string, containing the pickle. + + Each opcode in the pickle is generated, from the current pickle position, + stopping after a STOP opcode is delivered. A triple is generated for + each opcode: + + opcode, arg, pos + + opcode is an OpcodeInfo record, describing the current opcode. + + If the opcode has an argument embedded in the pickle, arg is its decoded + value, as a Python object. If the opcode doesn't have an argument, arg + is None. + + If the pickle has a tell() method, pos was the value of pickle.tell() + before reading the current opcode. If the pickle is a bytes object, + it's wrapped in a BytesIO object, and the latter's tell() result is + used. Else (the pickle doesn't have a tell(), and it's not obvious how + to query its current position) pos is None. + """ + return _genops(pickle) ############################################################################## # A pickle optimizer. def optimize(p): 'Optimize a pickle string by removing unused PUT opcodes' - gets = set() # set of args used by a GET opcode - puts = [] # (arg, startpos, stoppos) for the PUT opcodes - prevpos = None # set to pos if previous opcode was a PUT - for opcode, arg, pos in genops(p): - if prevpos is not None: - puts.append((prevarg, prevpos, pos)) - prevpos = None + not_a_put = object() + gets = { not_a_put } # set of args used by a GET opcode + opcodes = [] # (startpos, stoppos, putid) + proto = 0 + for opcode, arg, pos, end_pos in _genops(p, yield_end_pos=True): if 'PUT' in opcode.name: - prevarg, prevpos = arg, pos - elif 'GET' in opcode.name: - gets.add(arg) - - # Copy the pickle string except for PUTS without a corresponding GET - s = [] - i = 0 - for arg, start, stop in puts: - j = stop if (arg in gets) else start - s.append(p[i:j]) - i = stop - s.append(p[i:]) - return b''.join(s) + opcodes.append((pos, end_pos, arg)) + else: + if 'GET' in opcode.name: + gets.add(arg) + elif opcode.name == 'PROTO': + assert pos == 0 + proto = arg + opcodes.append((pos, end_pos, not_a_put)) + prevpos, prevarg = pos, None + + # Copy the opcodes except for PUTS without a corresponding GET + out = io.BytesIO() + opcodes = iter(opcodes) + if proto >= 2: + # Write the PROTO header before any framing + start, stop, _ = next(opcodes) + out.write(p[start:stop]) + buf = pickle._Framer(out.write) + if proto >= 4: + buf.start_framing() + for start, stop, putid in opcodes: + if putid in gets: + buf.write(p[start:stop]) + if proto >= 4: + buf.end_framing() + return out.getvalue() ############################################################################## # A symbolic pickle disassembler. diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py --- a/Lib/test/pickletester.py +++ b/Lib/test/pickletester.py @@ -776,7 +776,12 @@ class AbstractPickleTests(unittest.TestC for proto in protocols: expected = build_none if proto >= 2: - expected = pickle.PROTO + bytes([proto]) + expected + build_proto = pickle.PROTO + bytes([proto]) + if proto >= 4: + # A 2-bytes frame + expected = build_proto + bytes([2] + [0] * 7) + expected + else: + expected = build_proto + expected p = self.dumps(None, proto) self.assertEqual(p, expected) @@ -1208,7 +1213,9 @@ class AbstractPickleTests(unittest.TestC sizes = [len(self.dumps(2**n, proto)) for n in range(70)] # the size function is monotonic self.assertEqual(sorted(sizes), sizes) - if proto >= 2: + if proto >= 4: + self.assertLessEqual(sizes[-1], 22) + elif proto >= 2: self.assertLessEqual(sizes[-1], 14) def check_negative_32b_binXXX(self, dumped): @@ -1710,22 +1717,23 @@ class AbstractPicklerUnpicklerObjectTest def _check_multiple_unpicklings(self, ioclass): for proto in protocols: - data1 = [(x, str(x)) for x in range(2000)] + [b"abcde", len] - f = ioclass() - pickler = self.pickler_class(f, protocol=proto) - pickler.dump(data1) - pickled = f.getvalue() + with self.subTest(proto=proto): + data1 = [(x, str(x)) for x in range(2000)] + [b"abcde", len] + f = ioclass() + pickler = self.pickler_class(f, protocol=proto) + pickler.dump(data1) + pickled = f.getvalue() - N = 5 - f = ioclass(pickled * N) - unpickler = self.unpickler_class(f) - for i in range(N): - if f.seekable(): - pos = f.tell() - self.assertEqual(unpickler.load(), data1) - if f.seekable(): - self.assertEqual(f.tell(), pos + len(pickled)) - self.assertRaises(EOFError, unpickler.load) + N = 5 + f = ioclass(pickled * N) + unpickler = self.unpickler_class(f) + for i in range(N): + if f.seekable(): + pos = f.tell() + self.assertEqual(unpickler.load(), data1) + if f.seekable(): + self.assertEqual(f.tell(), pos + len(pickled)) + self.assertRaises(EOFError, unpickler.load) def test_multiple_unpicklings_seekable(self): self._check_multiple_unpicklings(io.BytesIO) diff --git a/Modules/_pickle.c b/Modules/_pickle.c --- a/Modules/_pickle.c +++ b/Modules/_pickle.c @@ -112,7 +112,9 @@ enum { MAX_WRITE_BUF_SIZE = 64 * 1024, /* Prefetch size when unpickling (disabled on unpeekable streams) */ - PREFETCH = 8192 * 16 + PREFETCH = 8192 * 16, + + FRAME_SIZE_TARGET = 64 * 1024 }; /* Exception classes for pickle. These should override the ones defined in @@ -338,6 +340,8 @@ typedef struct PicklerObject { Py_ssize_t max_output_len; /* Allocation size of output_buffer. */ int proto; /* Pickle protocol number, >= 0 */ int bin; /* Boolean, true if proto > 0 */ + int framing; + Py_ssize_t frame_start; Py_ssize_t buf_size; /* Size of the current buffered pickle data */ int fast; /* Enable fast mode if set to a true value. The fast mode disable the usage of memo, @@ -368,7 +372,9 @@ typedef struct UnpicklerObject { char *input_line; Py_ssize_t input_len; Py_ssize_t next_read_idx; + Py_ssize_t frame_end_idx; Py_ssize_t prefetched_idx; /* index of first prefetched byte */ + PyObject *read; /* read() method of the input stream. */ PyObject *readline; /* readline() method of the input stream. */ PyObject *peek; /* peek() method of the input stream, or NULL */ @@ -386,6 +392,7 @@ typedef struct UnpicklerObject { int proto; /* Protocol of the pickle loaded. */ int fix_imports; /* Indicate whether Unpickler should fix the name of globals pickled by Python 2.x. */ + int framing; } UnpicklerObject; /* Forward declarations */ @@ -678,15 +685,56 @@ static int if (self->output_buffer == NULL) return -1; self->output_len = 0; + self->frame_start = -1; return 0; } +static int +_Pickler_CommitFrame(PicklerObject *self) +{ + size_t frame_len; + char *qdata; + + if (!self->framing || self->frame_start == -1) + return 0; + frame_len = self->output_len - self->frame_start - 8; + qdata = PyBytes_AS_STRING(self->output_buffer) + self->frame_start; + qdata[0] = (unsigned char)(frame_len & 0xff); + qdata[1] = (unsigned char)((frame_len >> 8) & 0xff); + qdata[2] = (unsigned char)((frame_len >> 16) & 0xff); + qdata[3] = (unsigned char)((frame_len >> 24) & 0xff); + qdata[4] = (unsigned char)((frame_len >> 32) & 0xff); + qdata[5] = (unsigned char)((frame_len >> 40) & 0xff); + qdata[6] = (unsigned char)((frame_len >> 48) & 0xff); + qdata[7] = (unsigned char)((frame_len >> 56) & 0xff); + self->frame_start = -1; + return 0; +} + +static int +_Pickler_OpcodeBoundary(PicklerObject *self) +{ + Py_ssize_t frame_len; + + if (!self->framing || self->frame_start == -1) + return 0; + frame_len = self->output_len - self->frame_start - 8; + if (frame_len >= FRAME_SIZE_TARGET) + return _Pickler_CommitFrame(self); + else + return 0; +} + static PyObject * _Pickler_GetString(PicklerObject *self) { PyObject *output_buffer = self->output_buffer; assert(self->output_buffer != NULL); + + if (_Pickler_CommitFrame(self)) + return NULL; + self->output_buffer = NULL; /* Resize down to exact size */ if (_PyBytes_Resize(&output_buffer, self->output_len) < 0) @@ -701,6 +749,7 @@ static int assert(self->write != NULL); + /* This will commit the frame first */ output = _Pickler_GetString(self); if (output == NULL) return -1; @@ -751,6 +800,17 @@ static Py_ssize_t } } buffer = PyBytes_AS_STRING(self->output_buffer); + if (self->framing) { + if (self->frame_start == -1) { + /* Setup new frame */ + Py_ssize_t frame_start = self->output_len; + self->frame_start = frame_start; + for (i = 0; i < 8; i++) { + buffer[frame_start + i] =0; + } + self->output_len += 8; + } + } if (n < 8) { /* This is faster than memcpy when the string is short. */ for (i = 0; i < n; i++) { @@ -779,6 +839,8 @@ static PicklerObject * self->write = NULL; self->proto = 0; self->bin = 0; + self->framing = 0; + self->frame_start = -1; self->fast = 0; self->fast_nesting = 0; self->fix_imports = 0; @@ -876,6 +938,7 @@ static Py_ssize_t self->input_buffer = self->buffer.buf; self->input_len = self->buffer.len; self->next_read_idx = 0; + self->frame_end_idx = -1; self->prefetched_idx = self->input_len; return self->input_len; } @@ -937,7 +1000,7 @@ static Py_ssize_t return -1; /* Prefetch some data without advancing the file pointer, if possible */ - if (self->peek) { + if (self->peek && !self->framing) { PyObject *len, *prefetched; len = PyLong_FromSsize_t(PREFETCH); if (len == NULL) { @@ -985,7 +1048,7 @@ static Py_ssize_t Returns -1 (with an exception set) on failure. On success, return the number of chars read. */ static Py_ssize_t -_Unpickler_Read(UnpicklerObject *self, char **s, Py_ssize_t n) +_Unpickler_ReadUnframed(UnpicklerObject *self, char **s, Py_ssize_t n) { Py_ssize_t num_read; @@ -1011,6 +1074,52 @@ static Py_ssize_t } static Py_ssize_t +_Unpickler_Read(UnpicklerObject *self, char **s, Py_ssize_t n) +{ + if (self->framing && + (self->frame_end_idx == -1 || + self->frame_end_idx <= self->next_read_idx)) { + /* Need to read new frame */ + char *dummy; + unsigned char *frame_start; + size_t frame_len; + if (_Unpickler_ReadUnframed(self, &dummy, 8) < 0) + return -1; + frame_start = (unsigned char *) dummy; + frame_len = (size_t) frame_start[0]; + frame_len |= (size_t) frame_start[1] << 8; + frame_len |= (size_t) frame_start[2] << 16; + frame_len |= (size_t) frame_start[3] << 24; + frame_len |= (size_t) frame_start[4] << 32; + frame_len |= (size_t) frame_start[5] << 40; + frame_len |= (size_t) frame_start[6] << 48; + frame_len |= (size_t) frame_start[7] << 56; + if (frame_len > PY_SSIZE_T_MAX) { + PyErr_Format(UnpicklingError, "Invalid frame length"); + return -1; + } + if (frame_len < n) { + PyErr_Format(UnpicklingError, "Bad framing"); + return -1; + } + if (_Unpickler_ReadUnframed(self, &dummy /* unused */, + frame_len) < 0) + return -1; + /* Rewind to start of frame */ + self->frame_end_idx = self->next_read_idx; + self->next_read_idx -= frame_len; + } + if (self->framing) { + /* Check for bad input */ + if (n + self->next_read_idx > self->frame_end_idx) { + PyErr_Format(UnpicklingError, "Bad framing"); + return -1; + } + } + return _Unpickler_ReadUnframed(self, s, n); +} + +static Py_ssize_t _Unpickler_CopyLine(UnpicklerObject *self, char *line, Py_ssize_t len, char **result) { @@ -1165,6 +1274,7 @@ static UnpicklerObject * self->input_line = NULL; self->input_len = 0; self->next_read_idx = 0; + self->frame_end_idx = -1; self->prefetched_idx = 0; self->read = NULL; self->readline = NULL; @@ -1175,6 +1285,7 @@ static UnpicklerObject * self->num_marks = 0; self->marks_size = 0; self->proto = 0; + self->framing = 0; self->fix_imports = 0; return self; @@ -1290,6 +1401,8 @@ memo_put(PicklerObject *self, PyObject * if (self->fast) return 0; + if (_Pickler_OpcodeBoundary(self)) + goto error; x = PyMemoTable_Size(self->memo); if (PyMemoTable_Set(self->memo, obj, x) < 0) @@ -3544,6 +3657,8 @@ save(PicklerObject *self, PyObject *obj, status = -1; } done: + if (status == 0) + status = _Pickler_OpcodeBoundary(self); Py_LeaveRecursiveCall(); Py_XDECREF(reduce_func); Py_XDECREF(reduce_value); @@ -3564,6 +3679,8 @@ dump(PicklerObject *self, PyObject *obj) header[1] = (unsigned char)self->proto; if (_Pickler_Write(self, header, 2) < 0) return -1; + if (self->proto >= 4) + self->framing = 1; } if (save(self, obj, 0) < 0 || @@ -5578,6 +5695,7 @@ load_proto(UnpicklerObject *self) i = (unsigned char)s[0]; if (i <= HIGHEST_PROTOCOL) { self->proto = i; + self->framing = (self->proto >= 4); return 0; } @@ -5593,6 +5711,8 @@ load(UnpicklerObject *self) char *s; self->num_marks = 0; + self->proto = 0; + self->framing = 0; if (Py_SIZE(self->stack)) Pdata_clear(self->stack, 0); @@ -5981,6 +6101,7 @@ Unpickler_init(UnpicklerObject *self, Py self->arg = NULL; self->proto = 0; + self->framing = 0; return 0; }