diff --git a/Lib/test/test_marshal.py b/Lib/test/test_marshal.py --- a/Lib/test/test_marshal.py +++ b/Lib/test/test_marshal.py @@ -8,6 +8,7 @@ import unittest import os import types +import struct class HelperMixin: def helper(self, sample, *extra): @@ -418,6 +419,31 @@ def test3To3(self): self._test(3) +class VersionTestCase(unittest.TestCase): + def get_code(self, version): + with open(__file__, "rb") as f: + code = f.read() + if __file__.endswith(".py"): + code = compile(code, __file__, "exec") + return marshal.dumps(code, version) + def get_token(self, version): + return b"v" + struct.pack('B', version) + + def testVersion(self): + self.assertTrue(self.get_code(4).startswith(self.get_token(4))) + for i in range(4): + self.assertFalse(self.get_code(i).startswith(b"v")) + + def testValid(self): + code = self.get_code(4) + marshal.loads(self.get_token(4) + code) + + def testInvalid(self): + # don't accept more than one token + code = self.get_code(4) + self.assertRaises(ValueError, marshal.loads, self.get_token(3) + code) + self.assertRaises(ValueError, marshal.loads, self.get_token(5) + code) + class InterningTestCase(unittest.TestCase, HelperMixin): strobj = "this is an interned string" strobj = sys.intern(strobj) @@ -447,6 +473,7 @@ BufferTestCase, BugsTestCase, LargeValuesTestCase, + VersionTestCase, ) if __name__ == "__main__": diff --git a/Python/marshal.c b/Python/marshal.c --- a/Python/marshal.c +++ b/Python/marshal.c @@ -4,7 +4,13 @@ even though dicts, lists, sets and frozensets, not commonly seen in code objects, are supported. Version 3 of this protocol properly supports circular links - and sharing. */ + and sharing. + All current versions (up to 4) are backwards compatible. + Version 4 introduces a version token, to identify the version + of the pickle. In the absence of such a token, we can assume + version 4, or lower. In future, this can allow us to break + backwards compatibility. + */ #define PY_SSIZE_T_CLEAN @@ -26,6 +32,7 @@ #define MAX_MARSHAL_STACK_DEPTH 2000 #endif +#define TYPE_VERSION 'v' #define TYPE_NULL '0' #define TYPE_NONE 'N' #define TYPE_FALSE 'F' @@ -41,6 +48,7 @@ #define TYPE_STRING 's' #define TYPE_INTERNED 't' #define TYPE_REF 'r' +#define TYPE_SHORT_REF 'R' #define TYPE_TUPLE '(' #define TYPE_LIST '[' #define TYPE_DICT '{' @@ -248,8 +256,13 @@ } /* we don't store "long" indices in the dict */ assert(0 <= w && w <= 0x7fffffff); - w_byte(TYPE_REF, p); - w_long(w, p); + if (w < 256) { + w_byte(TYPE_SHORT_REF, p); + w_byte((unsigned char)w, p); + } else { + w_byte(TYPE_REF, p); + w_long(w, p); + } return 1; } else { int ok; @@ -547,6 +560,15 @@ } } +static void +w_version(WFILE *p) +{ + if (p->version >= 4) { + w_byte(TYPE_VERSION, p); + w_byte(p->version, p); + } +} + /* version currently has no effect for writing ints. */ void PyMarshal_WriteLongToFile(long x, FILE *fp, int version) @@ -573,6 +595,7 @@ } else wf.refs = NULL; wf.version = version; + w_version(&wf); w_object(x, &wf); Py_XDECREF(wf.refs); } @@ -670,6 +693,9 @@ if (ptr != NULL) c = *(unsigned char *) ptr; } + if (c == EOF) + PyErr_SetString(PyExc_EOFError, + "EOF read where object expected"); return c; } @@ -843,11 +869,8 @@ int flag, is_interned = 0; PyObject *retval = NULL; - if (code == EOF) { - PyErr_SetString(PyExc_EOFError, - "EOF read where object expected"); + if (code == EOF) return NULL; - } p->depth++; @@ -866,6 +889,20 @@ } while (0) switch (type) { + case TYPE_VERSION: + n = r_byte(p); + if (n == EOF) + break; + if (n < 4 || n > Py_MARSHAL_VERSION) { + /* this token first appears in version 4 */ + PyErr_Format(PyExc_ValueError, "invalid version token %d", n); + break; + } + p->version = n; + p->depth--; /* this doesn't count as recursion */ + retval = r_object(p); + p->depth++; + break; case TYPE_NULL: break; @@ -911,11 +948,8 @@ char buf[256], *ptr; double dx; n = r_byte(p); - if (n == EOF) { - PyErr_SetString(PyExc_EOFError, - "EOF read where object expected"); + if (n == EOF) break; - } ptr = r_string(n, p); if (ptr == NULL) break; @@ -949,11 +983,8 @@ char buf[256], *ptr; Py_complex c; n = r_byte(p); - if (n == EOF) { - PyErr_SetString(PyExc_EOFError, - "EOF read where object expected"); + if (n == EOF) break; - } ptr = r_string(n, p); if (ptr == NULL) break; @@ -963,11 +994,8 @@ if (c.real == -1.0 && PyErr_Occurred()) break; n = r_byte(p); - if (n == EOF) { - PyErr_SetString(PyExc_EOFError, - "EOF read where object expected"); + if (n == EOF) break; - } ptr = r_string(n, p); if (ptr == NULL) break; @@ -1042,11 +1070,8 @@ is_interned = 1; case TYPE_SHORT_ASCII: n = r_byte(p); - if (n == EOF) { - PyErr_SetString(PyExc_EOFError, - "EOF read where object expected"); + if (n == EOF) break; - } _read_ascii: { char *ptr; @@ -1096,6 +1121,8 @@ case TYPE_SMALL_TUPLE: n = (unsigned char) r_byte(p); + if (n == EOF) + break; goto _read_tuple; case TYPE_TUPLE: n = r_long(p); @@ -1325,11 +1352,18 @@ retval = v; break; + case TYPE_SHORT_REF: + n = r_byte(p); + if (n == EOF) + break; + goto _read_ref; + case TYPE_REF: n = r_long(p); + if (n == -1 && PyErr_Occurred()) + break; + _read_ref: if (n < 0 || n >= PyList_GET_SIZE(p->refs)) { - if (n == -1 && PyErr_Occurred()) - break; PyErr_SetString(PyExc_ValueError, "bad marshal data (invalid reference)"); break; } @@ -1458,6 +1492,7 @@ rf.ptr = rf.end = NULL; rf.buf = NULL; rf.refs = PyList_New(0); + rf.version = 4; /* can assum this, it is backwards compatible */ if (rf.refs == NULL) return NULL; result = r_object(&rf); @@ -1480,6 +1515,7 @@ rf.buf = NULL; rf.depth = 0; rf.refs = PyList_New(0); + rf.version = 4; if (rf.refs == NULL) return NULL; result = r_object(&rf); @@ -1509,6 +1545,7 @@ return NULL; } else wf.refs = NULL; + w_version(&wf); w_object(x, &wf); Py_XDECREF(wf.refs); if (wf.str != NULL) { @@ -1601,6 +1638,7 @@ rf.current_filename = NULL; rf.ptr = rf.end = NULL; rf.buf = NULL; + rf.version = 4; if ((rf.refs = PyList_New(0)) != NULL) { result = read_object(&rf); Py_DECREF(rf.refs); @@ -1664,6 +1702,7 @@ rf.ptr = s; rf.end = s + n; rf.depth = 0; + rf.version = 0; if ((rf.refs = PyList_New(0)) == NULL) return NULL; result = read_object(&rf);