diff -r aaa68dce117e Lib/_pyio.py --- a/Lib/_pyio.py Thu Aug 09 21:38:23 2012 +0200 +++ b/Lib/_pyio.py Thu Aug 09 22:36:05 2012 +0200 @@ -1542,6 +1542,90 @@ class TextIOWrapper(TextIOBase): result += " mode={0!r}".format(mode) return result + " encoding={0!r}>".format(self.encoding) + def set_encoding(self, encoding, errors=None): + if not isinstance(encoding, str): + raise ValueError("invalid encoding: %r" % encoding) + + if errors is None: + errors = 'strict' + + old_encoding = codecs.lookup(self._encoding).name + encoding = codecs.lookup(encoding).name + if encoding == old_encoding and errors == self._errors: + # no change + return + + pending_decoded_text = ( + self._decoded_chars + and self._decoded_chars_used != len(self._decoded_chars)) + if pending_decoded_text and not self.seekable(): + raise UnsupportedOperation( + "It is not possible to set the encoding " + "of a non seekable file after the first read") + + # flush write buffer + self.flush() + + # reset attributes + old_decoder = self._decoder or self._get_decoder() + old_b2cratio = self._b2cratio + self._encoding = encoding + self._errors = errors + self._encoder = None + self._decoder = None + self._b2cratio = 0.0 + + if pending_decoded_text: + # compute the length in bytes of the characters already read + new_decoder = self._get_decoder() + dec_flags, input_chunk = self._snapshot + used = self._decoded_chars[:self._decoded_chars_used] + if old_b2cratio > 0.0: + byteslen = round(old_b2cratio * self._decoded_chars_used) + direction = 0 + else: + byteslen = 1 + direction = 1 + while True: + old_decoder.setstate((b'', dec_flags)) + try: + decoded = old_decoder.decode(input_chunk[:byteslen]) + except UnicodeDecodeError: + if direction: + byteslen += direction + else: + byteslen += 1 + else: + if len(decoded) == len(used): + assert decoded == used + break + if not direction: + if len(decoded) > len(used): + direction = -1 + else: + direction = 1 + byteslen += direction + if not(1 <= byteslen <= len(input_chunk)): + raise AssertionError("failed to compute the length in bytes of the read buffer") + + # decode the fail of the read buffer using the new decoder + input_chunk = input_chunk[byteslen:] + decoded_chars = new_decoder.decode(input_chunk, False) + self._snapshot = (dec_flags, input_chunk) + self._set_decoded_chars(decoded_chars) + if decoded_chars: + self._b2cratio = len(input_chunk) / len(decoded_chars) + + # don't write a BOM in the middle of a file + if self._seekable and self.writable(): + position = self.buffer.tell() + if position != 0: + try: + self._get_encoder().setstate(0) + except LookupError: + # Sometimes the encoder doesn't exist + pass + @property def encoding(self): return self._encoding diff -r aaa68dce117e Lib/test/test_io.py --- a/Lib/test/test_io.py Thu Aug 09 21:38:23 2012 +0200 +++ b/Lib/test/test_io.py Thu Aug 09 22:36:05 2012 +0200 @@ -2361,7 +2361,6 @@ class TextIOWrapperTest(unittest.TestCas for charset in ('utf-8-sig', 'utf-16', 'utf-32'): with self.open(filename, 'w', encoding=charset) as f: f.write('aaa') - pos = f.tell() with self.open(filename, 'rb') as f: self.assertEqual(f.read(), 'aaa'.encode(charset)) @@ -2460,6 +2459,83 @@ class TextIOWrapperTest(unittest.TestCas txt.write('5') self.assertEqual(b''.join(raw._write_stack), b'123\n45') + def test_set_encoding_read(self): + # latin1 -> utf8 + # (latin1 can decode utf-8 encoded string) + data = 'abc\xe9\n'.encode('latin1') + 'd\xe9f\n'.encode('utf8') + raw = self.BytesIO(data) + txt = self.TextIOWrapper(raw, encoding='latin1', newline='\n') + self.assertEqual(txt.readline(), 'abc\xe9\n') + txt.set_encoding('utf-8') + self.assertEqual(txt.readline(), 'd\xe9f\n') + + # utf-16-be -> utf-32-be + # (utf-16 can decode utf-32 encoded string) + data = 'abc\n'.encode('utf-16-be') + 'def\n'.encode('utf-32-be') + raw = self.BytesIO(data) + txt = self.TextIOWrapper(raw, encoding='utf-16-be', newline='\n') + self.assertEqual(txt.readline(), 'abc\n') + txt.set_encoding('utf-32-be') + self.assertEqual(txt.readline(), 'def\n') + + # ascii/replace -> latin1/strict + data = 'abc\n'.encode('ascii') + 'd\xe9f\n'.encode('latin1') + raw = self.BytesIO(data) + txt = self.TextIOWrapper(raw, encoding='ascii', errors='replace', newline='\n') + self.assertEqual(txt.readline(), 'abc\n') + txt.set_encoding('latin1', 'strict') + self.assertEqual(txt.readline(), 'd\xe9f\n') + + # latin1 -> utf8 -> ascii -> utf8 + # (latin1 can decode utf-8 encoded string) + data = 'abc\xe9\n'.encode('latin1') + 'd\xe9f\n'.encode('utf8') + 'ghi\n'.encode('utf8') + raw = self.BytesIO(data) + txt = self.TextIOWrapper(raw, encoding='latin1', newline='\n') + self.assertEqual(txt.readline(), 'abc\xe9\n') + txt.set_encoding('utf-8') + self.assertEqual(txt.readline(), 'd\xe9f\n') + txt.set_encoding('ascii') + self.assertEqual(txt.readline(), 'ghi\n') + txt.set_encoding('utf-8') + + + def test_set_encoding_read_non_seekable(self): + # ascii -> latin1 without reading before setting the new encoding + data = 'abc\xe9'.encode('latin1') + raw = self.MockUnseekableIO(data) + txt = self.TextIOWrapper(raw, encoding='ascii', newline='\n') + txt.set_encoding('latin1') + self.assertEqual(txt.readline(), 'abc\xe9') + + # setting the encoding after is read must fail + data = 'xabc\xe9\n'.encode('latin1') + 'yd\xe9f\n'.encode('utf8') + raw = self.MockUnseekableIO(data) + txt = self.TextIOWrapper(raw, encoding='latin1', newline='\n') + self.assertEqual(txt.readline(), 'xabc\xe9\n') + self.assertRaises(self.UnsupportedOperation, txt.set_encoding, 'utf-8') + + def test_set_encoding_write(self): + # latin -> utf8 + raw = self.BytesIO() + txt = self.TextIOWrapper(raw, encoding='latin1', newline='\n') + txt.write('abc\xe9\n') + txt.set_encoding('utf-8') + self.assertEqual(raw.getvalue(), b'abc\xe9\n') + txt.write('d\xe9f\n') + txt.flush() + self.assertEqual(raw.getvalue(), b'abc\xe9\nd\xc3\xa9f\n') + + # ascii -> utf-8-sig: ensure that no BOM is written in the middle of + # the file + raw = self.BytesIO() + txt = self.TextIOWrapper(raw, encoding='ascii', newline='\n') + txt.write('abc\n') + txt.set_encoding('utf-8-sig') + txt.write('d\xe9f\n') + txt.flush() + self.assertEqual(raw.getvalue(), b'abc\nd\xc3\xa9f\n') + + class CTextIOWrapperTest(TextIOWrapperTest): def test_initialization(self):