diff --git a/Doc/library/io.rst b/Doc/library/io.rst --- a/Doc/library/io.rst +++ b/Doc/library/io.rst @@ -848,6 +848,18 @@ Whether line buffering is enabled. + .. method:: set_encoding(encoding, errors=None) + + Change the encoding of a stream to *encoding* and the error + handler to *errors* (or *strict*, if *errors* is None). + + For non-seekable streams it may not be possible to change the + encoding if some data has already been read from the stream. + + Changing the encoding of a seekable stream may invalidate any + previous position markers obtained from `tell`. + + .. versionadded:: 3.5 .. class:: StringIO(initial_value='', newline=None) diff --git a/Lib/_pyio.py b/Lib/_pyio.py --- a/Lib/_pyio.py +++ b/Lib/_pyio.py @@ -1561,6 +1561,99 @@ result += " mode={0!r}".format(mode) return result + " encoding={0!r}>".format(self.encoding) + def set_encoding(self, encoding, errors=None): + '''Change the encoding of the stream. + + For non-seekable streams it may not be possible to change the encoding + if some data has already been read from the stream. + + Changing the encoding of a seekable stream may invalidate any previous + position markers obtained from `tell`. + ''' + + if not isinstance(encoding, str): + raise ValueError("invalid encoding: %r" % encoding) + + if errors is None: + errors = 'strict' + + old_encoding = codecs.lookup(self._encoding).name + encoding = codecs.lookup(encoding).name + if encoding == old_encoding and errors == self._errors: + # no change + return + + pending_decoded_text = ( + self._decoded_chars + and self._decoded_chars_used != len(self._decoded_chars)) + if pending_decoded_text and not self.seekable(): + raise UnsupportedOperation( + "It is not possible to set the encoding " + "of a non seekable file after the first read") + + # flush write buffer + self.flush() + + # reset attributes + old_decoder = self._decoder or self._get_decoder() + old_b2cratio = self._b2cratio + self._encoding = encoding + self._errors = errors + self._encoder = None + self._decoder = None + self._b2cratio = 0.0 + + if pending_decoded_text: + # compute the length in bytes of the characters already read + new_decoder = self._get_decoder() + dec_flags, input_chunk = self._snapshot + used = self._decoded_chars[:self._decoded_chars_used] + if old_b2cratio > 0.0: + byteslen = round(old_b2cratio * self._decoded_chars_used) + direction = 0 + else: + byteslen = 1 + direction = 1 + while True: + old_decoder.setstate((b'', dec_flags)) + try: + decoded = old_decoder.decode(input_chunk[:byteslen]) + except UnicodeDecodeError: + if direction: + byteslen += direction + else: + byteslen += 1 + else: + if len(decoded) == len(used): + assert decoded == used + break + if not direction: + if len(decoded) > len(used): + direction = -1 + else: + direction = 1 + byteslen += direction + if not(1 <= byteslen <= len(input_chunk)): + raise AssertionError("failed to compute the length in bytes of the read buffer") + + # decode the tail of the read buffer using the new decoder + input_chunk = input_chunk[byteslen:] + decoded_chars = new_decoder.decode(input_chunk, False) + self._snapshot = (dec_flags, input_chunk) + self._set_decoded_chars(decoded_chars) + if decoded_chars: + self._b2cratio = len(input_chunk) / len(decoded_chars) + + # don't write a BOM in the middle of a file + if self._seekable and self.writable(): + position = self.buffer.tell() + if position != 0: + try: + self._get_encoder().setstate(0) + except LookupError: + # Sometimes the encoder doesn't exist + pass + @property def encoding(self): return self._encoding diff --git a/Lib/test/test_io.py b/Lib/test/test_io.py --- a/Lib/test/test_io.py +++ b/Lib/test/test_io.py @@ -2628,6 +2628,97 @@ self.assertEqual("ok", out.decode().strip()) + def test_set_encoding_same_codec(self): + data = 'foobar\n'.encode('latin1') + raw = self.BytesIO(data) + txt = self.TextIOWrapper(raw, encoding='latin1') + self.assertEqual(txt.encoding, 'latin1') + + # Just an alias, shouldn't change anything + txt.set_encoding('ISO-8859-1') + self.assertEqual(txt.encoding, 'latin1') + + # This is an actual change + txt.set_encoding('iso-8859-15') + self.assertEqual(txt.encoding, 'iso-8859-15') + + def test_set_encoding_read(self): + # latin1 -> utf8 + # (latin1 can decode utf-8 encoded string) + data = 'abc\xe9\n'.encode('latin1') + 'd\xe9f\n'.encode('utf8') + raw = self.BytesIO(data) + txt = self.TextIOWrapper(raw, encoding='latin1', newline='\n') + self.assertEqual(txt.readline(), 'abc\xe9\n') + txt.set_encoding('utf-8') + self.assertEqual(txt.readline(), 'd\xe9f\n') + + # utf-16-be -> utf-32-be + # (utf-16 can decode utf-32 encoded string) + data = 'abc\n'.encode('utf-16-be') + 'def\n'.encode('utf-32-be') + raw = self.BytesIO(data) + txt = self.TextIOWrapper(raw, encoding='utf-16-be', newline='\n') + self.assertEqual(txt.readline(), 'abc\n') + txt.set_encoding('utf-32-be') + self.assertEqual(txt.readline(), 'def\n') + + # ascii/replace -> latin1/strict + data = 'abc\n'.encode('ascii') + 'd\xe9f\n'.encode('latin1') + raw = self.BytesIO(data) + txt = self.TextIOWrapper(raw, encoding='ascii', errors='replace', newline='\n') + self.assertEqual(txt.readline(), 'abc\n') + txt.set_encoding('latin1', 'strict') + self.assertEqual(txt.readline(), 'd\xe9f\n') + + # latin1 -> utf8 -> ascii -> utf8 + # (latin1 can decode utf-8 encoded string) + data = 'abc\xe9\n'.encode('latin1') + 'd\xe9f\n'.encode('utf8') + 'ghi\n'.encode('utf8') + raw = self.BytesIO(data) + txt = self.TextIOWrapper(raw, encoding='latin1', newline='\n') + self.assertEqual(txt.readline(), 'abc\xe9\n') + txt.set_encoding('utf-8') + self.assertEqual(txt.readline(), 'd\xe9f\n') + txt.set_encoding('ascii') + self.assertEqual(txt.readline(), 'ghi\n') + txt.set_encoding('utf-8') + + + def test_set_encoding_read_non_seekable(self): + # ascii -> latin1 without reading before setting the new encoding + data = 'abc\xe9'.encode('latin1') + raw = self.MockUnseekableIO(data) + txt = self.TextIOWrapper(raw, encoding='ascii', newline='\n') + txt.set_encoding('latin1') + self.assertEqual(txt.readline(), 'abc\xe9') + + # setting the encoding after is read must fail + data = 'xabc\xe9\n'.encode('latin1') + 'yd\xe9f\n'.encode('utf8') + raw = self.MockUnseekableIO(data) + txt = self.TextIOWrapper(raw, encoding='latin1', newline='\n') + self.assertEqual(txt.readline(), 'xabc\xe9\n') + self.assertRaises(self.UnsupportedOperation, txt.set_encoding, 'utf-8') + + def test_set_encoding_write(self): + # latin -> utf8 + raw = self.BytesIO() + txt = self.TextIOWrapper(raw, encoding='latin1', newline='\n') + txt.write('abc\xe9\n') + txt.set_encoding('utf-8') + self.assertEqual(raw.getvalue(), b'abc\xe9\n') + txt.write('d\xe9f\n') + txt.flush() + self.assertEqual(raw.getvalue(), b'abc\xe9\nd\xc3\xa9f\n') + + # ascii -> utf-8-sig: ensure that no BOM is written in the middle of + # the file + raw = self.BytesIO() + txt = self.TextIOWrapper(raw, encoding='ascii', newline='\n') + txt.write('abc\n') + txt.set_encoding('utf-8-sig') + txt.write('d\xe9f\n') + txt.flush() + self.assertEqual(raw.getvalue(), b'abc\nd\xc3\xa9f\n') + + class CTextIOWrapperTest(TextIOWrapperTest): io = io diff --git a/Modules/_io/textio.c b/Modules/_io/textio.c --- a/Modules/_io/textio.c +++ b/Modules/_io/textio.c @@ -11,6 +11,11 @@ #include "structmember.h" #include "_iomodule.h" +/*[clinic input] +module io +[clinic start generated code]*/ +/*[clinic end generated code: checksum=da39a3ee5e6b4b0d3255bfef95601890afd80709]*/ + _Py_IDENTIFIER(close); _Py_IDENTIFIER(_dealloc_warn); _Py_IDENTIFIER(decode); @@ -648,6 +653,10 @@ /* TextIOWrapper */ +/*[clinic input] +class io.TextIOWrapper "textio *" "&PyTextIOWrapper_Type" +[clinic start generated code]*/ +/*[clinic end generated code: checksum=da39a3ee5e6b4b0d3255bfef95601890afd80709]*/ PyDoc_STRVAR(textiowrapper_doc, "Character and line based layer over a BufferedIOBase object, buffer.\n" @@ -842,6 +851,95 @@ {NULL, NULL} }; +static int +_textiowrapper_set_decoder(textio *self, const char *encoding, + const char *errors) +{ + PyObject *res; + int r; + + res = _PyObject_CallMethodId(self->buffer, &PyId_readable, NULL); + if (res == NULL) + return -1; + + r = PyObject_IsTrue(res); + Py_DECREF(res); + if (r == -1) + return -1; + + if (r != 1) + return 0; + + Py_CLEAR(self->decoder); + self->decoder = PyCodec_IncrementalDecoder(encoding, errors); + if (self->decoder == NULL) + return -1; + + if (self->readuniversal) { + PyObject *incrementalDecoder = PyObject_CallFunction( + (PyObject *)&PyIncrementalNewlineDecoder_Type, + "Oi", self->decoder, (int)self->readtranslate); + if (incrementalDecoder == NULL) + return -1; + Py_CLEAR(self->decoder); + self->decoder = incrementalDecoder; + } + + return 0; +} + + +static int +_textiowrapper_set_encoder(textio *self, const char *encoding, + const char *errors) +{ + PyObject *res, *ci; + int r; + + res = _PyObject_CallMethodId(self->buffer, &PyId_writable, NULL); + if (res == NULL) + return -1; + + r = PyObject_IsTrue(res); + Py_DECREF(res); + if (r == -1) + return -1; + + if (r != 1) + return 0; + + Py_CLEAR(self->encoder); + self->encoder = PyCodec_IncrementalEncoder(encoding, errors); + if (self->encoder == NULL) + return -1; + + /* Get the normalized named of the codec */ + ci = _PyCodec_Lookup(encoding); + if (ci == NULL) + return -1; + + res = _PyObject_GetAttrId(ci, &PyId_name); + Py_DECREF(ci); + if (res == NULL) { + if (PyErr_ExceptionMatches(PyExc_AttributeError)) + PyErr_Clear(); + else + return -1; + } + else if (PyUnicode_Check(res)) { + encodefuncentry *e = encodefuncs; + while (e->name != NULL) { + if (!PyUnicode_CompareWithASCIIString(res, e->name)) { + self->encodefunc = e->encodefunc; + break; + } + e++; + } + } + Py_XDECREF(res); + + return 0; +} static int textiowrapper_init(textio *self, PyObject *args, PyObject *kwds) @@ -990,72 +1088,18 @@ self->writenl = "\r\n"; #endif - /* Build the decoder object */ - res = _PyObject_CallMethodId(buffer, &PyId_readable, NULL); - if (res == NULL) - goto error; - r = PyObject_IsTrue(res); - Py_DECREF(res); - if (r == -1) - goto error; - if (r == 1) { - self->decoder = PyCodec_IncrementalDecoder( - encoding, errors); - if (self->decoder == NULL) - goto error; - - if (self->readuniversal) { - PyObject *incrementalDecoder = PyObject_CallFunction( - (PyObject *)&PyIncrementalNewlineDecoder_Type, - "Oi", self->decoder, (int)self->readtranslate); - if (incrementalDecoder == NULL) - goto error; - Py_CLEAR(self->decoder); - self->decoder = incrementalDecoder; - } - } - - /* Build the encoder object */ - res = _PyObject_CallMethodId(buffer, &PyId_writable, NULL); - if (res == NULL) - goto error; - r = PyObject_IsTrue(res); - Py_DECREF(res); - if (r == -1) - goto error; - if (r == 1) { - PyObject *ci; - self->encoder = PyCodec_IncrementalEncoder( - encoding, errors); - if (self->encoder == NULL) - goto error; - /* Get the normalized named of the codec */ - ci = _PyCodec_Lookup(encoding); - if (ci == NULL) - goto error; - res = _PyObject_GetAttrId(ci, &PyId_name); - Py_DECREF(ci); - if (res == NULL) { - if (PyErr_ExceptionMatches(PyExc_AttributeError)) - PyErr_Clear(); - else - goto error; - } - else if (PyUnicode_Check(res)) { - encodefuncentry *e = encodefuncs; - while (e->name != NULL) { - if (!PyUnicode_CompareWithASCIIString(res, e->name)) { - self->encodefunc = e->encodefunc; - break; - } - e++; - } - } - Py_XDECREF(res); - } - self->buffer = buffer; Py_INCREF(buffer); + + /* Build the decoder object */ + if (_textiowrapper_set_decoder(self, encoding, errors) != 0) + goto error; + + /* Build the encoder object */ + if (_textiowrapper_set_encoder(self, encoding, errors) != 0) + goto error; + + if (Py_TYPE(buffer) == &PyBufferedReader_Type || Py_TYPE(buffer) == &PyBufferedWriter_Type || @@ -1279,6 +1323,7 @@ return 0; } + static PyObject * textiowrapper_write(textio *self, PyObject *args) { @@ -1387,6 +1432,220 @@ self->decoded_chars_used = 0; } + +static PyObject* +_textiowrapper_canonical_codec_name(PyObject *codec_name) +{ + char *c_name = NULL; + PyObject *codec_obj = NULL; + PyObject *canonical_name = NULL; + + c_name = PyUnicode_AsUTF8(codec_name); + if (c_name == NULL) + goto err_out; + + codec_obj = _PyCodec_Lookup(c_name); + if (codec_obj == NULL) + goto err_out; + + canonical_name = PyObject_GetAttrString(codec_obj, "name"); + if (canonical_name == NULL) + goto err_out; + + return canonical_name; + + err_out: + Py_CLEAR(codec_obj); + Py_CLEAR(canonical_name); + return NULL; +} + + +/*[clinic input] +io.TextIOWrapper.set_encoding as textiowrapper_set_encoding + + encoding: object + Name of new encoding to use + errors: object = NULL + Error handler to use. + +Change the encoding of the stream. + +For non-seekable streams it may not be possible to change the encoding if some +data has already been read from the stream. + +Changing the encoding of a seekable stream may invalidate any previous +position markers obtained from `tell`. +[clinic start generated code]*/ + +PyDoc_STRVAR(textiowrapper_set_encoding__doc__, +"set_encoding(self, encoding, errors=None)\n" +"Change the encoding of the stream.\n" +"\n" +" encoding\n" +" Name of new encoding to use\n" +" errors\n" +" Error handler to use.\n" +"\n" +"For non-seekable streams it may not be possible to change the\n" +"encoding if some data has already been read from the stream."); + +#define TEXTIOWRAPPER_SET_ENCODING_METHODDEF \ + {"set_encoding", (PyCFunction)textiowrapper_set_encoding, METH_VARARGS|METH_KEYWORDS, textiowrapper_set_encoding__doc__}, + +static PyObject * +textiowrapper_set_encoding_impl(textio *self, PyObject *encoding, PyObject *errors); + +static PyObject * +textiowrapper_set_encoding(textio *self, PyObject *args, PyObject *kwargs) +{ + PyObject *return_value = NULL; + static char *_keywords[] = {"encoding", "errors", NULL}; + PyObject *encoding; + PyObject *errors = NULL; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, + "O|O:set_encoding", _keywords, + &encoding, &errors)) + goto exit; + return_value = textiowrapper_set_encoding_impl(self, encoding, errors); + +exit: + return return_value; +} + +static PyObject * +textiowrapper_set_encoding_impl(textio *self, PyObject *encoding, PyObject *errors) +/*[clinic end generated code: checksum=200293dda638c928acdccc3f366e2fa9d2e734c7]*/ +{ + char pending_decoded_text = 0; + PyObject *old_decoder = NULL; + double old_b2cratio; + int str_len, res; + PyObject *encoding_cname, *old_encoding_cname; // canonical name + + CHECK_INITIALIZED(self); + + if (errors == NULL) + errors = PyUnicode_FromString("strict"); + + /* Get the normalized named of the old and new codec */ + encoding_cname = _textiowrapper_canonical_codec_name(encoding); + if (encoding_cname == NULL) + return NULL; + old_encoding_cname = _textiowrapper_canonical_codec_name(self->encoding); + if (old_encoding_cname == NULL) { + Py_CLEAR(encoding_cname); + return NULL; + } + + /* Compare with current codec and error handler */ + res = (PyUnicode_Compare(encoding_cname, old_encoding_cname) == 0 + && PyUnicode_Compare(errors, self->errors) == 0); + Py_CLEAR(encoding_cname); + Py_CLEAR(old_encoding_cname); + if (res) + // No change + Py_RETURN_NONE; + + /* Check if something is in the read buffer */ + str_len = PyUnicode_GetLength(self->decoded_chars); + if (str_len < 0) + return NULL; + if (self->decoded_chars && self->decoded_chars_used != str_len) + pending_decoded_text = 1; + + if (pending_decoded_text && !self->seekable) { + _unsupported("It is not possible to set the encoding " + "of a non seekable file after the first read"); + return NULL; + } + + // Flush write buffer + if (_textiowrapper_writeflush(self) != 0) + return NULL; + + // Reset attributes + old_decoder = self->decoder; + Py_XINCREF(old_decoder); + + old_b2cratio = self->b2cratio; + self->b2cratio = 0; + + Py_CLEAR(self->encoding); + self->encoding = encoding; + Py_INCREF(self->encoding); + + Py_CLEAR(self->errors); + self->errors = errors; + Py_INCREF(self->errors); + + // Create new encoder & decoder + if (_textiowrapper_set_decoder(self, PyUnicode_AsUTF8(encoding), + PyUnicode_AsUTF8(errors)) != 0) + return NULL; + if (_textiowrapper_set_encoder(self, PyUnicode_AsUTF8(encoding), + PyUnicode_AsUTF8(errors)) != 0) + return NULL; + + if (pending_decoded_text) { + // Compute the length in bytes of the characters already read + PyObject *dec_flags, *input_chunk, *decoded_chars, *substr, *res; + size_t byteslen; + ssize_t direction; + + if (!PyArg_ParseTuple(self->snapshot, "OO", &dec_flags, &input_chunk)) + goto fail; + if (old_b2cratio > 0) { + byteslen = old_b2cratio * self->decoded_chars_used; + direction = 0; + } + else { + byteslen = 1; + direction = 1; + } + while(1) { + res = _PyObject_CallMethodId(old_decoder, &PyId_setstate, + "((yi))", "", dec_flags); + if (res == NULL) + goto fail; + + substr = PyUnicode_Substring(input_chunk, 0, byteslen); + if (substr == NULL) + goto fail; + + /* FIXME: Probably makes sense to factor out the next lines + into a separate function that is used by both _set_encoding() + _read_chunk() */ + if (Py_TYPE(self->decoder) == &PyIncrementalNewlineDecoder_Type) { + decoded_chars = _PyIncrementalNewlineDecoder_decode( + old_decoder, substr, 0); + } + else { + decoded_chars = PyObject_CallMethodObjArgs(old_decoder, + _PyIO_str_decode, substr, Py_False, NULL); + } + + if (check_decoded(decoded_chars) < 0) + // TODO: Handle UnicodeDecode exception here + goto fail; + + // TODO: Check if we have decoded up to correct position, + // if not, update byteslen and direction. + + + } + // TODO: decode the tail of the read buffer using the new decoder + } + + // FIXME: Handle BOM + + Py_RETURN_NONE; + + fail: + return NULL; +} + static PyObject * textiowrapper_get_decoded_chars(textio *self, Py_ssize_t n) { @@ -2736,6 +2995,8 @@ {"seek", (PyCFunction)textiowrapper_seek, METH_VARARGS}, {"tell", (PyCFunction)textiowrapper_tell, METH_NOARGS}, {"truncate", (PyCFunction)textiowrapper_truncate, METH_VARARGS}, + + TEXTIOWRAPPER_SET_ENCODING_METHODDEF {NULL, NULL} };