# HG changeset patch # Parent 4a55b98314cd690ee7526cc5fd5d9cc1fec40972 Issue 13881: zlib and bz2 StreamWriter from incremental encoder Also fixes StreamWriter.writelines() in general for all byte codecs. diff -r 4a55b98314cd Doc/library/codecs.rst --- a/Doc/library/codecs.rst Mon Jan 12 21:03:41 2015 +0100 +++ b/Doc/library/codecs.rst Thu Jan 15 04:15:18 2015 +0000 @@ -664,8 +664,7 @@ .. method:: writelines(list) Writes the concatenated list of strings to the stream (possibly by reusing - the :meth:`write` method). The standard bytes-to-bytes codecs - do not support this method. + the :meth:`write` method). .. method:: reset() diff -r 4a55b98314cd Lib/codecs.py --- a/Lib/codecs.py Mon Jan 12 21:03:41 2015 +0100 +++ b/Lib/codecs.py Thu Jan 15 04:15:18 2015 +0000 @@ -379,7 +379,13 @@ """ Writes the concatenated list of strings to the stream using .write(). """ - self.write(''.join(list)) + if not list: + return + if isinstance(list[0], str): + join = ''.join + else: + join = b''.join + self.write(join(list)) def reset(self): @@ -411,6 +417,23 @@ def __exit__(self, type, value, tb): self.stream.close() +class _IncrementalBasedWriter(StreamWriter): + """Generic StreamWriter implementation. + + The _EncoderClass attribute must be set to an IncrementalEncoder + class to use. + """ + + def __init__(self, stream, errors='strict'): + super().__init__(stream, errors) + self._encoder = self._Encoder(errors) + + def write(self, object): + self.stream.write(self._encoder.encode(object)) + + def reset(self): + self.stream.write(self._encoder.encode(final=True)) + ### class StreamReader(Codec): diff -r 4a55b98314cd Lib/encodings/bz2_codec.py --- a/Lib/encodings/bz2_codec.py Mon Jan 12 21:03:41 2015 +0100 +++ b/Lib/encodings/bz2_codec.py Thu Jan 15 04:15:18 2015 +0000 @@ -57,8 +57,8 @@ def reset(self): self.decompressobj = bz2.BZ2Decompressor() -class StreamWriter(Codec, codecs.StreamWriter): - charbuffertype = bytes +class StreamWriter(codecs._IncrementalBasedWriter): + _Encoder = IncrementalEncoder class StreamReader(Codec, codecs.StreamReader): charbuffertype = bytes diff -r 4a55b98314cd Lib/encodings/zlib_codec.py --- a/Lib/encodings/zlib_codec.py Mon Jan 12 21:03:41 2015 +0100 +++ b/Lib/encodings/zlib_codec.py Thu Jan 15 04:15:18 2015 +0000 @@ -56,8 +56,8 @@ def reset(self): self.decompressobj = zlib.decompressobj() -class StreamWriter(Codec, codecs.StreamWriter): - charbuffertype = bytes +class StreamWriter(codecs._IncrementalBasedWriter): + _Encoder = IncrementalEncoder class StreamReader(Codec, codecs.StreamReader): charbuffertype = bytes diff -r 4a55b98314cd Lib/test/test_codecs.py --- a/Lib/test/test_codecs.py Mon Jan 12 21:03:41 2015 +0100 +++ b/Lib/test/test_codecs.py Thu Jan 15 04:15:18 2015 +0000 @@ -6,6 +6,7 @@ import unittest import warnings import encodings +import array from test import support @@ -982,7 +983,6 @@ class ReadBufferTest(unittest.TestCase): def test_array(self): - import array self.assertEqual( codecs.readbuffer_encode(array.array("b", b"spam")), (b"spam", 4) @@ -2520,6 +2520,41 @@ self.assertEqual(size, len(o)) self.assertEqual(i, binput) + def test_writelines(self): + data = b"12345678" + for encoding in bytes_transform_encodings: + Writer = codecs.getwriter(encoding) + if encoding == "uu_codec": + expected = b"" + else: + expected = codecs.encode(b"", encoding) + with self.subTest(encoding=encoding): + writer = Writer(io.BytesIO()) + writer.writelines([]) + writer.reset() + self.assertEqual(expected, writer.getvalue()) + expected = codecs.encode(data * 3, encoding) + for byteslike in (data, bytearray(data), array.array("h", data)): + with self.subTest(encoding=encoding, type=type(byteslike)): + writer = Writer(io.BytesIO()) + writer.writelines([byteslike] * 3) + writer.reset() + self.assertEqual(expected, writer.getvalue()) + + def test_multi_write(self): + data = bytes(200) # Long enough to span a base64/quopri/uu line + broken = {"base64_codec", "quopri_codec", "uu_codec"} + for encoding in bytes_transform_encodings: + if encoding in broken: # See Issue 20132 + continue + with self.subTest(encoding=encoding): + expected = codecs.encode(data, encoding) + writer = codecs.getwriter(encoding)(io.BytesIO()) + for b in data: + writer.write(bytes((b,))) + writer.reset() + self.assertEqual(expected, writer.getvalue()) + def test_read(self): for encoding in bytes_transform_encodings: with self.subTest(encoding=encoding):