diff -r 9e5d35ee0903 Lib/asyncio/streams.py --- a/Lib/asyncio/streams.py Wed Jan 14 00:54:00 2015 +0100 +++ b/Lib/asyncio/streams.py Wed Jan 14 01:47:18 2015 +0100 @@ -14,6 +14,7 @@ from . import coroutines from . import events from . import futures from . import protocols +from . import tasks from .coroutines import coroutine from .log import logger @@ -305,6 +306,8 @@ class StreamWriter: class StreamReader: def __init__(self, limit=_DEFAULT_LIMIT, loop=None): + if limit < 1: + raise ValueError("limit greater or equal than 1") # The line length limit is a security feature; # it also doubles as half the buffer limit. self._limit = limit @@ -378,8 +381,11 @@ class StreamReader: else: self._paused = True - def _wait_for_data(self, func_name): - """Wait until feed_data() or feed_eof() is called.""" + def _wait_for_data(self, func_name, timeout, blocks): + """Wait until feed_data() or feed_eof() is called. + + blocks is a list of bytes strings which are pushed back to the buffer + is the wait raises an exception.""" # StreamReader uses a future to link the protocol feed_data() method # to a read coroutine. Running two read coroutines at the same time # would have an unexpected behaviour. It would not possible to know @@ -390,12 +396,18 @@ class StreamReader: self._waiter = futures.Future(loop=self._loop) try: - yield from self._waiter + assert not self._buffer + try: + yield from tasks.wait_for(self._waiter, timeout, + loop=self._loop) + except: + self._buffer.extend(b''.join(blocks)) + raise finally: self._waiter = None @coroutine - def readline(self): + def readline(self, timeout=None): if self._exception is not None: raise self._exception @@ -422,36 +434,24 @@ class StreamReader: break if not_enough: - yield from self._wait_for_data('readline') + yield from self._wait_for_data('readline', timeout, + (line,)) self._maybe_resume_transport() return bytes(line) @coroutine - def read(self, n=-1): + def _read(self, n, timeout, blocks): if self._exception is not None: raise self._exception if not n: return b'' - if n < 0: - # This used to just loop creating a new waiter hoping to - # collect everything in self._buffer, but that would - # deadlock if the subprocess sends more than self.limit - # bytes. So just call self.read(self._limit) until EOF. - blocks = [] - while True: - block = yield from self.read(self._limit) - if not block: - break - blocks.append(block) - return b''.join(blocks) - else: - if not self._buffer and not self._eof: - yield from self._wait_for_data('read') + if not self._buffer and not self._eof: + yield from self._wait_for_data('read', timeout, blocks) - if n < 0 or len(self._buffer) <= n: + if len(self._buffer) <= n: data = bytes(self._buffer) self._buffer.clear() else: @@ -463,7 +463,32 @@ class StreamReader: return data @coroutine - def readexactly(self, n): + def _read_until_eof(self, timeout): + if self._exception is not None: + raise self._exception + + # This used to just loop creating a new waiter hoping to + # collect everything in self._buffer, but that would + # deadlock if the subprocess sends more than self.limit + # bytes. So just call self._read(self._limit) until EOF. + blocks = [] + while True: + block = yield from self._read(self._limit, timeout, blocks) + if not block: + break + blocks.append(block) + return b''.join(blocks) + + @coroutine + def read(self, n=-1, timeout=None): + if n < 0: + data = yield from self._read_until_eof(timeout) + else: + data = yield from self._read(n, timeout, ()) + return data + + @coroutine + def readexactly(self, n, timeout=None): if self._exception is not None: raise self._exception @@ -476,7 +501,7 @@ class StreamReader: blocks = [] while n > 0: - block = yield from self.read(n) + block = yield from self._read(n, timeout, blocks) if not block: partial = b''.join(blocks) raise IncompleteReadError(partial, len(partial) + n) diff -r 9e5d35ee0903 Lib/test/test_asyncio/test_streams.py --- a/Lib/test/test_asyncio/test_streams.py Wed Jan 14 00:54:00 2015 +0100 +++ b/Lib/test/test_asyncio/test_streams.py Wed Jan 14 01:47:18 2015 +0100 @@ -31,6 +31,11 @@ class StreamReaderTests(test_utils.TestC gc.collect() super().tearDown() + def test_ctor_limit(self): + # limit must by >= 1 + self.assertRaises(ValueError, asyncio.StreamReader, 0, loop=self.loop) + self.assertRaises(ValueError, asyncio.StreamReader, -1, loop=self.loop) + @mock.patch('asyncio.streams.events') def test_ctor_global_loop(self, m_events): stream = asyncio.StreamReader() @@ -174,6 +179,21 @@ class StreamReaderTests(test_utils.TestC self.assertEqual(b'', data) self.assertEqual(b'', stream._buffer) + def test_read_timeout(self): + stream = asyncio.StreamReader(loop=self.loop) + read_task = asyncio.Task(stream.read(30, timeout=0.010), + loop=self.loop) + self.assertRaises(asyncio.TimeoutError, + self.loop.run_until_complete, read_task) + + def test_read_timeout_ok(self): + stream = asyncio.StreamReader(loop=self.loop) + read_task = asyncio.Task(stream.read(30, timeout=0.050), + loop=self.loop) + self.loop.call_later(0.010, stream.feed_data, self.DATA) + data = self.loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + def test_read_until_eof(self): # Read all bytes until eof. stream = asyncio.StreamReader(loop=self.loop) @@ -201,6 +221,23 @@ class StreamReaderTests(test_utils.TestC self.assertRaises( ValueError, self.loop.run_until_complete, stream.read(2)) + def test_read_until_eof_timeout(self): + stream = asyncio.StreamReader(loop=self.loop) + + read_task = asyncio.Task(stream.read(timeout=0.010), + loop=self.loop) + self.assertRaises(asyncio.TimeoutError, + self.loop.run_until_complete, read_task) + + def test_read_until_eof_timeout_partial(self): + stream = asyncio.StreamReader(loop=self.loop) + read_task = asyncio.Task(stream.read(timeout=0.050), + loop=self.loop) + self.loop.call_later(0.010, stream.feed_data, self.DATA) + self.assertRaises(asyncio.TimeoutError, + self.loop.run_until_complete, read_task) + self.assertEqual(self.DATA, stream._buffer) + def test_readline(self): # Read one line. 'readline' will need to wait for the data # to come from 'cb' @@ -340,6 +377,30 @@ class StreamReaderTests(test_utils.TestC ValueError, self.loop.run_until_complete, stream.readline()) self.assertEqual(b'', stream._buffer) + def test_readline_timeout_ok(self): + stream = asyncio.StreamReader(loop=self.loop) + read_task = asyncio.Task(stream.readline(timeout=0.010), + loop=self.loop) + self.assertRaises(asyncio.TimeoutError, + self.loop.run_until_complete, read_task) + + def test_readline_timeout_ok(self): + stream = asyncio.StreamReader(loop=self.loop) + read_task = asyncio.Task(stream.readline(timeout=0.050), + loop=self.loop) + self.loop.call_soon(stream.feed_data, b'line\n') + line = self.loop.run_until_complete(read_task) + self.assertEqual(b'line\n', line) + + def test_readline_timeout_partial(self): + stream = asyncio.StreamReader(loop=self.loop) + read_task = asyncio.Task(stream.readline(timeout=0.050), + loop=self.loop) + self.loop.call_soon(stream.feed_data, b'line') + self.assertRaises(asyncio.TimeoutError, + self.loop.run_until_complete, read_task) + self.assertEqual(b'line', stream._buffer) + def test_readexactly_zero_or_less(self): # Read exact number of bytes (zero or less). stream = asyncio.StreamReader(loop=self.loop) @@ -400,6 +461,35 @@ class StreamReaderTests(test_utils.TestC self.assertRaises( ValueError, self.loop.run_until_complete, stream.readexactly(2)) + def test_readexactly_timeout(self): + stream = asyncio.StreamReader(loop=self.loop) + n = 2 * len(self.DATA) + read_task2 = asyncio.Task(stream.readexactly(n, timeout=0.010), + loop=self.loop) + self.assertRaises(asyncio.TimeoutError, + self.loop.run_until_complete, read_task2) + + def test_readexactly_timeout_ok(self): + stream = asyncio.StreamReader(loop=self.loop) + n = 2 * len(self.DATA) + read_task = asyncio.Task(stream.readexactly(n, timeout=0.050), + loop=self.loop) + self.loop.call_later(0.010, stream.feed_data, self.DATA) + self.loop.call_later(0.020, stream.feed_data, self.DATA) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(self.DATA + self.DATA, data) + + def test_readexactly_timeout_partial(self): + stream = asyncio.StreamReader(loop=self.loop) + n = 2 * len(self.DATA) + read_task = asyncio.Task(stream.readexactly(n, timeout=0.050), + loop=self.loop) + self.loop.call_later(0.010, stream.feed_data, self.DATA) + self.assertRaises(asyncio.TimeoutError, + self.loop.run_until_complete, read_task) + self.assertEqual(self.DATA, stream._buffer) + def test_exception(self): stream = asyncio.StreamReader(loop=self.loop) self.assertIsNone(stream.exception())