# HG changeset patch # Parent 63c5531cfdf7b433a170fbff643cd751ea9e0319 Issue #15216: Add TextIOWrapper.set_encoding() Patch by Victor Stinner, Nikolaus Rath, and Martin Panter. diff -r 63c5531cfdf7 Doc/library/io.rst --- a/Doc/library/io.rst Sat Jan 07 09:33:28 2017 +0300 +++ b/Doc/library/io.rst Wed Jan 11 09:55:09 2017 +0000 @@ -840,7 +840,6 @@ Write the string *s* to the stream and return the number of characters written. - .. class:: TextIOWrapper(buffer, encoding=None, errors=None, newline=None, \ line_buffering=False, write_through=False) @@ -901,13 +900,29 @@ locale encoding using :func:`locale.setlocale`, use the current locale encoding instead of the user preferred encoding. - :class:`TextIOWrapper` provides one attribute in addition to those of + :class:`TextIOWrapper` provides these members in addition to those of :class:`TextIOBase` and its parents: .. attribute:: line_buffering Whether line buffering is enabled. + .. method:: set_encoding(encoding=None, errors=None[, newline]) + + Change the encoding, error handler, and newline handler. + If *encoding* is None or *newline* is unspecified, the existing + setting is retained. If *errors* is None, the default depends on + *encoding*: if *encoding* is also None, the existing error handler + is retained, otherwise it is reset to ``'strict'``. + + For non-seekable streams, it may not be possible to change the + encoding if some data has already been read from the stream. + + Changing the encoding of a seekable stream may invalidate any + previous position markers obtained from :meth:`~TextIOBase.tell`. + + .. versionadded:: 3.7 + .. class:: StringIO(initial_value='', newline='\\n') diff -r 63c5531cfdf7 Lib/_pyio.py --- a/Lib/_pyio.py Sat Jan 07 09:33:28 2017 +0300 +++ b/Lib/_pyio.py Wed Jan 11 09:55:09 2017 +0000 @@ -1946,11 +1946,7 @@ self._line_buffering = line_buffering self._encoding = encoding self._errors = errors - self._readuniversal = not newline - self._readtranslate = newline is None - self._readnl = newline - self._writetranslate = newline != '' - self._writenl = newline or os.linesep + self._set_newline(newline) self._encoder = None self._decoder = None self._decoded_chars = '' # buffer for text returned from decoder @@ -1995,6 +1991,115 @@ result += " mode={0!r}".format(mode) return result + " encoding={0!r}>".format(self.encoding) + def set_encoding(self, encoding=None, errors=None, newline=Ellipsis): + """Change the encoding of the stream. + + For non-seekable streams, it may not be possible to change the encoding + if some data has already been read from the stream. + + Changing the encoding of a seekable stream may invalidate any previous + position markers obtained from `tell`. + """ + old_encoding = codecs.lookup(self._encoding).name + if encoding is None: + encoding = old_encoding + if errors is None: + errors = self._errors + else: + if not isinstance(encoding, str): + raise ValueError("invalid encoding: %r" % encoding) + + if errors is None: + errors = 'strict' + + encoding = codecs.lookup(encoding).name + if newline is Ellipsis: + newline = self._readnl + if encoding == old_encoding and errors == self._errors \ + and newline == self._readnl: + # no change + return + + pending_decoded_text = ( + self._decoded_chars + and self._decoded_chars_used != len(self._decoded_chars)) + if pending_decoded_text and not self.seekable(): + raise UnsupportedOperation( + "It is not possible to set the encoding " + "of a non seekable file after the first read") + + # flush write buffer + self.flush() + + # reset attributes + old_decoder = self._decoder or self._get_decoder() + old_b2cratio = self._b2cratio + self._encoding = encoding + self._errors = errors + self._encoder = None + self._decoder = None + self._b2cratio = 0.0 + self._set_newline(newline) + + if pending_decoded_text: + # compute the length in bytes of the characters already read + new_decoder = self._get_decoder() + dec_flags, input_chunk = self._snapshot + used = self._decoded_chars[:self._decoded_chars_used] + if old_b2cratio > 0.0: + byteslen = round(old_b2cratio * self._decoded_chars_used) + direction = 0 + else: + byteslen = 1 + direction = 1 + while True: + old_decoder.setstate((b'', dec_flags)) + try: + decoded = old_decoder.decode(input_chunk[:byteslen]) + except UnicodeDecodeError: + if direction: + byteslen += direction + else: + byteslen += 1 + else: + if len(decoded) == len(used): + assert decoded == used + break + if not direction: + if len(decoded) > len(used): + direction = -1 + else: + direction = 1 + byteslen += direction + if not(1 <= byteslen <= len(input_chunk)): + raise AssertionError("failed to compute the length " + "in bytes of the read buffer") + + # decode the tail of the read buffer using the new decoder + input_chunk = input_chunk[byteslen:] + decoded_chars = new_decoder.decode(input_chunk, False) + self._snapshot = (0, input_chunk) # New decoder starts with flags == 0 + self._set_decoded_chars(decoded_chars) + if decoded_chars: + self._b2cratio = len(input_chunk) / len(decoded_chars) + + # don't write a BOM in the middle of a file + if self._seekable and self.writable(): + position = self.buffer.tell() + if position != 0: + try: + self._get_encoder().setstate(0) + except LookupError: + # Sometimes the encoder doesn't exist + pass + + def _set_newline(self, newline): + self._readuniversal = not newline + self._readtranslate = newline is None + self._readnl = newline + self._writetranslate = newline != '' + self._writenl = newline or os.linesep + @property def encoding(self): return self._encoding diff -r 63c5531cfdf7 Lib/test/test_io.py --- a/Lib/test/test_io.py Sat Jan 07 09:33:28 2017 +0300 +++ b/Lib/test/test_io.py Wed Jan 11 09:55:09 2017 +0000 @@ -3222,6 +3222,178 @@ F.tell = lambda x: 0 t = self.TextIOWrapper(F(), encoding='utf-8') + def test_set_encoding_same_codec(self): + data = 'foobar\n'.encode('latin1') + raw = self.BytesIO(data) + txt = self.TextIOWrapper(raw, encoding='latin1') + self.assertEqual(txt.encoding, 'latin1') + + # Just an alias, shouldn't change anything + txt.set_encoding('ISO-8859-1') + self.assertEqual(txt.encoding, 'latin1') + + # This is an actual change + txt.set_encoding('iso8859-15') + self.assertEqual(txt.encoding, 'iso8859-15') + + def test_set_encoding_read(self): + # latin1 -> utf8 + # (latin1 can decode utf-8 encoded string) + data = 'abc\xe9\n'.encode('latin1') + 'd\xe9f\n'.encode('utf8') + raw = self.BytesIO(data) + txt = self.TextIOWrapper(raw, encoding='latin1', newline='\n') + self.assertEqual(txt.readline(), 'abc\xe9\n') + txt.set_encoding('utf-8') + self.assertEqual(txt.readline(), 'd\xe9f\n') + + # utf-16-be -> utf-32-be + # (utf-16 can decode utf-32 encoded string) + data = 'abc\n'.encode('utf-16-be') + 'def\n'.encode('utf-32-be') + raw = self.BytesIO(data) + txt = self.TextIOWrapper(raw, encoding='utf-16-be', newline='\n') + self.assertEqual(txt.readline(), 'abc\n') + txt.set_encoding('utf-32-be') + self.assertEqual(txt.readline(), 'def\n') + + # ascii/replace -> latin1/strict + data = 'abc\n'.encode('ascii') + 'd\xe9f\n'.encode('latin1') + raw = self.BytesIO(data) + txt = self.TextIOWrapper(raw, encoding='ascii', errors='replace', newline='\n') + self.assertEqual(txt.readline(), 'abc\n') + txt.set_encoding('latin1', 'strict') + self.assertEqual(txt.readline(), 'd\xe9f\n') + + # latin1 -> utf8 -> ascii -> utf8 + # (latin1 can decode utf-8 encoded string) + data = 'abc\xe9\n'.encode('latin1') + 'd\xe9f\n'.encode('utf8') + 'ghi\n'.encode('utf8') + raw = self.BytesIO(data) + txt = self.TextIOWrapper(raw, encoding='latin1', newline='\n') + self.assertEqual(txt.readline(), 'abc\xe9\n') + txt.set_encoding('utf-8') + self.assertEqual(txt.readline(), 'd\xe9f\n') + txt.set_encoding('ascii') + self.assertEqual(txt.readline(), 'ghi\n') + txt.set_encoding('utf-8') + + def test_set_encoding_read_non_seekable(self): + # ascii -> latin1 without reading before setting the new encoding + data = 'abc\xe9'.encode('latin1') + raw = self.MockUnseekableIO(data) + txt = self.TextIOWrapper(raw, encoding='ascii', newline='\n') + txt.set_encoding('latin1') + self.assertEqual(txt.readline(), 'abc\xe9') + + # setting the encoding after is read must fail + data = 'xabc\xe9\n'.encode('latin1') + 'yd\xe9f\n'.encode('utf8') + raw = self.MockUnseekableIO(data) + txt = self.TextIOWrapper(raw, encoding='latin1', newline='\n') + self.assertEqual(txt.readline(), 'xabc\xe9\n') + self.assertRaises(self.UnsupportedOperation, txt.set_encoding, 'utf-8') + + def test_set_encoding_write_fromascii(self): + # ascii has a specific encodefunc in the C implementation, + # but utf-8-sig has not. Make sure that we get rid of the + # cached encodefunc when we switch encoders. + raw = self.BytesIO() + txt = self.TextIOWrapper(raw, encoding='ascii', newline='\n') + txt.write('foo\n') + txt.set_encoding('utf-8-sig') + txt.write('\xe9\n') + txt.flush() + self.assertEqual(raw.getvalue(), b'foo\n\xc3\xa9\n') + + def test_set_encoding_write(self): + # latin -> utf8 + raw = self.BytesIO() + txt = self.TextIOWrapper(raw, encoding='latin1', newline='\n') + txt.write('abc\xe9\n') + txt.set_encoding('utf-8') + self.assertEqual(raw.getvalue(), b'abc\xe9\n') + txt.write('d\xe9f\n') + txt.flush() + self.assertEqual(raw.getvalue(), b'abc\xe9\nd\xc3\xa9f\n') + + # ascii -> utf-8-sig: ensure that no BOM is written in the middle of + # the file + raw = self.BytesIO() + txt = self.TextIOWrapper(raw, encoding='ascii', newline='\n') + txt.write('abc\n') + txt.set_encoding('utf-8-sig') + txt.write('d\xe9f\n') + txt.flush() + self.assertEqual(raw.getvalue(), b'abc\nd\xc3\xa9f\n') + + def test_set_encoding_write_non_seekable(self): + raw = self.BytesIO() + raw.seekable = lambda: False + raw.seek = None + txt = self.TextIOWrapper(raw, encoding='ascii', newline='\n') + txt.write('abc\n') + txt.set_encoding('utf-8-sig') + txt.write('d\xe9f\n') + txt.flush() + + # If the raw stream is not seekable, there'll be a BOM + self.assertEqual(raw.getvalue(), b'abc\n\xef\xbb\xbfd\xc3\xa9f\n') + + def test_set_encoding_defaults(self): + txt = self.TextIOWrapper(self.BytesIO(), 'ascii', 'replace', '\n') + txt.set_encoding(None, None) + self.assertEqual(txt.encoding, 'ascii') + self.assertEqual(txt.errors, 'replace') + txt.write('LF\n') + + txt.set_encoding(newline='\r\n') + self.assertEqual(txt.encoding, 'ascii') + self.assertEqual(txt.errors, 'replace') + + txt.set_encoding(errors='ignore') + self.assertEqual(txt.encoding, 'ascii') + txt.write('CRLF\n') + + txt.set_encoding(encoding='utf-8', newline=None) + self.assertEqual(txt.errors, 'strict') + txt.seek(0) + self.assertEqual(txt.read(), 'LF\nCRLF\n') + + self.assertEqual(txt.detach().getvalue(), b'LF\nCRLF\r\n') + + def test_set_encoding_newline(self): + raw = self.BytesIO(b'CR\rEOF') + txt = self.TextIOWrapper(raw, 'ascii', newline='\n') + txt.set_encoding(newline=None) + self.assertEqual(txt.readline(), 'CR\n') + raw = self.BytesIO(b'CR\rEOF') + txt = self.TextIOWrapper(raw, 'ascii', newline='\n') + txt.set_encoding(newline='') + self.assertEqual(txt.readline(), 'CR\r') + raw = self.BytesIO(b'CR\rLF\nEOF') + txt = self.TextIOWrapper(raw, 'ascii', newline='\r') + txt.set_encoding(newline='\n') + self.assertEqual(txt.readline(), 'CR\rLF\n') + raw = self.BytesIO(b'LF\nCR\rEOF') + txt = self.TextIOWrapper(raw, 'ascii', newline='\n') + txt.set_encoding(newline='\r') + self.assertEqual(txt.readline(), 'LF\nCR\r') + raw = self.BytesIO(b'CR\rCRLF\r\nEOF') + txt = self.TextIOWrapper(raw, 'ascii', newline='\r') + txt.set_encoding(newline='\r\n') + self.assertEqual(txt.readline(), 'CR\rCRLF\r\n') + + txt = self.TextIOWrapper(self.BytesIO(), 'ascii', newline='\r') + txt.set_encoding(newline=None) + txt.write('linesep\n') + txt.set_encoding(newline='') + txt.write('LF\n') + txt.set_encoding(newline='\n') + txt.write('LF\n') + txt.set_encoding(newline='\r') + txt.write('CR\n') + txt.set_encoding(newline='\r\n') + txt.write('CRLF\n') + expected = 'linesep' + os.linesep + 'LF\nLF\nCR\rCRLF\r\n' + self.assertEqual(txt.detach().getvalue().decode('ascii'), expected) + class MemviewBytesIO(io.BytesIO): '''A BytesIO object whose read method returns memoryviews diff -r 63c5531cfdf7 Modules/_io/textio.c --- a/Modules/_io/textio.c Sat Jan 07 09:33:28 2017 +0300 +++ b/Modules/_io/textio.c Wed Jan 11 09:55:09 2017 +0000 @@ -639,7 +639,7 @@ PyObject *decoder; PyObject *readnl; PyObject *errors; - const char *writenl; /* utf-8 encoded, NULL stands for \n */ + const char *writenl; /* ASCII-encoded; NULL stands for \n */ char line_buffering; char write_through; char readuniversal; @@ -785,6 +785,157 @@ {NULL, NULL} }; +static int +validate_newline(const char *newline) +{ + if (newline && newline[0] != '\0' + && !(newline[0] == '\n' && newline[1] == '\0') + && !(newline[0] == '\r' && newline[1] == '\0') + && !(newline[0] == '\r' && newline[1] == '\n' && newline[2] == '\0')) { + PyErr_Format(PyExc_ValueError, + "illegal newline value: %s", newline); + return -1; + } + return 0; +} + +static int +set_newline(textio *self, const char *newline) +{ + PyObject *old = self->readnl; + if (newline == NULL) { + self->readnl = NULL; + } + else { + self->readnl = PyUnicode_FromString(newline); + if (self->readnl == NULL) { + self->readnl = old; + return -1; + } + } + self->readuniversal = (newline == NULL || newline[0] == '\0'); + self->readtranslate = (newline == NULL); + self->writetranslate = (newline == NULL || newline[0] != '\0'); + if (!self->readuniversal && self->readnl != NULL) { + assert(PyUnicode_KIND(self->readnl) == PyUnicode_1BYTE_KIND); + self->writenl = (const char *)PyUnicode_1BYTE_DATA(self->readnl); + if (strcmp(self->writenl, "\n") == 0) { + self->writenl = NULL; + } + } + else { +#ifdef MS_WINDOWS + self->writenl = "\r\n"; +#else + self->writenl = NULL; +#endif + } + Py_XDECREF(old); + return 0; +} + +static int +_textiowrapper_set_decoder(textio *self, PyObject *codec_info, + const char *errors) +{ + PyObject *res; + int r; + + res = _PyObject_CallMethodId(self->buffer, &PyId_readable, NULL); + if (res == NULL) + return -1; + + r = PyObject_IsTrue(res); + Py_DECREF(res); + if (r == -1) + return -1; + + if (r != 1) + return 0; + + Py_CLEAR(self->decoder); + self->decoder = _PyCodecInfo_GetIncrementalDecoder(codec_info, errors); + if (self->decoder == NULL) + return -1; + + if (self->readuniversal) { + PyObject *incrementalDecoder = PyObject_CallFunction( + (PyObject *)&PyIncrementalNewlineDecoder_Type, + "Oi", self->decoder, (int)self->readtranslate); + if (incrementalDecoder == NULL) + return -1; + Py_CLEAR(self->decoder); + self->decoder = incrementalDecoder; + } + + return 0; +} + +static PyObject* +_textiowrapper_decode(PyObject *decoder, PyObject *bytes, int eof) +{ + PyObject *chars; + + if (Py_TYPE(decoder) == &PyIncrementalNewlineDecoder_Type) + chars = _PyIncrementalNewlineDecoder_decode(decoder, bytes, eof); + else + chars = PyObject_CallMethodObjArgs(decoder, _PyIO_str_decode, bytes, + eof ? Py_True : Py_False, NULL); + + if (check_decoded(chars) < 0) + // check_decoded already decreases refcount + return NULL; + + return chars; +} + +static int +_textiowrapper_set_encoder(textio *self, PyObject *codec_info, + const char *errors) +{ + PyObject *res; + int r; + + res = _PyObject_CallMethodId(self->buffer, &PyId_writable, NULL); + if (res == NULL) + return -1; + + r = PyObject_IsTrue(res); + Py_DECREF(res); + if (r == -1) + return -1; + + if (r != 1) + return 0; + + Py_CLEAR(self->encoder); + self->encodefunc = NULL; + self->encoder = _PyCodecInfo_GetIncrementalEncoder(codec_info, errors); + if (self->encoder == NULL) + return -1; + + /* Get the normalized named of the codec */ + res = _PyObject_GetAttrId(codec_info, &PyId_name); + if (res == NULL) { + if (PyErr_ExceptionMatches(PyExc_AttributeError)) + PyErr_Clear(); + else + return -1; + } + else if (PyUnicode_Check(res)) { + const encodefuncentry *e = encodefuncs; + while (e->name != NULL) { + if (_PyUnicode_EqualToASCIIString(res, e->name)) { + self->encodefunc = e->encodefunc; + break; + } + e++; + } + } + Py_XDECREF(res); + + return 0; +} /*[clinic input] _io.TextIOWrapper.__init__ @@ -840,12 +991,7 @@ self->ok = 0; self->detached = 0; - if (newline && newline[0] != '\0' - && !(newline[0] == '\n' && newline[1] == '\0') - && !(newline[0] == '\r' && newline[1] == '\0') - && !(newline[0] == '\r' && newline[1] == '\n' && newline[2] == '\0')) { - PyErr_Format(PyExc_ValueError, - "illegal newline value: %s", newline); + if (validate_newline(newline) < 0) { return -1; } @@ -953,92 +1099,26 @@ goto error; self->chunk_size = 8192; - self->readuniversal = (newline == NULL || newline[0] == '\0'); self->line_buffering = line_buffering; self->write_through = write_through; - self->readtranslate = (newline == NULL); - if (newline) { - self->readnl = PyUnicode_FromString(newline); - if (self->readnl == NULL) - goto error; + if (set_newline(self, newline) < 0) { + goto error; } - self->writetranslate = (newline == NULL || newline[0] != '\0'); - if (!self->readuniversal && self->readnl) { - self->writenl = PyUnicode_AsUTF8(self->readnl); - if (self->writenl == NULL) - goto error; - if (!strcmp(self->writenl, "\n")) - self->writenl = NULL; - } -#ifdef MS_WINDOWS - else - self->writenl = "\r\n"; -#endif - + + self->buffer = buffer; + Py_INCREF(buffer); + /* Build the decoder object */ - res = _PyObject_CallMethodId(buffer, &PyId_readable, NULL); - if (res == NULL) + if (_textiowrapper_set_decoder(self, codec_info, errors) != 0) goto error; - r = PyObject_IsTrue(res); - Py_DECREF(res); - if (r == -1) + + /* Build the encoder object */ + if (_textiowrapper_set_encoder(self, codec_info, errors) != 0) goto error; - if (r == 1) { - self->decoder = _PyCodecInfo_GetIncrementalDecoder(codec_info, - errors); - if (self->decoder == NULL) - goto error; - - if (self->readuniversal) { - PyObject *incrementalDecoder = PyObject_CallFunction( - (PyObject *)&PyIncrementalNewlineDecoder_Type, - "Oi", self->decoder, (int)self->readtranslate); - if (incrementalDecoder == NULL) - goto error; - Py_XSETREF(self->decoder, incrementalDecoder); - } - } - - /* Build the encoder object */ - res = _PyObject_CallMethodId(buffer, &PyId_writable, NULL); - if (res == NULL) - goto error; - r = PyObject_IsTrue(res); - Py_DECREF(res); - if (r == -1) - goto error; - if (r == 1) { - self->encoder = _PyCodecInfo_GetIncrementalEncoder(codec_info, - errors); - if (self->encoder == NULL) - goto error; - /* Get the normalized named of the codec */ - res = _PyObject_GetAttrId(codec_info, &PyId_name); - if (res == NULL) { - if (PyErr_ExceptionMatches(PyExc_AttributeError)) - PyErr_Clear(); - else - goto error; - } - else if (PyUnicode_Check(res)) { - const encodefuncentry *e = encodefuncs; - while (e->name != NULL) { - if (_PyUnicode_EqualToASCIIString(res, e->name)) { - self->encodefunc = e->encodefunc; - break; - } - e++; - } - } - Py_XDECREF(res); - } /* Finished sorting out the codec details */ Py_CLEAR(codec_info); - self->buffer = buffer; - Py_INCREF(buffer); - if (Py_TYPE(buffer) == &PyBufferedReader_Type || Py_TYPE(buffer) == &PyBufferedWriter_Type || Py_TYPE(buffer) == &PyBufferedRandom_Type) { @@ -1370,6 +1450,337 @@ self->decoded_chars_used = 0; } +static PyObject* +_textiowrapper_canonical_codec_name(PyObject *codec_name) +{ + char *c_name = NULL; + PyObject *codec_obj = NULL; + PyObject *canonical_name = NULL; + + c_name = PyUnicode_AsUTF8(codec_name); + if (c_name == NULL) + goto err_out; + + codec_obj = _PyCodec_Lookup(c_name); + if (codec_obj == NULL) + goto err_out; + + canonical_name = PyObject_GetAttrString(codec_obj, "name"); + Py_CLEAR(codec_obj); + if (canonical_name == NULL) + goto err_out; + + return canonical_name; + + err_out: + Py_CLEAR(canonical_name); + return NULL; +} + +PyDoc_STRVAR(set_encoding_doc, +"set_encoding(encoding=None, errors=None[, newline])\n" +"\n" +"Change the encoding of the stream.\n" +"\n" +" encoding\n" +" Name of new encoding to use.\n" +" errors\n" +" New error handler to use.\n" +" newline\n" +" New newline handler.\n" +"\n" +"For non-seekable streams, it may not be possible to change the encoding if some\n" +"data has already been read from the stream.\n" +"\n" +"Changing the encoding of a seekable stream may invalidate any previous\n" +"position markers obtained from `tell`."); + +static PyObject * +set_encoding(PyObject *selfobj, PyObject *posargs, PyObject *kwargs) +{ + PyObject *encoding = Py_None; + const char *errors = NULL; + const char *newline = (const char *)&newline; /* Unique non-NULL value */ + + static char *keywords[] = {"encoding", "errors", "newline", NULL}; + if (!PyArg_ParseTupleAndKeywords( + posargs, kwargs, "|Ozz:set_encoding", keywords, + &encoding, &errors, &newline)) { + return NULL; + } + + char pending_decoded_text = 0; + PyObject *old_decoder = NULL; + double old_b2cratio; + char res; + PyObject *encoding_cname, *old_encoding_cname; // canonical name + + textio *self = (textio *)selfobj; + CHECK_INITIALIZED(self); + + /* Use existing settings where new settings are not specified */ + if (encoding == Py_None) { + encoding = self->encoding; + if (errors == NULL) { + errors = PyBytes_AS_STRING(self->errors); + } + } + else if (errors == NULL) { + errors = "strict"; + } + if (newline == (const char *)&newline) { + if (self->readnl == NULL) { + newline = NULL; + } + else { + assert(PyUnicode_KIND(self->readnl) == PyUnicode_1BYTE_KIND); + newline = (const char *)PyUnicode_1BYTE_DATA(self->readnl); + } + } + else if (validate_newline(newline) < 0) { + return NULL; + } + + /* Get the normalized named of the old and new codec */ + encoding_cname = _textiowrapper_canonical_codec_name(encoding); + if (encoding_cname == NULL) + return NULL; + old_encoding_cname = _textiowrapper_canonical_codec_name(self->encoding); + if (old_encoding_cname == NULL) { + Py_CLEAR(encoding_cname); + return NULL; + } + + /* Compare with current codec and error handler */ + res = (PyUnicode_Compare(encoding_cname, old_encoding_cname) == 0); + Py_CLEAR(encoding_cname); + Py_CLEAR(old_encoding_cname); + if (res && strcmp(PyBytes_AS_STRING(self->errors), errors) == 0 && ( + (newline == NULL && self->readnl == NULL) + || (newline != NULL && self->readnl != NULL + && PyUnicode_CompareWithASCIIString(self->readnl, newline) == 0) + )) { + // No change + Py_RETURN_NONE; + } + /* Check if something is in the read buffer */ + if (self->decoded_chars) { + Py_ssize_t strlen; + strlen = PyUnicode_GetLength(self->decoded_chars); + if (strlen < 0) + return NULL; + if(self->decoded_chars_used != strlen) + pending_decoded_text = 1; + } + + if (pending_decoded_text && !self->seekable) { + _unsupported("It is not possible to set the encoding " + "of a non seekable file after the first read"); + return NULL; + } + + // Flush write buffer + if (_textiowrapper_writeflush(self) != 0) + return NULL; + + // Reset attributes + if(pending_decoded_text) { + old_decoder = self->decoder; + Py_INCREF(old_decoder); + } + + old_b2cratio = self->b2cratio; + self->b2cratio = 0; + + PyObject *old_encoding = self->encoding; + self->encoding = encoding; + Py_INCREF(self->encoding); + Py_DECREF(old_encoding); + + if (errors != PyBytes_AS_STRING(self->errors)) { + PyObject *new = PyBytes_FromString(errors); + if (new == NULL) { + return NULL; + } + self->errors = new; + } + + if (set_newline(self, newline) < 0) { + return NULL; + } + + // Create new encoder & decoder + PyObject *codec_info = _PyCodec_LookupTextEncoding( + PyUnicode_AsUTF8(encoding), "codecs.open()"); + if (codec_info == NULL) { + return NULL; + } + if (_textiowrapper_set_decoder(self, codec_info, errors) != 0 || + _textiowrapper_set_encoder(self, codec_info, errors) != 0) { + Py_DECREF(codec_info); + return NULL; + } + Py_DECREF(codec_info); + + if (pending_decoded_text) { + // Compute the length in bytes of the characters already read + PyObject *dec_flags = NULL, *input_chunk = NULL, + *decoded_chars = NULL, *decoded_bytes = NULL, + *remaining_bytes = NULL, *res = NULL; + Py_ssize_t cons_input_len, nchars, input_len, direction; + char *c_input_chunk, err_out; + + if (!PyArg_ParseTuple(self->snapshot, "OO", &dec_flags, &input_chunk)) + goto fail; + + if(PyBytes_AsStringAndSize(input_chunk, &c_input_chunk, + &input_len) != 0) + goto fail; + + /* Estimate the number of bytes that need to be decoded to produce the + characters that have already been consumed. */ + if (old_b2cratio > 0) { + cons_input_len = old_b2cratio * self->decoded_chars_used; + direction = 0; + } + else { + cons_input_len = 1; + direction = 1; + } + while(1) { + if (cons_input_len < 1 || cons_input_len > input_len) { + PyErr_SetString(PyExc_AssertionError, "failed to compute " + "the length in bytes of the read buffer"); + goto fail; + } + + // Restore decoder state to value from beginning of chunk + res = _PyObject_CallMethodId(old_decoder, &PyId_setstate, + "((yi))", "", dec_flags); + if (res == NULL) + goto fail; + Py_CLEAR(res); + + /* Extract first *cons_input_len* bytes from the input chunk for + decoding */ + Py_CLEAR(decoded_bytes); + decoded_bytes = PyBytes_FromStringAndSize(c_input_chunk, cons_input_len); + if (decoded_bytes == NULL) + goto fail; + + // Decode *cons_input_len* bytes from input chunk + Py_CLEAR(decoded_chars); + decoded_chars = _textiowrapper_decode(old_decoder, decoded_bytes, 0); + if (decoded_chars == NULL) { + if (PyErr_ExceptionMatches(PyExc_UnicodeError)) { + // This substring can't be decoded, try to decode with an + // additional byte + cons_input_len += direction ? direction : 1; + PyErr_Clear(); + } + else + goto fail; + } + else { + // Decoding was successful + int decoded_len; + decoded_len = PyUnicode_GetLength(decoded_chars); + if(decoded_len == self->decoded_chars_used) + /* We decoded exactly as many character as were already + consumed. New decoder should thus start from this + position. */ + // assert decoded_chars == input_chunk[:cons_input_len] + break; + /* If we got too many or too little bytes, update our guess + for *cons_input_len* accordingly */ + direction = direction || + ((decoded_len > self->decoded_chars_used) ? -1 : 1); + cons_input_len += direction; + } + } + // ok here + + // Decode the tail of the read buffer using the new decoder + input_len -= cons_input_len; + remaining_bytes = PyBytes_FromStringAndSize(c_input_chunk + cons_input_len, + input_len); + if (remaining_bytes == NULL) + goto fail; + + Py_CLEAR(decoded_chars); + decoded_chars = _textiowrapper_decode(self->decoder, remaining_bytes, 0); + if (decoded_chars == NULL) + goto fail; + + nchars = PyUnicode_GetLength(decoded_chars); + textiowrapper_set_decoded_chars(self, decoded_chars); + decoded_chars = NULL; + + if (nchars > 0) + self->b2cratio = (double) input_len / nchars; + + // Decoder flags are zero for a fresh decoder + Py_CLEAR(self->snapshot); + self->snapshot = Py_BuildValue("iN", 0, remaining_bytes); + if (self->snapshot == NULL) + goto fail; + + err_out = 0; + goto clear; + fail: + err_out = 1; + clear: + Py_CLEAR(decoded_bytes); + Py_CLEAR(decoded_chars); + Py_CLEAR(old_decoder); + if (err_out) + return NULL; + } + + if (self->seekable) { + char writeable; + PyObject *res; + + res = _PyObject_CallMethodId(self->buffer, &PyId_writable, NULL); + if (res == NULL) + return NULL; + writeable = PyObject_IsTrue(res); + Py_DECREF(res); + + if (writeable) { + PyObject *posobj = NULL; + char cmp; + posobj = _PyObject_CallMethodId(self->buffer, &PyId_tell, NULL); + if (posobj == NULL) + return NULL; + + /* We have a writable, seekable stream. Check if we're at the + beginning */ + cmp = PyObject_RichCompareBool(posobj, _PyIO_zero, Py_EQ); + Py_DECREF(posobj); + if (cmp < 0) + return NULL; + + // don't write a BOM in the middle of a file + if (cmp) + self->encoding_start_of_stream = 1; + else { + /* FIXME: How do we know that zero is the right state to not + emit a BOM for any encoder? */ + PyObject *res; + self->encoding_start_of_stream = 0; + res = PyObject_CallMethodObjArgs(self->encoder, _PyIO_str_setstate, + _PyIO_zero, NULL); + if (res == NULL) + return NULL; + Py_DECREF(res); + } + } /* writeable */ + } /* seekable */ + + Py_RETURN_NONE; +} + static PyObject * textiowrapper_get_decoded_chars(textio *self, Py_ssize_t n) { @@ -1483,18 +1894,12 @@ nbytes = input_chunk_buf.len; eof = (nbytes == 0); - if (Py_TYPE(self->decoder) == &PyIncrementalNewlineDecoder_Type) { - decoded_chars = _PyIncrementalNewlineDecoder_decode( - self->decoder, input_chunk, eof); - } - else { - decoded_chars = PyObject_CallMethodObjArgs(self->decoder, - _PyIO_str_decode, input_chunk, eof ? Py_True : Py_False, NULL); - } + + decoded_chars = _textiowrapper_decode(self->decoder, input_chunk, eof); PyBuffer_Release(&input_chunk_buf); - - if (check_decoded(decoded_chars) < 0) + if (decoded_chars == NULL) goto fail; + textiowrapper_set_decoded_chars(self, decoded_chars); nchars = PyUnicode_GET_LENGTH(decoded_chars); if (nchars > 0) @@ -2843,6 +3248,8 @@ {"__getstate__", (PyCFunction)textiowrapper_getstate, METH_NOARGS}, _IO_TEXTIOWRAPPER_SEEK_METHODDEF + {"set_encoding", (PyCFunction)set_encoding, METH_KEYWORDS | METH_VARARGS, + set_encoding_doc}, _IO_TEXTIOWRAPPER_TELL_METHODDEF _IO_TEXTIOWRAPPER_TRUNCATE_METHODDEF {NULL, NULL}