diff -r d73a32dd9bbc Lib/io.py --- a/Lib/io.py Fri Aug 01 21:33:00 2008 +0200 +++ b/Lib/io.py Fri Aug 01 22:03:27 2008 +0200 @@ -1018,39 +1018,79 @@ class BufferedWriter(_BufferedIOMixin): raw._checkWritable() _BufferedIOMixin.__init__(self, raw) self.buffer_size = buffer_size - self.max_buffer_size = (2*buffer_size - if max_buffer_size is None - else max_buffer_size) - self._write_buf = bytearray() + self._write_buf = bytearray(b" " * self.buffer_size) + self._reset_write_buf() + + def _reset_write_buf(self): + # Just after the last byte actually written + self._write_pos = 0 + # Just after the last buffered byte + self._write_end = 0 def write(self, b): if self.closed: raise ValueError("write to closed file") if isinstance(b, unicode): raise TypeError("can't write unicode to binary stream") - # XXX we can implement some more tricks to try and avoid partial writes - if len(self._write_buf) > self.buffer_size: - # We're full, so let's pre-flush the buffer - try: - self.flush() - except BlockingIOError as e: - # We can't accept anything else. - # XXX Why not just let the exception pass through? - raise BlockingIOError(e.errno, e.strerror, 0) - before = len(self._write_buf) - self._write_buf.extend(b) - written = len(self._write_buf) - before - if len(self._write_buf) > self.buffer_size: - try: - self.flush() - except BlockingIOError as e: - if (len(self._write_buf) > self.max_buffer_size): - # We've hit max_buffer_size. We have to accept a partial - # write and cut back our buffer. - overage = len(self._write_buf) - self.max_buffer_size - self._write_buf = self._write_buf[:self.max_buffer_size] - raise BlockingIOError(e.errno, e.strerror, overage) - return written + end = self._write_end + free = self.buffer_size - end + # If b is a bytearray, ensure it won't be mutated + # (we must keep the internal buffer size constant) + b = bytes(b) + if len(b) <= free: + # Fast path: the data to write can be fully buffered + self._write_buf[end:end+len(b)] = b + self._write_end = end + len(b) + return len(b) + + # First write the current buffer + try: + self.flush() + except BlockingIOError as e: + # Make some place by rotating the buffer + pos = self._write_pos + if pos > 0: + self._write_buf[:-pos] = self._write_buf[pos:] + self._write_pos = 0 + end = self._write_end - pos + avail = self.buffer_size - end + # Buffer as much as possible + if avail < len(b): + self._write_buf[end:] = b[:avail] + self._write_end = self.buffer_size + # The remaining bytes don't fit inside the buffer + raise BlockingIOError(e.errno, e.strerror, avail) + self._write_buf[end:end+len(b)] = b + self._write_end = end + len(b) + return len(b) + + # Then write b itself + remaining = len(b) + written = 0 + try: + while remaining >= self.buffer_size: + # XXX if self.raw.write() doesn't want to write everything + # at once, this can be quadratic. + n = self.raw.write(b[written:]) + written += n + remaining -= n + except BlockingIOError as e: + n = e.characters_written + written += n + remaining -= n + if remaining > self.buffer_size: + # Buffer as much as possible + self._write_buf[:] = b[written:written+self.buffer_size] + self._write_end = self.buffer_size + # The remaining bytes don't fit inside the buffer + raise BlockingIOError(e.errno, e.strerror, + written + self.buffer_size) + + if remaining > 0: + self._write_buf[:remaining] = b[-remaining:] + self._write_pos = 0 + self._write_end = remaining + return written + remaining def truncate(self, pos=None): self.flush() @@ -1063,18 +1103,20 @@ class BufferedWriter(_BufferedIOMixin): raise ValueError("flush of closed file") written = 0 try: - while self._write_buf: - n = self.raw.write(self._write_buf) - del self._write_buf[:n] + end = self._write_end + while self._write_pos < end: + n = self.raw.write(self._write_buf[self._write_pos:end]) + self._write_pos += n written += n except BlockingIOError as e: n = e.characters_written - del self._write_buf[:n] + self._write_pos += n written += n raise BlockingIOError(e.errno, e.strerror, written) + self._reset_write_buf() def tell(self): - return self.raw.tell() + len(self._write_buf) + return self.raw.tell() + self._write_end - self._write_pos def seek(self, pos, whence=0): self.flush() @@ -1107,7 +1149,7 @@ class BufferedRWPair(BufferedIOBase): reader._checkReadable() writer._checkWritable() self.reader = BufferedReader(reader, buffer_size) - self.writer = BufferedWriter(writer, buffer_size, max_buffer_size) + self.writer = BufferedWriter(writer, buffer_size) def read(self, n=None): if n is None: @@ -1157,11 +1199,14 @@ class BufferedRandom(BufferedWriter, Buf writer) defaults to twice the buffer size. """ + # XXX reusing Buffered{Reader, Writer} makes the implementation + # simple but inefficient (see flush() calls in read() etc.) + def __init__(self, raw, buffer_size=DEFAULT_BUFFER_SIZE, max_buffer_size=None): raw._checkSeekable() BufferedReader.__init__(self, raw, buffer_size) - BufferedWriter.__init__(self, raw, buffer_size, max_buffer_size) + BufferedWriter.__init__(self, raw, buffer_size) def seek(self, pos, whence=0): self.flush() @@ -1169,11 +1214,12 @@ class BufferedRandom(BufferedWriter, Buf # if the raw seek fails, we don't lose buffered data forever. pos = self.raw.seek(pos, whence) self._reset_read_buf() + self._reset_write_buf() return pos def tell(self): - if self._write_buf: - return self.raw.tell() + len(self._write_buf) + if self._write_end > self._write_pos: + return BufferedWriter.tell(self) else: return BufferedReader.tell(self) @@ -1207,6 +1253,7 @@ class BufferedRandom(BufferedWriter, Buf # Undo readahead self.raw.seek(self._read_pos - len(self._read_buf), 1) self._reset_read_buf() + self._reset_write_buf() return BufferedWriter.write(self, b) diff -r d73a32dd9bbc Lib/test/test_io.py --- a/Lib/test/test_io.py Fri Aug 01 21:33:00 2008 +0200 +++ b/Lib/test/test_io.py Fri Aug 01 22:03:27 2008 +0200 @@ -7,6 +7,7 @@ import time import time import array import unittest +import threading from itertools import chain from test import test_support @@ -19,6 +20,9 @@ class MockRawIO(io.RawIOBase): def __init__(self, read_stack=()): self._read_stack = list(read_stack) self._write_stack = [] + + def get_written(self): + return b"".join(self._write_stack) def read(self, n=None): try: @@ -27,7 +31,7 @@ class MockRawIO(io.RawIOBase): return b"" def write(self, b): - self._write_stack.append(b[:]) + self._write_stack.append(bytes(b)) return len(b) def writable(self): @@ -63,17 +67,33 @@ class MockFileIO(io.BytesIO): class MockNonBlockWriterIO(io.RawIOBase): - def __init__(self, blocking_script): - self._blocking_script = list(blocking_script) + def __init__(self): self._write_stack = [] + self._blocker_char = None + + def pop_written(self): + s = b"".join(self._write_stack) + self._write_stack[:] = [] + return s + + def block_on(self, char): + """Block when a given char is encountered.""" + self._blocker_char = char def write(self, b): - self._write_stack.append(b[:]) - n = self._blocking_script.pop(0) - if (n < 0): - raise io.BlockingIOError(0, "test blocking", -n) - else: - return n + b = bytes(b) + n = -1 + if self._blocker_char: + try: + n = b.index(self._blocker_char) + except ValueError: + pass + else: + self._blocker_char = None + self._write_stack.append(b[:n]) + raise io.BlockingIOError(0, "test blocking", n) + self._write_stack.append(b) + return len(b) def writable(self): return True @@ -398,7 +418,7 @@ class BufferedWriterTest(unittest.TestCa writer = MockRawIO() bufio = io.BufferedWriter(writer, 8) - bufio.write(b"abc") + self.assertEquals(bufio.write(b"abc"), 3) self.assertFalse(writer._write_stack) @@ -406,30 +426,39 @@ class BufferedWriterTest(unittest.TestCa writer = MockRawIO() bufio = io.BufferedWriter(writer, 8) - bufio.write(b"abc") - bufio.write(b"defghijkl") + self.assertEquals(bufio.write(b"abc"), 3) + self.assertEquals(bufio.write(b"defghijkl"), 9) - self.assertEquals(b"abcdefghijkl", writer._write_stack[0]) + # The first 8 bytes were written + self.assertTrue(writer.get_written().startswith(b"abcdefgh"), + writer.get_written()) def testWriteNonBlocking(self): - raw = MockNonBlockWriterIO((9, 2, 22, -6, 10, 12, 12)) - bufio = io.BufferedWriter(raw, 8, 16) + raw = MockNonBlockWriterIO() + bufio = io.BufferedWriter(raw, 8) - bufio.write(b"asdf") - bufio.write(b"asdfa") - self.assertEquals(b"asdfasdfa", raw._write_stack[0]) + self.assertEquals(bufio.write(b"abcd"), 4) + self.assertEquals(bufio.write(b"efghi"), 5) + # 1 byte will be written, the rest will be buffered + raw.block_on(b"k") + self.assertEquals(bufio.write(b"jklmnopqr"), 9) - bufio.write(b"asdfasdfasdf") - self.assertEquals(b"asdfasdfasdf", raw._write_stack[1]) - bufio.write(b"asdfasdfasdf") - self.assertEquals(b"dfasdfasdf", raw._write_stack[2]) - self.assertEquals(b"asdfasdfasdf", raw._write_stack[3]) + # 4 bytes will be written, 8 will be buffered and the rest will be lost + raw.block_on(b"0") + try: + bufio.write(b"wxyz0123456789") + except io.BlockingIOError as e: + written = e.characters_written + else: + self.fail("BlockingIOError should have been raised") + self.assertEquals(written, 12) + self.assertEquals(raw.pop_written(), + b"abcdefghijklmnopqrwxyz") - bufio.write(b"asdfasdfasdf") - - # XXX I don't like this test. It relies too heavily on how the - # algorithm actually works, which we might change. Refactor - # later. + self.assertEquals(bufio.write(b"ABCDEFGHI"), 9) + s = raw.pop_written() + # Previously buffered bytes were flushed + self.assertTrue(s.startswith(b"01234567A"), s) def testFileno(self): rawio = MockRawIO((b"abc", b"d", b"efg")) @@ -445,6 +474,36 @@ class BufferedWriterTest(unittest.TestCa bufio.flush() self.assertEquals(b"abc", writer._write_stack[0]) + + def testThreads(self): + # BufferedWriter should not raise exceptions or crash + # when called from multiple threads. + try: + # We use a real file object because it allows us to + # exercise situations where the GIL is released before + # writing the buffer. + with io.open(test_support.TESTFN, "wb", buffering=0) as raw: + bufio = io.BufferedWriter(raw, 8) + errors = [] + def f(): + try: + # Write enough bytes to flush the buffer + s = b"a" * 19 + for i in range(50): + bufio.write(s) + except Exception as e: + errors.append(e) + raise + threads = [threading.Thread(target=f) for x in range(20)] + for t in threads: + t.start() + time.sleep(0.02) # yield + for t in threads: + t.join() + self.assertFalse(errors, + "the following exceptions were caught: %r" % errors) + finally: + test_support.unlink(test_support.TESTFN) class BufferedRWPairTest(unittest.TestCase):