diff --git a/Lib/_pyio.py b/Lib/_pyio.py --- a/Lib/_pyio.py +++ b/Lib/_pyio.py @@ -1612,19 +1612,19 @@ class TextIOWrapper(TextIOBase): if not isinstance(s, str): raise TypeError("can't write %s to text stream" % s.__class__.__name__) + if self._decoder is not None or self._snapshot is not None: + # undo readahead + pos = self.tell() + self.seek(pos) length = len(s) haslf = (self._writetranslate or self._line_buffering) and "\n" in s if haslf and self._writetranslate and self._writenl != "\n": s = s.replace("\n", self._writenl) encoder = self._encoder or self._get_encoder() - # XXX What if we were just reading? b = encoder.encode(s) self.buffer.write(b) if self._line_buffering and (haslf or "\r" in s): self.flush() - self._snapshot = None - if self._decoder: - self._decoder.reset() return length def _get_encoder(self): diff --git a/Lib/test/test_io.py b/Lib/test/test_io.py --- a/Lib/test/test_io.py +++ b/Lib/test/test_io.py @@ -2297,6 +2297,28 @@ class TextIOWrapperTest(unittest.TestCas with self.assertRaises(AttributeError): txt.buffer = buf + def test_interlaced_read_write(self): + with self.BytesIO('abcd'.encode('utf-8')) as raw: + with self.TextIOWrapper(raw, encoding='utf-8') as f: + f.write("1") + # read() must call writer.flush() + self.assertEqual(f.read(1), 'b') + # write() must rewind the raw stream + f.write('3') + self.assertEqual(f.read(), 'd') + f.flush() + self.assertEqual(raw.getvalue(), b'1b3d') + + with self.BytesIO(b'abc') as raw: + with self.TextIOWrapper(raw, encoding='utf-8') as f: + self.assertEqual(f.read(1), 'a') + # write() must undo reader readahead + f.write("2") + self.assertEqual(f.read(1), 'c') + f.flush() + self.assertEqual(raw.getvalue(), b'a2c') + + class CTextIOWrapperTest(TextIOWrapperTest): def test_initialization(self): diff --git a/Modules/_io/textio.c b/Modules/_io/textio.c --- a/Modules/_io/textio.c +++ b/Modules/_io/textio.c @@ -1269,6 +1269,21 @@ textiowrapper_write(textio *self, PyObje if (self->encoder == NULL) return _unsupported("not writable"); + if (self->decoder != NULL || self->snapshot != NULL) { + /* undo readahead */ + PyObject *pos; + pos = PyObject_CallMethodObjArgs((PyObject*)self, + _PyIO_str_tell, NULL); + if (pos == NULL) + return NULL; + ret = PyObject_CallMethodObjArgs((PyObject*)self, + _PyIO_str_seek, pos, NULL); + Py_DECREF(pos); + if (ret == NULL) + return NULL; + Py_DECREF(ret); + } + Py_INCREF(text); textlen = PyUnicode_GetSize(text); @@ -1293,7 +1308,6 @@ textiowrapper_write(textio *self, PyObje PyUnicode_GET_SIZE(text), '\r'))) needflush = 1; - /* XXX What if we were just reading? */ if (self->encodefunc != NULL) { b = (*self->encodefunc)((PyObject *) self, text); self->encoding_start_of_stream = 0; @@ -1331,15 +1345,6 @@ textiowrapper_write(textio *self, PyObje Py_DECREF(ret); } - Py_CLEAR(self->snapshot); - - if (self->decoder) { - ret = PyObject_CallMethod(self->decoder, "reset", NULL); - if (ret == NULL) - return NULL; - Py_DECREF(ret); - } - return PyLong_FromSsize_t(textlen); }