diff -r 3f4e15c3d089 asyncio/streams.py --- a/asyncio/streams.py Tue Feb 04 08:54:49 2014 +0100 +++ b/asyncio/streams.py Tue Feb 04 12:38:39 2014 -0500 @@ -4,8 +4,6 @@ 'open_connection', 'start_server', 'IncompleteReadError', ] -import collections - from . import events from . import futures from . import protocols @@ -252,6 +250,8 @@ class StreamReader: + _buffer_factory = bytearray # Constructs initial value for self._buffer. + def __init__(self, limit=_DEFAULT_LIMIT, loop=None): # The line length limit is a security feature; # it also doubles as half the buffer limit. @@ -259,9 +259,7 @@ if loop is None: loop = events.get_event_loop() self._loop = loop - # TODO: Use a bytearray for a buffer, like the transport. - self._buffer = collections.deque() # Deque of bytes objects. - self._byte_count = 0 # Bytes in buffer. + self._buffer = self._buffer_factory() self._eof = False # Whether we're done. self._waiter = None # A future. self._exception = None @@ -285,7 +283,7 @@ self._transport = transport def _maybe_resume_transport(self): - if self._paused and self._byte_count <= self._limit: + if self._paused and len(self._buffer) <= self._limit: self._paused = False self._transport.resume_reading() @@ -301,8 +299,7 @@ if not data: return - self._buffer.append(data) - self._byte_count += len(data) + self._buffer.extend(data) waiter = self._waiter if waiter is not None: @@ -312,7 +309,7 @@ if (self._transport is not None and not self._paused and - self._byte_count > 2*self._limit): + len(self._buffer) > 2*self._limit): try: self._transport.pause_reading() except NotImplementedError: @@ -338,31 +335,31 @@ if self._exception is not None: raise self._exception - parts = [] + parts = self._buffer_factory() parts_size = 0 not_enough = True while not_enough: while self._buffer and not_enough: - data = self._buffer.popleft() - ichar = data.find(b'\n') + ichar = self._buffer.find(b'\n') if ichar < 0: - parts.append(data) - parts_size += len(data) + if len(self._buffer) > self._limit: + self._maybe_resume_transport() + raise ValueError('Line is too long') + parts.extend(self._buffer) + parts_size += len(self._buffer) + self._buffer.clear() else: ichar += 1 - head, tail = data[:ichar], data[ichar:] - if tail: - self._buffer.appendleft(tail) + if ichar + parts_size > self._limit: + self._maybe_resume_transport() + raise ValueError('Line is too long') + head = self._buffer[:ichar] + del self._buffer[:ichar] not_enough = False - parts.append(head) + parts.extend(head) parts_size += len(head) - if parts_size > self._limit: - self._byte_count -= parts_size - self._maybe_resume_transport() - raise ValueError('Line is too long') - if self._eof: break @@ -373,11 +370,9 @@ finally: self._waiter = None - line = b''.join(parts) - self._byte_count -= parts_size self._maybe_resume_transport() - return line + return bytes(parts) @tasks.coroutine def read(self, n=-1): @@ -395,42 +390,37 @@ finally: self._waiter = None else: - if not self._byte_count and not self._eof: + if not self._buffer and not self._eof: self._waiter = self._create_waiter('read') try: yield from self._waiter finally: self._waiter = None - if n < 0 or self._byte_count <= n: - data = b''.join(self._buffer) + if n < 0 or len(self._buffer) <= n: + data = bytes(self._buffer) self._buffer.clear() - self._byte_count = 0 - self._maybe_resume_transport() - return data + else: + # len > 0 and len(self._buffer) > n + data = bytes(self._buffer[:n]) + del self._buffer[:n] - parts = [] - parts_bytes = 0 - while self._buffer and parts_bytes < n: - data = self._buffer.popleft() - data_bytes = len(data) - if n < parts_bytes + data_bytes: - data_bytes = n - parts_bytes - data, rest = data[:data_bytes], data[data_bytes:] - self._buffer.appendleft(rest) - - parts.append(data) - parts_bytes += data_bytes - self._byte_count -= data_bytes - self._maybe_resume_transport() - - return b''.join(parts) + self._maybe_resume_transport() + return data @tasks.coroutine def readexactly(self, n): if self._exception is not None: raise self._exception + if n < 0: + return b'' + + if len(self._buffer) > n: + data = bytes(self._buffer[:n]) + del self._buffer[:n] + return data + # There used to be "optimized" code here. It created its own # Future and waited until self._buffer had at least the n # bytes, then called read(n). Unfortunately, this could pause diff -r 3f4e15c3d089 tests/test_streams.py --- a/tests/test_streams.py Tue Feb 04 08:54:49 2014 +0100 +++ b/tests/test_streams.py Tue Feb 04 12:38:39 2014 -0500 @@ -79,13 +79,13 @@ stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(b'') - self.assertEqual(0, stream._byte_count) + self.assertEqual(0, len(stream._buffer)) def test_feed_data_byte_count(self): stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(self.DATA) - self.assertEqual(len(self.DATA), stream._byte_count) + self.assertEqual(len(self.DATA), len(stream._buffer)) def test_read_zero(self): # Read zero bytes. @@ -94,7 +94,7 @@ data = self.loop.run_until_complete(stream.read(0)) self.assertEqual(b'', data) - self.assertEqual(len(self.DATA), stream._byte_count) + self.assertEqual(len(self.DATA), len(stream._buffer)) def test_read(self): # Read bytes. @@ -107,7 +107,7 @@ data = self.loop.run_until_complete(read_task) self.assertEqual(self.DATA, data) - self.assertFalse(stream._byte_count) + self.assertEqual(0, len(stream._buffer)) def test_read_line_breaks(self): # Read bytes without line breaks. @@ -118,7 +118,7 @@ data = self.loop.run_until_complete(stream.read(5)) self.assertEqual(b'line1', data) - self.assertEqual(5, stream._byte_count) + self.assertEqual(5, len(stream._buffer)) def test_read_eof(self): # Read bytes, stop at eof. @@ -131,7 +131,7 @@ data = self.loop.run_until_complete(read_task) self.assertEqual(b'', data) - self.assertFalse(stream._byte_count) + self.assertEqual(0, len(stream._buffer)) def test_read_until_eof(self): # Read all bytes until eof. @@ -147,7 +147,7 @@ data = self.loop.run_until_complete(read_task) self.assertEqual(b'chunk1\nchunk2', data) - self.assertFalse(stream._byte_count) + self.assertEqual(0, len(stream._buffer)) def test_read_exception(self): stream = asyncio.StreamReader(loop=self.loop) @@ -174,7 +174,7 @@ line = self.loop.run_until_complete(read_task) self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) - self.assertEqual(len(b'\n chunk4')-1, stream._byte_count) + self.assertEqual(len(b'\n chunk4')-1, len(stream._buffer)) def test_readline_limit_with_existing_data(self): stream = asyncio.StreamReader(3, loop=self.loop) @@ -183,7 +183,7 @@ self.assertRaises( ValueError, self.loop.run_until_complete, stream.readline()) - self.assertEqual([b'line2\n'], list(stream._buffer)) + self.assertEqual(b'line1\nline2\n', stream._buffer) stream = asyncio.StreamReader(3, loop=self.loop) stream.feed_data(b'li') @@ -192,8 +192,7 @@ self.assertRaises( ValueError, self.loop.run_until_complete, stream.readline()) - self.assertEqual([b'li'], list(stream._buffer)) - self.assertEqual(2, stream._byte_count) + self.assertEqual(b'line1li', stream._buffer) def test_readline_limit(self): stream = asyncio.StreamReader(7, loop=self.loop) @@ -207,8 +206,7 @@ self.assertRaises( ValueError, self.loop.run_until_complete, stream.readline()) - self.assertEqual([b'chunk3\n'], list(stream._buffer)) - self.assertEqual(7, stream._byte_count) + self.assertEqual(b'chunk1chunk2chunk3\n', stream._buffer) def test_readline_line_byte_count(self): stream = asyncio.StreamReader(loop=self.loop) @@ -218,7 +216,7 @@ line = self.loop.run_until_complete(stream.readline()) self.assertEqual(b'line1\n', line) - self.assertEqual(len(self.DATA) - len(b'line1\n'), stream._byte_count) + self.assertEqual(len(self.DATA) - len(b'line1\n'), len(stream._buffer)) def test_readline_eof(self): stream = asyncio.StreamReader(loop=self.loop) @@ -246,7 +244,7 @@ self.assertEqual(b'line2\nl', data) self.assertEqual( len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), - stream._byte_count) + len(stream._buffer)) def test_readline_exception(self): stream = asyncio.StreamReader(loop=self.loop) @@ -266,11 +264,11 @@ data = self.loop.run_until_complete(stream.readexactly(0)) self.assertEqual(b'', data) - self.assertEqual(len(self.DATA), stream._byte_count) + self.assertEqual(len(self.DATA), len(stream._buffer)) data = self.loop.run_until_complete(stream.readexactly(-1)) self.assertEqual(b'', data) - self.assertEqual(len(self.DATA), stream._byte_count) + self.assertEqual(len(self.DATA), len(stream._buffer)) def test_readexactly(self): # Read exact number of bytes. @@ -287,7 +285,7 @@ data = self.loop.run_until_complete(read_task) self.assertEqual(self.DATA + self.DATA, data) - self.assertEqual(len(self.DATA), stream._byte_count) + self.assertEqual(len(self.DATA), len(stream._buffer)) def test_readexactly_eof(self): # Read exact number of bytes (eof). @@ -306,7 +304,7 @@ self.assertEqual(cm.exception.expected, n) self.assertEqual(str(cm.exception), '18 bytes read on a total of 36 expected bytes') - self.assertFalse(stream._byte_count) + self.assertEqual(0, len(stream._buffer)) def test_readexactly_exception(self): stream = asyncio.StreamReader(loop=self.loop)