diff -r 540a9c69c2ea Lib/telnetlib.py --- a/Lib/telnetlib.py Fri Sep 13 19:53:08 2013 +0200 +++ b/Lib/telnetlib.py Sat Oct 05 15:15:53 2013 +0200 @@ -17,13 +17,12 @@ Note that read_all() won't read until eof -- it just reads some data -- but it guarantees to read at least one byte unless EOF is hit. -It is possible to pass a Telnet object to select.select() in order to -wait until more data is available. Note that in this case, -read_eager() may return b'' even if there was data on the socket, -because the protocol negotiation may have eaten the data. This is why -EOFError is needed in some cases to distinguish between "no data" and -"connection closed" (since the socket also appears ready for reading -when it is closed). +It is possible to pass a Telnet object to a selector in order to wait until +more data is available. Note that in this case, read_eager() may return b'' +even if there was data on the socket, because the protocol negotiation may have +eaten the data. This is why EOFError is needed in some cases to distinguish +between "no data" and "connection closed" (since the socket also appears ready +for reading when it is closed). To do: - option negotiation @@ -34,10 +33,9 @@ # Imported modules -import errno import sys import socket -import select +import selectors __all__ = ["Telnet"] @@ -130,6 +128,15 @@ EXOPL = bytes([255]) # Extended-Options-List NOOPT = bytes([0]) + +# poll/select have the advantage of not requiring any extra file descriptor, +# contrarily to epoll/kqueue (also, they require a single syscall). +if hasattr(selectors, 'PollSelector'): + _TelnetSelector = selectors.PollSelector +else: + _TelnetSelector = selectors.SelectSelector + + class Telnet: """Telnet interface class. @@ -206,7 +213,6 @@ self.sb = 0 # flag for SB and SE sequence. self.sbdataq = b'' self.option_callback = None - self._has_poll = hasattr(select, 'poll') if host is not None: self.open(host, port, timeout) @@ -289,61 +295,6 @@ is closed and no cooked data is available. """ - if self._has_poll: - return self._read_until_with_poll(match, timeout) - else: - return self._read_until_with_select(match, timeout) - - def _read_until_with_poll(self, match, timeout): - """Read until a given string is encountered or until timeout. - - This method uses select.poll() to implement the timeout. - """ - n = len(match) - call_timeout = timeout - if timeout is not None: - from time import time - time_start = time() - self.process_rawq() - i = self.cookedq.find(match) - if i < 0: - poller = select.poll() - poll_in_or_priority_flags = select.POLLIN | select.POLLPRI - poller.register(self, poll_in_or_priority_flags) - while i < 0 and not self.eof: - try: - ready = poller.poll(call_timeout) - except OSError as e: - if e.errno == errno.EINTR: - if timeout is not None: - elapsed = time() - time_start - call_timeout = timeout-elapsed - continue - raise - for fd, mode in ready: - if mode & poll_in_or_priority_flags: - i = max(0, len(self.cookedq)-n) - self.fill_rawq() - self.process_rawq() - i = self.cookedq.find(match, i) - if timeout is not None: - elapsed = time() - time_start - if elapsed >= timeout: - break - call_timeout = timeout-elapsed - poller.unregister(self) - if i >= 0: - i = i + n - buf = self.cookedq[:i] - self.cookedq = self.cookedq[i:] - return buf - return self.read_very_lazy() - - def _read_until_with_select(self, match, timeout=None): - """Read until a given string is encountered or until timeout. - - The timeout is implemented using select.select(). - """ n = len(match) self.process_rawq() i = self.cookedq.find(match) @@ -352,27 +303,26 @@ buf = self.cookedq[:i] self.cookedq = self.cookedq[i:] return buf - s_reply = ([self], [], []) - s_args = s_reply if timeout is not None: - s_args = s_args + (timeout,) from time import time - time_start = time() - while not self.eof and select.select(*s_args) == s_reply: - i = max(0, len(self.cookedq)-n) - self.fill_rawq() - self.process_rawq() - i = self.cookedq.find(match, i) - if i >= 0: - i = i+n - buf = self.cookedq[:i] - self.cookedq = self.cookedq[i:] - return buf - if timeout is not None: - elapsed = time() - time_start - if elapsed >= timeout: - break - s_args = s_reply + (timeout-elapsed,) + deadline = time() + timeout + with _TelnetSelector() as selector: + selector.register(self, selectors.EVENT_READ) + while not self.eof: + if selector.select(timeout): + i = max(0, len(self.cookedq)-n) + self.fill_rawq() + self.process_rawq() + i = self.cookedq.find(match, i) + if i >= 0: + i = i+n + buf = self.cookedq[:i] + self.cookedq = self.cookedq[i:] + return buf + if timeout is not None: + timeout = deadline - time() + if timeout < 0: + break return self.read_very_lazy() def read_all(self): @@ -577,29 +527,35 @@ def sock_avail(self): """Test whether data is available on the socket.""" - return select.select([self], [], [], 0) == ([self], [], []) + with _TelnetSelector() as selector: + selector.register(self, selectors.EVENT_READ) + return bool(selector.select(0)) def interact(self): """Interaction function, emulates a very dumb telnet client.""" if sys.platform == "win32": self.mt_interact() return - while 1: - rfd, wfd, xfd = select.select([self, sys.stdin], [], []) - if self in rfd: - try: - text = self.read_eager() - except EOFError: - print('*** Connection closed by remote host ***') - break - if text: - sys.stdout.write(text.decode('ascii')) - sys.stdout.flush() - if sys.stdin in rfd: - line = sys.stdin.readline().encode('ascii') - if not line: - break - self.write(line) + with _TelnetSelector() as selector: + selector.register(self, selectors.EVENT_READ) + selector.register(sys.stdin, selectors.EVENT_READ) + + while True: + for key, events in selector.select(): + if key.fileobj is self: + try: + text = self.read_eager() + except EOFError: + print('*** Connection closed by remote host ***') + return + if text: + sys.stdout.write(text.decode('ascii')) + sys.stdout.flush() + elif key.fileobj is sys.stdin: + line = sys.stdin.readline().encode('ascii') + if not line: + return + self.write(line) def mt_interact(self): """Multithreaded version of interact().""" @@ -646,79 +602,6 @@ results are undeterministic, and may depend on the I/O timing. """ - if self._has_poll: - return self._expect_with_poll(list, timeout) - else: - return self._expect_with_select(list, timeout) - - def _expect_with_poll(self, expect_list, timeout=None): - """Read until one from a list of a regular expressions matches. - - This method uses select.poll() to implement the timeout. - """ - re = None - expect_list = expect_list[:] - indices = range(len(expect_list)) - for i in indices: - if not hasattr(expect_list[i], "search"): - if not re: import re - expect_list[i] = re.compile(expect_list[i]) - call_timeout = timeout - if timeout is not None: - from time import time - time_start = time() - self.process_rawq() - m = None - for i in indices: - m = expect_list[i].search(self.cookedq) - if m: - e = m.end() - text = self.cookedq[:e] - self.cookedq = self.cookedq[e:] - break - if not m: - poller = select.poll() - poll_in_or_priority_flags = select.POLLIN | select.POLLPRI - poller.register(self, poll_in_or_priority_flags) - while not m and not self.eof: - try: - ready = poller.poll(call_timeout) - except OSError as e: - if e.errno == errno.EINTR: - if timeout is not None: - elapsed = time() - time_start - call_timeout = timeout-elapsed - continue - raise - for fd, mode in ready: - if mode & poll_in_or_priority_flags: - self.fill_rawq() - self.process_rawq() - for i in indices: - m = expect_list[i].search(self.cookedq) - if m: - e = m.end() - text = self.cookedq[:e] - self.cookedq = self.cookedq[e:] - break - if timeout is not None: - elapsed = time() - time_start - if elapsed >= timeout: - break - call_timeout = timeout-elapsed - poller.unregister(self) - if m: - return (i, m, text) - text = self.read_very_lazy() - if not text and self.eof: - raise EOFError - return (-1, None, text) - - def _expect_with_select(self, list, timeout=None): - """Read until one from a list of a regular expressions matches. - - The timeout is implemented using select.select(). - """ re = None list = list[:] indices = range(len(list)) @@ -728,27 +611,27 @@ list[i] = re.compile(list[i]) if timeout is not None: from time import time - time_start = time() - while 1: - self.process_rawq() - for i in indices: - m = list[i].search(self.cookedq) - if m: - e = m.end() - text = self.cookedq[:e] - self.cookedq = self.cookedq[e:] - return (i, m, text) - if self.eof: - break - if timeout is not None: - elapsed = time() - time_start - if elapsed >= timeout: - break - s_args = ([self.fileno()], [], [], timeout-elapsed) - r, w, x = select.select(*s_args) - if not r: - break - self.fill_rawq() + deadline = time() + timeout + with _TelnetSelector() as selector: + selector.register(self, selectors.EVENT_READ) + while not self.eof: + self.process_rawq() + for i in indices: + m = list[i].search(self.cookedq) + if m: + e = m.end() + text = self.cookedq[:e] + self.cookedq = self.cookedq[e:] + return (i, m, text) + if timeout is not None: + ready = selector.select(timeout) + timeout = deadline - time() + if not ready: + if timeout < 0: + break + else: + continue + self.fill_rawq() text = self.read_very_lazy() if not text and self.eof: raise EOFError diff -r 540a9c69c2ea Lib/test/test_telnetlib.py --- a/Lib/test/test_telnetlib.py Fri Sep 13 19:53:08 2013 +0200 +++ b/Lib/test/test_telnetlib.py Sat Oct 05 15:15:53 2013 +0200 @@ -1,10 +1,9 @@ import socket -import select +import selectors import telnetlib import time import contextlib -import unittest from unittest import TestCase from test import support threading = support.import_module('threading') @@ -112,40 +111,32 @@ self._messages += out.getvalue() return -def mock_select(*s_args): - block = False - for l in s_args: - for fob in l: - if isinstance(fob, TelnetAlike): - block = fob.sock.block - if block: - return [[], [], []] - else: - return s_args - -class MockPoller(object): - test_case = None # Set during TestCase setUp. +class MockSelector(selectors.BaseSelector): def __init__(self): - self._file_objs = [] + super().__init__() + self.keys = {} - def register(self, fd, eventmask): - self.test_case.assertTrue(hasattr(fd, 'fileno'), fd) - self.test_case.assertEqual(eventmask, select.POLLIN|select.POLLPRI) - self._file_objs.append(fd) + def register(self, fileobj, events, data=None): + key = selectors.SelectorKey(fileobj, 0, events, data) + self.keys[fileobj] = key + return key - def poll(self, timeout=None): + def unregister(self, fileobj): + key = self.keys.pop(fileobj) + return key + + def select(self, timeout=None): block = False - for fob in self._file_objs: - if isinstance(fob, TelnetAlike): - block = fob.sock.block + for fileobj in self.keys: + if isinstance(fileobj, TelnetAlike): + block = fileobj.sock.block + break if block: return [] else: - return zip(self._file_objs, [select.POLLIN]*len(self._file_objs)) + return [(key, key.events) for key in self.keys.values()] - def unregister(self, fd): - self._file_objs.remove(fd) @contextlib.contextmanager def test_socket(reads): @@ -159,7 +150,7 @@ socket.create_connection = old_conn return -def test_telnet(reads=(), cls=TelnetAlike, use_poll=None): +def test_telnet(reads=[], cls=TelnetAlike): ''' return a telnetlib.Telnet object that uses a SocketStub with reads queued up to be read ''' for x in reads: @@ -167,31 +158,15 @@ with test_socket(reads): telnet = cls('dummy', 0) telnet._messages = '' # debuglevel output - if use_poll is not None: - if use_poll and not telnet._has_poll: - raise unittest.SkipTest('select.poll() required.') - telnet._has_poll = use_poll return telnet +class ReadTests(TestCase): + def setUp(self): + self.old_selector = telnetlib._TelnetSelector + telnetlib._TelnetSelector = MockSelector + def tearDown(self): + telnetlib._TelnetSelector = self.old_selector -class ExpectAndReadTestCase(TestCase): - def setUp(self): - self.old_select = select.select - select.select = mock_select - self.old_poll = False - if hasattr(select, 'poll'): - self.old_poll = select.poll - select.poll = MockPoller - MockPoller.test_case = self - - def tearDown(self): - if self.old_poll: - MockPoller.test_case = None - select.poll = self.old_poll - select.select = self.old_select - - -class ReadTests(ExpectAndReadTestCase): def test_read_until(self): """ read_until(expected, timeout=None) @@ -208,22 +183,6 @@ data = telnet.read_until(b'match') self.assertEqual(data, expect) - def test_read_until_with_poll(self): - """Use select.poll() to implement telnet.read_until().""" - want = [b'x' * 10, b'match', b'y' * 10] - telnet = test_telnet(want, use_poll=True) - select.select = lambda *_: self.fail('unexpected select() call.') - data = telnet.read_until(b'match') - self.assertEqual(data, b''.join(want[:-1])) - - def test_read_until_with_select(self): - """Use select.select() to implement telnet.read_until().""" - want = [b'x' * 10, b'match', b'y' * 10] - telnet = test_telnet(want, use_poll=False) - if self.old_poll: - select.poll = lambda *_: self.fail('unexpected poll() call.') - data = telnet.read_until(b'match') - self.assertEqual(data, b''.join(want[:-1])) def test_read_all(self): """ @@ -415,39 +374,8 @@ self.assertRegex(telnet._messages, r'0.*test') -class ExpectTests(ExpectAndReadTestCase): - def test_expect(self): - """ - expect(expected, [timeout]) - Read until the expected string has been seen, or a timeout is - hit (default is no timeout); may block. - """ - want = [b'x' * 10, b'match', b'y' * 10] - telnet = test_telnet(want) - (_,_,data) = telnet.expect([b'match']) - self.assertEqual(data, b''.join(want[:-1])) - - def test_expect_with_poll(self): - """Use select.poll() to implement telnet.expect().""" - want = [b'x' * 10, b'match', b'y' * 10] - telnet = test_telnet(want, use_poll=True) - select.select = lambda *_: self.fail('unexpected select() call.') - (_,_,data) = telnet.expect([b'match']) - self.assertEqual(data, b''.join(want[:-1])) - - def test_expect_with_select(self): - """Use select.select() to implement telnet.expect().""" - want = [b'x' * 10, b'match', b'y' * 10] - telnet = test_telnet(want, use_poll=False) - if self.old_poll: - select.poll = lambda *_: self.fail('unexpected poll() call.') - (_,_,data) = telnet.expect([b'match']) - self.assertEqual(data, b''.join(want[:-1])) - - def test_main(verbose=None): - support.run_unittest(GeneralTests, ReadTests, WriteTests, OptionTests, - ExpectTests) + support.run_unittest(GeneralTests, ReadTests, WriteTests, OptionTests) if __name__ == '__main__': test_main()