diff -r 3ce292093764 Lib/io.py --- a/Lib/io.py Thu Jul 31 11:45:39 2008 +0200 +++ b/Lib/io.py Thu Jul 31 16:17:06 2008 +0200 @@ -1018,39 +1018,84 @@ raw._checkWritable() _BufferedIOMixin.__init__(self, raw) self.buffer_size = buffer_size - self.max_buffer_size = (2*buffer_size - if max_buffer_size is None - else max_buffer_size) - self._write_buf = bytearray() + self._write_buf = bytearray(b" " * self.buffer_size) + self._reset_write_buf() + + def _reset_write_buf(self): + assert len(self._write_buf) == self.buffer_size + # Just after the last byte actually written + self._write_pos = 0 + # Just after the last buffered byte + self._write_end = 0 def write(self, b): if self.closed: raise ValueError("write to closed file") if isinstance(b, unicode): raise TypeError("can't write unicode to binary stream") - # XXX we can implement some more tricks to try and avoid partial writes - if len(self._write_buf) > self.buffer_size: - # We're full, so let's pre-flush the buffer - try: + end = self._write_end + free = self.buffer_size - end + # If b is a bytearray, ensure it won't be mutated + # (we must keep the internal buffer size constant) + b = bytes(b) + if len(b) <= free: + # Fast path: the data to write can be fully buffered + self._write_buf[end:end+len(b)] = b + self._write_end = end + len(b) + if len(b) == free: self.flush() - except BlockingIOError as e: - # We can't accept anything else. - # XXX Why not just let the exception pass through? - raise BlockingIOError(e.errno, e.strerror, 0) - before = len(self._write_buf) - self._write_buf.extend(b) - written = len(self._write_buf) - before - if len(self._write_buf) > self.buffer_size: - try: - self.flush() - except BlockingIOError as e: - if (len(self._write_buf) > self.max_buffer_size): - # We've hit max_buffer_size. We have to accept a partial - # write and cut back our buffer. - overage = len(self._write_buf) - self.max_buffer_size - self._write_buf = self._write_buf[:self.max_buffer_size] - raise BlockingIOError(e.errno, e.strerror, overage) - return written + assert len(self._write_buf) == self.buffer_size + return len(b) + + # First write the current buffer and the beginning of b + self._write_buf[end:] = b[:free] + self._write_end = self.buffer_size + assert len(self._write_buf) == self.buffer_size + # Bytes that don't fit in the current buffer + remaining = len(b) - free + try: + # Pending bytes from current buffer + pending = end - self._write_pos + self.flush() + except BlockingIOError as e: + # Compute the number of bytes actually written from b + written = max(e.characters_written - pending, 0) + # Number of buffer bytes available to buffer the rest of b + avail = self._write_pos + if remaining > avail: + # The remaining bytes don't fit inside the buffer + # XXX we could still buffer as much as we can + raise BlockingIOError(e.errno, e.strerror, free) + # Make place at the end of the buffer, and put the + # remaining bytes there. + pos = self._write_pos + if pos > 0: + self._write_buf[:-pos] = self._write_buf[pos:] + self._write_pos = 0 + end = self._write_end - pos + self._write_buf[end:end+remaining] = b[-remaining:] + self._write_end = end + remaining + return len(b) + written = free + try: + while remaining >= self.buffer_size: + # XXX This might be quadratic in certain circumstances + n = self.raw.write(b[-remaining:]) + written += n + remaining -= n + except BlockingIOError as e: + n = e.characters_written + written += n + remaining -= n + if remaining > self.buffer_size: + # The remaining bytes don't fit inside the buffer + raise BlockingIOError(e.errno, e.strerror, written) + + if remaining > 0: + self._write_buf[:remaining] = b[-remaining:] + self._write_pos = 0 + self._write_end = remaining + return written + remaining def truncate(self, pos=None): self.flush() @@ -1063,18 +1108,20 @@ raise ValueError("flush of closed file") written = 0 try: - while self._write_buf: - n = self.raw.write(self._write_buf) - del self._write_buf[:n] + end = self._write_end + while self._write_pos < end: + n = self.raw.write(self._write_buf[self._write_pos:end]) + self._write_pos += n written += n except BlockingIOError as e: n = e.characters_written - del self._write_buf[:n] + self._write_pos += n written += n raise BlockingIOError(e.errno, e.strerror, written) + self._reset_write_buf() def tell(self): - return self.raw.tell() + len(self._write_buf) + return self.raw.tell() + self._write_end - self._write_pos def seek(self, pos, whence=0): self.flush() @@ -1107,7 +1154,7 @@ reader._checkReadable() writer._checkWritable() self.reader = BufferedReader(reader, buffer_size) - self.writer = BufferedWriter(writer, buffer_size, max_buffer_size) + self.writer = BufferedWriter(writer, buffer_size) def read(self, n=None): if n is None: @@ -1157,11 +1204,14 @@ writer) defaults to twice the buffer size. """ + # XXX reusing Buffered{Reader, Writer} makes the implementation + # simple but inefficient (see flush() calls in read() etc.) + def __init__(self, raw, buffer_size=DEFAULT_BUFFER_SIZE, max_buffer_size=None): raw._checkSeekable() BufferedReader.__init__(self, raw, buffer_size) - BufferedWriter.__init__(self, raw, buffer_size, max_buffer_size) + BufferedWriter.__init__(self, raw, buffer_size) def seek(self, pos, whence=0): self.flush() @@ -1169,11 +1219,12 @@ # if the raw seek fails, we don't lose buffered data forever. pos = self.raw.seek(pos, whence) self._reset_read_buf() + self._reset_write_buf() return pos def tell(self): - if self._write_buf: - return self.raw.tell() + len(self._write_buf) + if self._write_end > self._write_pos: + return BufferedWriter.tell(self) else: return BufferedReader.tell(self) @@ -1207,6 +1258,7 @@ # Undo readahead self.raw.seek(self._read_pos - len(self._read_buf), 1) self._reset_read_buf() + self._reset_write_buf() return BufferedWriter.write(self, b) diff -r 3ce292093764 Lib/test/test_io.py --- a/Lib/test/test_io.py Thu Jul 31 11:45:39 2008 +0200 +++ b/Lib/test/test_io.py Thu Jul 31 16:17:06 2008 +0200 @@ -7,6 +7,7 @@ import time import array import unittest +import threading from itertools import chain from test import test_support @@ -68,12 +69,13 @@ self._write_stack = [] def write(self, b): - self._write_stack.append(b[:]) + b = bytes(b) n = self._blocking_script.pop(0) + self._write_stack.append(b[:abs(n)]) if (n < 0): raise io.BlockingIOError(0, "test blocking", -n) else: - return n + return min(n, len(b)) def writable(self): return True @@ -398,7 +400,7 @@ writer = MockRawIO() bufio = io.BufferedWriter(writer, 8) - bufio.write(b"abc") + self.assertEquals(bufio.write(b"abc"), 3) self.assertFalse(writer._write_stack) @@ -406,31 +408,58 @@ writer = MockRawIO() bufio = io.BufferedWriter(writer, 8) - bufio.write(b"abc") - bufio.write(b"defghijkl") + self.assertEquals(bufio.write(b"abc"), 3) + self.assertEquals(bufio.write(b"defghijkl"), 9) - self.assertEquals(b"abcdefghijkl", writer._write_stack[0]) + # The first 8 bytes were written + self.assertTrue(writer._write_stack[0].startswith(b"abcdefgh")) def testWriteNonBlocking(self): - raw = MockNonBlockWriterIO((9, 2, 22, -6, 10, 12, 12)) - bufio = io.BufferedWriter(raw, 8, 16) + # XXX This test relies too heavily on the underlying algorithm + raw = MockNonBlockWriterIO((8, 8, -4, 8, -4)) + bufio = io.BufferedWriter(raw, 8) - bufio.write(b"asdf") - bufio.write(b"asdfa") - self.assertEquals(b"asdfasdfa", raw._write_stack[0]) + self.assertEquals(bufio.write(b"abcd"), 4) + self.assertEquals(len(raw._write_stack), 0) + # 8 bytes will be written + self.assertEquals(bufio.write(b"efghi"), 5) + self.assertEquals(len(raw._write_stack), 1) + self.assertEquals(b"abcdefgh", raw._write_stack[0]) - bufio.write(b"asdfasdfasdf") - self.assertEquals(b"asdfasdfasdf", raw._write_stack[1]) - bufio.write(b"asdfasdfasdf") - self.assertEquals(b"dfasdfasdf", raw._write_stack[2]) - self.assertEquals(b"asdfasdfasdf", raw._write_stack[3]) + # 8 bytes will be written + self.assertEquals(bufio.write(b"jklmnopqrstu"), 12) + self.assertEquals(len(raw._write_stack), 2) + self.assertEquals(b"ijklmnop", raw._write_stack[1]) + bufio.write(b"vw") + self.assertEquals(len(raw._write_stack), 2) - bufio.write(b"asdfasdfasdf") + # 4 bytes will be written then block, the rest will be buffered + self.assertEquals(bufio.write(b"xyz"), 3) + self.assertEquals(len(raw._write_stack), 3) + self.assertEquals(b"qrst", raw._write_stack[2]) - # XXX I don't like this test. It relies too heavily on how the - # algorithm actually works, which we might change. Refactor - # later. + # 8 bytes will be written + self.assertEquals(bufio.write(b"123"), 3) + self.assertEquals(len(raw._write_stack), 4) + self.assertEquals(b"uvwxyz12", raw._write_stack[3]) + # 4 bytes will be written then block and raise + # (not enough room in buffer, which contained one remaining byte) + try: + bufio.write(b"456ABCDEFGHIJK") + except io.BlockingIOError as e: + written = e.characters_written + else: + self.fail("BlockingIOError should have been raised") + # At least 3 bytes from the string have been written or + # buffered, and less than the total string length. + self.assertTrue(written >= 3) + self.assertTrue(written < 11) + self.assertEquals(len(raw._write_stack), 5) + self.assertTrue( + b"3456ABCDEFGHIJK".startswith(raw._write_stack[4]), + raw._write_stack[4]) + def testFileno(self): rawio = MockRawIO((b"abc", b"d", b"efg")) bufio = io.BufferedWriter(rawio) @@ -445,6 +474,36 @@ bufio.flush() self.assertEquals(b"abc", writer._write_stack[0]) + + def testThreads(self): + # BufferedWriter should not raise exceptions or crash + # when called from multiple threads. + try: + # We use a real file object because it allows us to + # exercise situations where the GIL is released before + # writing the buffer. + with io.open(test_support.TESTFN, "wb", buffering=0) as raw: + bufio = io.BufferedWriter(raw, 8) + errors = [] + def f(): + try: + # Write enough bytes to flush the buffer + s = b"a" * 19 + for i in range(50): + bufio.write(s) + except Exception as e: + errors.append(e) + raise + threads = [threading.Thread(target=f) for x in range(20)] + for t in threads: + t.start() + time.sleep(0.02) # yield + for t in threads: + t.join() + self.assertFalse(errors, + "the following exceptions were caught: %r" % errors) + finally: + test_support.unlink(test_support.TESTFN) class BufferedRWPairTest(unittest.TestCase):