diff -r 7ba3e71a49f7 Python/marshal.c --- a/Python/marshal.c Thu Jun 09 11:11:45 2011 +0100 +++ b/Python/marshal.c Thu Jun 09 14:43:26 2011 +0100 @@ -53,6 +53,8 @@ #define WFERR_NOMEMORY 3 typedef struct { + PyObject * readable; + int peeked; FILE *fp; int error; /* see WFERR_* values */ int depth; @@ -461,23 +463,70 @@ typedef WFILE RFILE; /* Same struct with different invariants */ -#define rs_byte(p) (((p)->ptr < (p)->end) ? (unsigned char)*(p)->ptr++ : EOF) - -#define r_byte(p) ((p)->fp ? getc((p)->fp) : rs_byte(p)) - static int r_string(char *s, int n, RFILE *p) { - if (p->fp != NULL) - /* The result fits into int because it must be <=n. */ - return (int)fread(s, 1, n, p->fp); - if (p->end - p->ptr < n) - n = (int)(p->end - p->ptr); - memcpy(s, p->ptr, n); - p->ptr += n; + char * ptr; + + if (!p->readable) { + if (p->fp != NULL) + /* The result fits into int because it must be <=n. */ + n = (int) fread(s, 1, n, p->fp); + else { + if (p->end - p->ptr < n) + n = (int)(p->end - p->ptr); + memcpy(s, p->ptr, n); + p->ptr += n; + } + } + else { + PyObject *data = PyObject_CallMethod(p->readable, "read", "i", n); + n = 0; + if (data != NULL) { + if (PyBytes_Check(data)) { + n = PyBytes_GET_SIZE(data); + if (n > 0) { + ptr = PyBytes_AS_STRING(data); + memcpy(s, ptr, n); + } + } + else { + // TODO error handling in this case + } + Py_DECREF(data); + } + } return n; } + +#define rs_byte(p) (((p)->ptr < (p)->end) ? (unsigned char)*(p)->ptr++ : EOF) + +//#define r_byte(p) ((p)->fp ? getc((p)->fp) : rs_byte(p)) + +static int +r_byte(RFILE *p) +{ + int c = EOF; + unsigned char ch; + int n; + + if (!p->readable) + c = p->fp ? getc(p->fp) : rs_byte(p); + else { + if (p->peeked != EOF) { + c = p->peeked; + p->peeked = EOF; + } + else { + n = r_string((char *) &ch, 1, p); + if (n > 0) + c = ch; + } + } + return c; +} + static int r_short(RFILE *p) { @@ -493,18 +542,26 @@ r_long(RFILE *p) { register long x; - register FILE *fp = p->fp; - if (fp) { - x = getc(fp); - x |= (long)getc(fp) << 8; - x |= (long)getc(fp) << 16; - x |= (long)getc(fp) << 24; + if (p->readable) { + x = r_byte(p); + x |= (long)r_byte(p) << 8; + x |= (long)r_byte(p) << 16; + x |= (long)r_byte(p) << 24; } else { - x = rs_byte(p); - x |= (long)rs_byte(p) << 8; - x |= (long)rs_byte(p) << 16; - x |= (long)rs_byte(p) << 24; + register FILE *fp = p->fp; + if (fp) { + x = getc(fp); + x |= (long)getc(fp) << 8; + x |= (long)getc(fp) << 16; + x |= (long)getc(fp) << 24; + } + else { + x = rs_byte(p); + x |= (long)rs_byte(p) << 8; + x |= (long)rs_byte(p) << 16; + x |= (long)rs_byte(p) << 24; + } } #if SIZEOF_LONG > 4 /* Sign extension for 64-bit machines */ @@ -1049,6 +1106,7 @@ { RFILE rf; assert(fp); + rf.readable = NULL; rf.fp = fp; rf.current_filename = NULL; rf.end = rf.ptr = NULL; @@ -1060,6 +1118,7 @@ { RFILE rf; rf.fp = fp; + rf.readable = NULL; rf.current_filename = NULL; rf.ptr = rf.end = NULL; return r_long(&rf); @@ -1121,6 +1180,7 @@ RFILE rf; PyObject *result; rf.fp = fp; + rf.readable = NULL; rf.current_filename = NULL; rf.depth = 0; rf.ptr = rf.end = NULL; @@ -1134,6 +1194,7 @@ RFILE rf; PyObject *result; rf.fp = NULL; + rf.readable = NULL; rf.current_filename = NULL; rf.ptr = str; rf.end = str + len; @@ -1149,6 +1210,7 @@ PyObject *res = NULL; wf.fp = NULL; + wf.readable = NULL; wf.str = PyBytes_FromStringAndSize((char *)NULL, 50); if (wf.str == NULL) return NULL; @@ -1224,26 +1286,29 @@ static PyObject * marshal_load(PyObject *self, PyObject *f) { - /* XXX Quick hack -- need to do this differently */ PyObject *data, *result; RFILE rf; - data = PyObject_CallMethod(f, "read", ""); + char *p; + int n; + + // Read one byte, so we can give a meaningful message early + // if something unexpected occurs. However, the stream may + // not be seekable, so we keep the peeked byte around, + // and r_byte will return it the first time. + data = PyObject_CallMethod(f, "read", "i", 1); if (data == NULL) return NULL; rf.fp = NULL; rf.current_filename = NULL; if (PyBytes_Check(data)) { - rf.ptr = PyBytes_AS_STRING(data); - rf.end = rf.ptr + PyBytes_GET_SIZE(data); - } - else if (PyBytes_Check(data)) { - rf.ptr = PyBytes_AS_STRING(data); - rf.end = rf.ptr + PyBytes_GET_SIZE(data); + n = PyBytes_GET_SIZE(data); + p = PyBytes_AS_STRING(data); + rf.peeked = (n > 0) ? *p : EOF; + rf.readable = f; } else { PyErr_Format(PyExc_TypeError, - "f.read() returned neither string " - "nor bytes but %.100s", + "f.read() returned not bytes but %.100s", data->ob_type->tp_name); Py_DECREF(data); return NULL; @@ -1300,6 +1365,7 @@ s = p.buf; n = p.len; rf.fp = NULL; + rf.readable = NULL; rf.current_filename = NULL; rf.ptr = s; rf.end = s + n;