# HG changeset patch # Parent 4a55b98314cd690ee7526cc5fd5d9cc1fec40972 Issue 13881: zlib and bz2 StreamWriter from incremental encoder diff -r 4a55b98314cd Lib/codecs.py --- a/Lib/codecs.py Mon Jan 12 21:03:41 2015 +0100 +++ b/Lib/codecs.py Wed Jan 14 05:57:53 2015 +0000 @@ -411,6 +411,22 @@ def __exit__(self, type, value, tb): self.stream.close() +class _IncrementalBasedWriter(StreamWriter): + """Generic StreamWriter implementation. + + The _encoder attribute must be set to an IncrementalEncoder to use. + """ + + def __init__(self, stream, errors='strict'): + super().__init__(stream, errors) + self._encoder = self._EncoderClass(errors) + + def write(self, object): + self.stream.write(self._encoder.encode(object)) + + def reset(self): + self.stream.write(self._encoder.encode(final=True)) + ### class StreamReader(Codec): diff -r 4a55b98314cd Lib/encodings/bz2_codec.py --- a/Lib/encodings/bz2_codec.py Mon Jan 12 21:03:41 2015 +0100 +++ b/Lib/encodings/bz2_codec.py Wed Jan 14 05:57:53 2015 +0000 @@ -57,8 +57,8 @@ def reset(self): self.decompressobj = bz2.BZ2Decompressor() -class StreamWriter(Codec, codecs.StreamWriter): - charbuffertype = bytes +class StreamWriter(codecs._IncrementalBasedWriter): + _EncoderClass = IncrementalEncoder class StreamReader(Codec, codecs.StreamReader): charbuffertype = bytes diff -r 4a55b98314cd Lib/encodings/zlib_codec.py --- a/Lib/encodings/zlib_codec.py Mon Jan 12 21:03:41 2015 +0100 +++ b/Lib/encodings/zlib_codec.py Wed Jan 14 05:57:53 2015 +0000 @@ -56,8 +56,8 @@ def reset(self): self.decompressobj = zlib.decompressobj() -class StreamWriter(Codec, codecs.StreamWriter): - charbuffertype = bytes +class StreamWriter(codecs._IncrementalBasedWriter): + _EncoderClass = IncrementalEncoder class StreamReader(Codec, codecs.StreamReader): charbuffertype = bytes diff -r 4a55b98314cd Lib/test/test_codecs.py --- a/Lib/test/test_codecs.py Mon Jan 12 21:03:41 2015 +0100 +++ b/Lib/test/test_codecs.py Wed Jan 14 05:57:53 2015 +0000 @@ -2520,6 +2520,20 @@ self.assertEqual(size, len(o)) self.assertEqual(i, binput) + def test_writer(self): + data = bytes(200) # Long enough to span a base64/quopri/uu line + broken = {"base64_codec", "quopri_codec", "uu_codec"} + for encoding in bytes_transform_encodings: + if encoding in broken: # See Issue 20132 + continue + with self.subTest(encoding=encoding): + writer = codecs.getwriter(encoding)(io.BytesIO()) + for b in data: + writer.write(bytes((b,))) + writer.reset() + expected = codecs.encode(data, encoding) + self.assertEqual(expected, writer.getvalue()) + def test_read(self): for encoding in bytes_transform_encodings: with self.subTest(encoding=encoding):