diff --git a/Lib/http/client.py b/Lib/http/client.py --- a/Lib/http/client.py +++ b/Lib/http/client.py @@ -679,6 +679,78 @@ total_bytes += n return total_bytes + def read1(self, n=-1): + """Read with at most one underlying system call. If at least one + byte is buffered, return that instead. + """ + if self.fp is None or self._method == "HEAD": + return b"" + if self.chunked: + return self._read1_or_peek_chunked(n, True) + result = self.fp.read1(n) + if not result and n: + self._close_conn() + return result + + def peek(self, n=-1): + # Having this enables IOBase.readline() to read more than one + # byte at a time + if self.fp is None or self._method == "HEAD": + return b"" + if self.chunked: + try: + return self._read1_or_peek_chunked(n, False) + except IncompleteRead: + return b"" + return self.fp.peek(n) + + def readline(self, limit=-1): + if self.fp is None or self._method == "HEAD": + return b"" + if self.chunked: + # Fallback to IOBase readline which uses peek() and read() + return super().readline(limit) + result = self.fp.readline(limit) + if not result and limit: + self._close_conn() + return result + + def _read1_or_peek_chunked(self, n, read1): + # strictly speaking, there may be more than one read calls + # because chunk header and tail must be consumed to + # preserve protocol. But chunk data is read using read1() or peek() + if self.chunk_left is None: + try: + # Strictly speaking, this may use more than 1 read + chunk_left = self._read_next_chunk_size() + except ValueError: + raise IncompleteRead(b'') + if chunk_left == 0: + self._read_and_discard_trailer() + # we read everything; close the "file" + self._close_conn() + return b"" + self.chunk_left = chunk_left + if n < 0: + n = self.chunk_left + elif n > 0: + n = min(n, self.chunk_left) + else: + return b"" + if read1: + read = self.fp.read1(n) + self.chunk_left -= len(read) + if self.chunk_left == 0: + self._safe_read(2) # toss the CRLF at the end of the chunk + self.chunk_left = None + else: + read = self.fp.peek(n) + # peek can return more than requested. Truncate at end of chunk + read = read[:self.chunk_left] + if not read: + raise IncompleteRead(b"") + return read + def fileno(self): return self.fp.fileno() diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py --- a/Lib/test/test_httplib.py +++ b/Lib/test/test_httplib.py @@ -658,6 +658,167 @@ conn.request('POST', '/', body) self.assertGreater(sock.sendall_calls, 1) +class ExtendedReadTest(TestCase): + """ + Test peek(), read1(), readline() + """ + lines = ( + 'HTTP/1.1 200 OK\r\n' + '\r\n' + 'hello world!\n' + 'and now \n' + 'for something completely different\n' + 'foo' + ) + lines_expected = lines[lines.find('hello'):].encode("ascii") + lines_chunked = ( + 'HTTP/1.1 200 OK\r\n' + 'Transfer-Encoding: chunked\r\n\r\n' + 'a\r\n' + 'hello worl\r\n' + '3\r\n' + 'd!\n\r\n' + '9\r\n' + 'and now \n\r\n' + '23\r\n' + 'for something completely different\n\r\n' + '3\r\n' + 'foo\r\n' + '0\r\n\r\n' # terminating chunk + ) + + def test_peek(self): + sock = FakeSocket(self.lines) + resp = client.HTTPResponse(sock, method="GET") + resp.begin() + resp.fp = io.BufferedReader(resp.fp) + # patch up the buffered peek so that it returns not too much stuff + oldpeek = resp.fp.peek + def mypeek(n=-1): + p = oldpeek(n) + if n >= 0: + return p[:n] + return p[:10] + self._verify_peek(resp, self.lines_expected) + + def test_peek_chunked(self): + expected = b'hello world! and now for something completely different' + sock = FakeSocket(self.lines_chunked) + resp = client.HTTPResponse(sock, method="GET") + resp.begin() + resp.fp = io.BufferedReader(resp.fp) + self._verify_peek(resp, self.lines_expected) + + def _verify_peek(self, resp, expected): + all = [] + while True: + # try a short peek + p = resp.peek(3) + if p: + self.assertGreater(len(p), 0) + # then unbounded peek + p2 = resp.peek() + self.assertGreaterEqual(len(p2), len(p)) + self.assertTrue(p2.startswith(p)) + next = resp.read(len(p2)) + self.assertEqual(next, p2) + else: + next = resp.read() + self.assertFalse(next) + all.append(next) + if not next: + break + self.assertEqual(b"".join(all), expected) + resp.close() + + def test_readline(self): + sock = FakeSocket(self.lines) + resp = client.HTTPResponse(sock, method="GET") + resp.begin() + resp.fp = io.BufferedReader(resp.fp) + self._verify_readline(resp.readline, self.lines_expected) + resp.close() + + def test_readline_chunked(self): + sock = FakeSocket(self.lines_chunked) + resp = client.HTTPResponse(sock, method="GET") + resp.begin() + resp.fp = io.BufferedReader(resp.fp) + self._verify_readline(resp.readline, self.lines_expected) + resp.close() + + def _verify_readline(self, readline, expected): + all = [] + while True: + # short readlines + line = readline(5) + if line and line != b"foo": + if len(line) < 5: + self.assertTrue(line.endswith(b"\n")) + all.append(line) + if not line: + break + self.assertEqual(b"".join(all), expected) + + def test_read1(self): + sock = FakeSocket(self.lines) + resp = client.HTTPResponse(sock, method="GET") + resp.begin() + resp.fp = io.BufferedReader(resp.fp) + def r(): + res = resp.read1(4) + self.assertLessEqual(len(res), 4) + return res + readliner = Readliner(r) + self._verify_readline(readliner.readline, self.lines_expected) + resp.close() + + def test_read1_chunked(self): + sock = FakeSocket(self.lines_chunked) + resp = client.HTTPResponse(sock, method="GET") + resp.begin() + resp.fp = io.BufferedReader(resp.fp) + def r(): + res = resp.read1(4) + self.assertLessEqual(len(res), 4) + return res + readliner = Readliner(r) + self._verify_readline(readliner.readline, self.lines_expected) + resp.close() + +class Readliner: + """ + a simple readline class that uses an arbitrary read function and buffering + """ + def __init__(self, readfunc): + self.readfunc = readfunc + self.remainder = b"" + + def readline(self, limit): + data = [] + datalen = 0 + read = self.remainder + try: + while True: + idx = read.find(b'\n') + if idx != -1: + break + if datalen + len(read) >= limit: + idx = limit - datalen - 1 + # read more data + data.append(read) + read = self.readfunc() + if not read: + idx = 0 #eof condition + break + idx += 1 + data.append(read[:idx]) + self.remainder = read[idx:] + return b"".join(data) + except: + self.remainder = b"".join(data) + raise + class OfflineTest(TestCase): def test_responses(self): self.assertEqual(client.responses[client.NOT_FOUND], "Not Found") @@ -961,7 +1122,7 @@ def test_main(verbose=None): support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest, HTTPSTest, RequestBodyTest, SourceAddressTest, - HTTPResponseTest) + HTTPResponseTest, ExtendedReadTest) if __name__ == '__main__': test_main()