diff --git a/Lib/_pyio.py b/Lib/_pyio.py --- a/Lib/_pyio.py +++ b/Lib/_pyio.py @@ -1051,6 +1051,16 @@ class BufferedReader(_BufferedIOMixin): self._reset_read_buf() return pos + def flush(self): + self._checkClosed() + if not self.writable(): + return + if self._read_buf: + # Undo readahead + with self._read_lock: + self.raw.seek(self._read_pos - len(self._read_buf), 1) + self._reset_read_buf() + class BufferedWriter(_BufferedIOMixin): """A buffer for a writeable sequential RawIO object. @@ -1190,18 +1200,23 @@ class BufferedRWPair(BufferedIOBase): def read(self, n=None): if n is None: n = -1 + self.writer.flush() return self.reader.read(n) def readinto(self, b): + self.writer.flush() return self.reader.readinto(b) def write(self, b): + self.reader.flush() return self.writer.write(b) def peek(self, n=0): + self.writer.flush() return self.reader.peek(n) def read1(self, n): + self.writer.flush() return self.reader.read1(n) def readable(self): @@ -1290,11 +1305,7 @@ class BufferedRandom(BufferedWriter, Buf return BufferedReader.read1(self, n) def write(self, b): - if self._read_buf: - # Undo readahead - with self._read_lock: - self.raw.seek(self._read_pos - len(self._read_buf), 1) - self._reset_read_buf() + BufferedReader.flush(self) return BufferedWriter.write(self, b) diff --git a/Lib/test/test_io.py b/Lib/test/test_io.py --- a/Lib/test/test_io.py +++ b/Lib/test/test_io.py @@ -1349,6 +1349,34 @@ class BufferedRWPairTest(unittest.TestCa pair = self.tp(SelectableIsAtty(True), SelectableIsAtty(True)) self.assertTrue(pair.isatty()) + def test_interlaced_read_write(self): + with self.BytesIO(b'abcdefgh') as raw: + with self.tp(raw, raw) as f: + f.write(b"1") + # read() must call writer.flush() + self.assertEqual(f.read(1), b'b') + # write() must rewind the raw stream + f.write(b'2') + # read1() must call writer.flush() + self.assertEqual(f.read1(1), b'd') + f.write(b'3') + # readinto() must call writer.flush() + buffer = bytearray(1) + f.readinto(buffer) + self.assertEqual(buffer, b'f') + f.write(b'4') + # peek() must call writer.flush() + self.assertEqual(f.peek(1), b'h') + f.flush() + self.assertEqual(raw.getvalue(), b'1b2d3f4h') + + with self.BytesIO(b'abc') as raw: + with self.tp(raw, raw) as f: + self.assertEqual(f.read(1), b'a') + # write() must undo reader readahead + f.write(b"2") + self.assertEqual(f.read(1), b'c') + class CBufferedRWPairTest(BufferedRWPairTest): tp = io.BufferedRWPair @@ -1537,6 +1565,34 @@ class BufferedRandomTest(BufferedReaderT BufferedReaderTest.test_misbehaved_io(self) BufferedWriterTest.test_misbehaved_io(self) + def test_interlaced_read_write(self): + with self.BytesIO(b'abcdefgh') as raw: + with self.tp(raw, 100) as f: + f.write(b"1") + # read() must call flush() + self.assertEqual(f.read(1), b'b') + # write() must rewind the raw stream + f.write(b'2') + # read1() must call flush() + self.assertEqual(f.read1(1), b'd') + f.write(b'3') + # readinto() must call flush() + buffer = bytearray(1) + f.readinto(buffer) + self.assertEqual(buffer, b'f') + f.write(b'4') + # peek() must call flush() + self.assertEqual(f.peek(1), b'h') + f.flush() + self.assertEqual(raw.getvalue(), b'1b2d3f4h') + + with self.BytesIO(b'abc') as raw: + with self.tp(raw, 100) as f: + self.assertEqual(f.read(1), b'a') + # write() must undo reader readahead + f.write(b"2") + self.assertEqual(f.read(1), b'c') + # You can't construct a BufferedRandom over a non-seekable stream. test_unseekable = None diff --git a/Modules/_io/bufferedio.c b/Modules/_io/bufferedio.c --- a/Modules/_io/bufferedio.c +++ b/Modules/_io/bufferedio.c @@ -753,6 +753,35 @@ _trap_eintr(void) */ static PyObject * +buffered_flush_unlocked(buffered *self) +{ + PyObject *res; + + if (self->writable) { + if (!VALID_WRITE_BUFFER(self) || self->write_pos == self->write_end) + Py_RETURN_NONE; + res = _bufferedwriter_flush_unlocked(self, 0); + if (res == NULL) + return NULL; + } + else { + Py_INCREF(Py_None); + res = Py_None; + } + + if (self->readable) { + /* Rewind the raw stream so that its position corresponds to + the current logical position. */ + Py_off_t n; + n = _buffered_raw_seek(self, -RAW_OFFSET(self), 1); + if (n == -1) + Py_CLEAR(res); + _bufferedreader_reset_buf(self); + } + return res; +} + +static PyObject * buffered_flush(buffered *self, PyObject *args) { PyObject *res; @@ -762,16 +791,7 @@ buffered_flush(buffered *self, PyObject if (!ENTER_BUFFERED(self)) return NULL; - res = _bufferedwriter_flush_unlocked(self, 0); - if (res != NULL && self->readable) { - /* Rewind the raw stream so that its position corresponds to - the current logical position. */ - Py_off_t n; - n = _buffered_raw_seek(self, -RAW_OFFSET(self), 1); - if (n == -1) - Py_CLEAR(res); - _bufferedreader_reset_buf(self); - } + res = buffered_flush_unlocked(self); LEAVE_BUFFERED(self) return res; @@ -826,18 +846,37 @@ buffered_read(buffered *self, PyObject * /* The number of bytes is unspecified, read until the end of stream */ if (!ENTER_BUFFERED(self)) return NULL; - res = _bufferedreader_read_all(self); + if (self->writable) { + res = buffered_flush_unlocked(self); + if (res != NULL) { + Py_DECREF(res); + res = _bufferedreader_read_all(self); + } + } + else + res = _bufferedreader_read_all(self); LEAVE_BUFFERED(self) } else { - res = _bufferedreader_read_fast(self, n); - if (res == Py_None) { + if (!self->writable) { + res = _bufferedreader_read_fast(self, n); + if (res != Py_None) + return res; Py_DECREF(res); - if (!ENTER_BUFFERED(self)) - return NULL; + } + + if (!ENTER_BUFFERED(self)) + return NULL; + if (self->writable) { + res = buffered_flush_unlocked(self); + if (res != NULL) { + Py_DECREF(res); + res = _bufferedreader_read_generic(self, n); + } + } + else res = _bufferedreader_read_generic(self, n); - LEAVE_BUFFERED(self) - } + LEAVE_BUFFERED(self) } return res; @@ -2218,30 +2257,105 @@ _forward_call(buffered *self, const char static PyObject * bufferedrwpair_read(rwpair *self, PyObject *args) { + Py_ssize_t n = -1; + PyObject *res; + if (!PyArg_ParseTuple(args, "|O&:read", &_PyIO_ConvertSsize_t, &n)) { + return NULL; + } + if (n < -1) { + PyErr_SetString(PyExc_ValueError, + "read length must be positive or -1"); + return NULL; + } + + res = buffered_flush(self->writer, NULL); + if (res == NULL) + return NULL; + Py_DECREF(res); + return _forward_call(self->reader, "read", args); } static PyObject * bufferedrwpair_peek(rwpair *self, PyObject *args) { + Py_ssize_t n = 0; + PyObject *res; + if (!PyArg_ParseTuple(args, "|n:peek", &n)) + return NULL; + + res = buffered_flush(self->writer, NULL); + if (res == NULL) + return NULL; + Py_DECREF(res); + return _forward_call(self->reader, "peek", args); } static PyObject * bufferedrwpair_read1(rwpair *self, PyObject *args) { + Py_ssize_t n; + PyObject *res; + + CHECK_INITIALIZED(self->reader) + if (!PyArg_ParseTuple(args, "n:read1", &n)) { + return NULL; + } + + if (n < 0) { + PyErr_SetString(PyExc_ValueError, + "read length must be positive"); + return NULL; + } + if (n == 0) + return PyBytes_FromStringAndSize(NULL, 0); + + res = buffered_flush(self->writer, NULL); + if (res == NULL) + return NULL; + Py_DECREF(res); + + res = buffered_flush(self->reader, NULL); + if (res == NULL) + return NULL; + Py_DECREF(res); + return _forward_call(self->reader, "read1", args); } static PyObject * bufferedrwpair_readinto(rwpair *self, PyObject *args) { + Py_buffer buf; + PyObject *res; + if (!PyArg_ParseTuple(args, "w*:readinto", &buf)) + return NULL; + PyBuffer_Release(&buf); + + res = buffered_flush(self->writer, NULL); + if (res == NULL) + return NULL; + Py_DECREF(res); + return _forward_call(self->reader, "readinto", args); } static PyObject * bufferedrwpair_write(rwpair *self, PyObject *args) { + Py_buffer buf; + PyObject *res; + if (!PyArg_ParseTuple(args, "y*:write", &buf)) { + return NULL; + } + PyBuffer_Release(&buf); + + res = buffered_flush(self->reader, NULL); + if (res == NULL) + return NULL; + Py_DECREF(res); + return _forward_call(self->writer, "write", args); }