Issue #19170: telnetlib: use selectors.

This commit is contained in:
Charles-François Natali 2013-10-21 14:02:12 +02:00
parent 6633c39af8
commit 6459025b24
2 changed files with 100 additions and 274 deletions

View File

@ -17,13 +17,12 @@ guido Guido van Rossum pts/2 <Dec 2 11:10> snag.cnri.reston..
Note that read_all() won't read until eof -- it just reads some data Note that read_all() won't read until eof -- it just reads some data
-- but it guarantees to read at least one byte unless EOF is hit. -- but it guarantees to read at least one byte unless EOF is hit.
It is possible to pass a Telnet object to select.select() in order to It is possible to pass a Telnet object to a selector in order to wait until
wait until more data is available. Note that in this case, more data is available. Note that in this case, read_eager() may return b''
read_eager() may return b'' even if there was data on the socket, even if there was data on the socket, because the protocol negotiation may have
because the protocol negotiation may have eaten the data. This is why eaten the data. This is why EOFError is needed in some cases to distinguish
EOFError is needed in some cases to distinguish between "no data" and between "no data" and "connection closed" (since the socket also appears ready
"connection closed" (since the socket also appears ready for reading for reading when it is closed).
when it is closed).
To do: To do:
- option negotiation - option negotiation
@ -34,10 +33,9 @@ To do:
# Imported modules # Imported modules
import errno
import sys import sys
import socket import socket
import select import selectors
__all__ = ["Telnet"] __all__ = ["Telnet"]
@ -130,6 +128,15 @@ PRAGMA_HEARTBEAT = bytes([140]) # TELOPT PRAGMA HEARTBEAT
EXOPL = bytes([255]) # Extended-Options-List EXOPL = bytes([255]) # Extended-Options-List
NOOPT = bytes([0]) NOOPT = bytes([0])
# poll/select have the advantage of not requiring any extra file descriptor,
# contrarily to epoll/kqueue (also, they require a single syscall).
if hasattr(selectors, 'PollSelector'):
_TelnetSelector = selectors.PollSelector
else:
_TelnetSelector = selectors.SelectSelector
class Telnet: class Telnet:
"""Telnet interface class. """Telnet interface class.
@ -206,7 +213,6 @@ class Telnet:
self.sb = 0 # flag for SB and SE sequence. self.sb = 0 # flag for SB and SE sequence.
self.sbdataq = b'' self.sbdataq = b''
self.option_callback = None self.option_callback = None
self._has_poll = hasattr(select, 'poll')
if host is not None: if host is not None:
self.open(host, port, timeout) self.open(host, port, timeout)
@ -288,61 +294,6 @@ class Telnet:
possibly the empty string. Raise EOFError if the connection possibly the empty string. Raise EOFError if the connection
is closed and no cooked data is available. is closed and no cooked data is available.
"""
if self._has_poll:
return self._read_until_with_poll(match, timeout)
else:
return self._read_until_with_select(match, timeout)
def _read_until_with_poll(self, match, timeout):
"""Read until a given string is encountered or until timeout.
This method uses select.poll() to implement the timeout.
"""
n = len(match)
call_timeout = timeout
if timeout is not None:
from time import time
time_start = time()
self.process_rawq()
i = self.cookedq.find(match)
if i < 0:
poller = select.poll()
poll_in_or_priority_flags = select.POLLIN | select.POLLPRI
poller.register(self, poll_in_or_priority_flags)
while i < 0 and not self.eof:
try:
ready = poller.poll(call_timeout)
except OSError as e:
if e.errno == errno.EINTR:
if timeout is not None:
elapsed = time() - time_start
call_timeout = timeout-elapsed
continue
raise
for fd, mode in ready:
if mode & poll_in_or_priority_flags:
i = max(0, len(self.cookedq)-n)
self.fill_rawq()
self.process_rawq()
i = self.cookedq.find(match, i)
if timeout is not None:
elapsed = time() - time_start
if elapsed >= timeout:
break
call_timeout = timeout-elapsed
poller.unregister(self)
if i >= 0:
i = i + n
buf = self.cookedq[:i]
self.cookedq = self.cookedq[i:]
return buf
return self.read_very_lazy()
def _read_until_with_select(self, match, timeout=None):
"""Read until a given string is encountered or until timeout.
The timeout is implemented using select.select().
""" """
n = len(match) n = len(match)
self.process_rawq() self.process_rawq()
@ -352,27 +303,26 @@ class Telnet:
buf = self.cookedq[:i] buf = self.cookedq[:i]
self.cookedq = self.cookedq[i:] self.cookedq = self.cookedq[i:]
return buf return buf
s_reply = ([self], [], [])
s_args = s_reply
if timeout is not None: if timeout is not None:
s_args = s_args + (timeout,)
from time import time from time import time
time_start = time() deadline = time() + timeout
while not self.eof and select.select(*s_args) == s_reply: with _TelnetSelector() as selector:
i = max(0, len(self.cookedq)-n) selector.register(self, selectors.EVENT_READ)
self.fill_rawq() while not self.eof:
self.process_rawq() if selector.select(timeout):
i = self.cookedq.find(match, i) i = max(0, len(self.cookedq)-n)
if i >= 0: self.fill_rawq()
i = i+n self.process_rawq()
buf = self.cookedq[:i] i = self.cookedq.find(match, i)
self.cookedq = self.cookedq[i:] if i >= 0:
return buf i = i+n
if timeout is not None: buf = self.cookedq[:i]
elapsed = time() - time_start self.cookedq = self.cookedq[i:]
if elapsed >= timeout: return buf
break if timeout is not None:
s_args = s_reply + (timeout-elapsed,) timeout = deadline - time()
if timeout < 0:
break
return self.read_very_lazy() return self.read_very_lazy()
def read_all(self): def read_all(self):
@ -577,29 +527,35 @@ class Telnet:
def sock_avail(self): def sock_avail(self):
"""Test whether data is available on the socket.""" """Test whether data is available on the socket."""
return select.select([self], [], [], 0) == ([self], [], []) with _TelnetSelector() as selector:
selector.register(self, selectors.EVENT_READ)
return bool(selector.select(0))
def interact(self): def interact(self):
"""Interaction function, emulates a very dumb telnet client.""" """Interaction function, emulates a very dumb telnet client."""
if sys.platform == "win32": if sys.platform == "win32":
self.mt_interact() self.mt_interact()
return return
while 1: with _TelnetSelector() as selector:
rfd, wfd, xfd = select.select([self, sys.stdin], [], []) selector.register(self, selectors.EVENT_READ)
if self in rfd: selector.register(sys.stdin, selectors.EVENT_READ)
try:
text = self.read_eager() while True:
except EOFError: for key, events in selector.select():
print('*** Connection closed by remote host ***') if key.fileobj is self:
break try:
if text: text = self.read_eager()
sys.stdout.write(text.decode('ascii')) except EOFError:
sys.stdout.flush() print('*** Connection closed by remote host ***')
if sys.stdin in rfd: return
line = sys.stdin.readline().encode('ascii') if text:
if not line: sys.stdout.write(text.decode('ascii'))
break sys.stdout.flush()
self.write(line) elif key.fileobj is sys.stdin:
line = sys.stdin.readline().encode('ascii')
if not line:
return
self.write(line)
def mt_interact(self): def mt_interact(self):
"""Multithreaded version of interact().""" """Multithreaded version of interact()."""
@ -645,79 +601,6 @@ class Telnet:
or if more than one expression can match the same input, the or if more than one expression can match the same input, the
results are undeterministic, and may depend on the I/O timing. results are undeterministic, and may depend on the I/O timing.
"""
if self._has_poll:
return self._expect_with_poll(list, timeout)
else:
return self._expect_with_select(list, timeout)
def _expect_with_poll(self, expect_list, timeout=None):
"""Read until one from a list of a regular expressions matches.
This method uses select.poll() to implement the timeout.
"""
re = None
expect_list = expect_list[:]
indices = range(len(expect_list))
for i in indices:
if not hasattr(expect_list[i], "search"):
if not re: import re
expect_list[i] = re.compile(expect_list[i])
call_timeout = timeout
if timeout is not None:
from time import time
time_start = time()
self.process_rawq()
m = None
for i in indices:
m = expect_list[i].search(self.cookedq)
if m:
e = m.end()
text = self.cookedq[:e]
self.cookedq = self.cookedq[e:]
break
if not m:
poller = select.poll()
poll_in_or_priority_flags = select.POLLIN | select.POLLPRI
poller.register(self, poll_in_or_priority_flags)
while not m and not self.eof:
try:
ready = poller.poll(call_timeout)
except OSError as e:
if e.errno == errno.EINTR:
if timeout is not None:
elapsed = time() - time_start
call_timeout = timeout-elapsed
continue
raise
for fd, mode in ready:
if mode & poll_in_or_priority_flags:
self.fill_rawq()
self.process_rawq()
for i in indices:
m = expect_list[i].search(self.cookedq)
if m:
e = m.end()
text = self.cookedq[:e]
self.cookedq = self.cookedq[e:]
break
if timeout is not None:
elapsed = time() - time_start
if elapsed >= timeout:
break
call_timeout = timeout-elapsed
poller.unregister(self)
if m:
return (i, m, text)
text = self.read_very_lazy()
if not text and self.eof:
raise EOFError
return (-1, None, text)
def _expect_with_select(self, list, timeout=None):
"""Read until one from a list of a regular expressions matches.
The timeout is implemented using select.select().
""" """
re = None re = None
list = list[:] list = list[:]
@ -728,27 +611,27 @@ class Telnet:
list[i] = re.compile(list[i]) list[i] = re.compile(list[i])
if timeout is not None: if timeout is not None:
from time import time from time import time
time_start = time() deadline = time() + timeout
while 1: with _TelnetSelector() as selector:
self.process_rawq() selector.register(self, selectors.EVENT_READ)
for i in indices: while not self.eof:
m = list[i].search(self.cookedq) self.process_rawq()
if m: for i in indices:
e = m.end() m = list[i].search(self.cookedq)
text = self.cookedq[:e] if m:
self.cookedq = self.cookedq[e:] e = m.end()
return (i, m, text) text = self.cookedq[:e]
if self.eof: self.cookedq = self.cookedq[e:]
break return (i, m, text)
if timeout is not None: if timeout is not None:
elapsed = time() - time_start ready = selector.select(timeout)
if elapsed >= timeout: timeout = deadline - time()
break if not ready:
s_args = ([self.fileno()], [], [], timeout-elapsed) if timeout < 0:
r, w, x = select.select(*s_args) break
if not r: else:
break continue
self.fill_rawq() self.fill_rawq()
text = self.read_very_lazy() text = self.read_very_lazy()
if not text and self.eof: if not text and self.eof:
raise EOFError raise EOFError

View File

@ -1,10 +1,9 @@
import socket import socket
import select import selectors
import telnetlib import telnetlib
import time import time
import contextlib import contextlib
import unittest
from unittest import TestCase from unittest import TestCase
from test import support from test import support
threading = support.import_module('threading') threading = support.import_module('threading')
@ -112,40 +111,32 @@ class TelnetAlike(telnetlib.Telnet):
self._messages += out.getvalue() self._messages += out.getvalue()
return return
def mock_select(*s_args): class MockSelector(selectors.BaseSelector):
block = False
for l in s_args:
for fob in l:
if isinstance(fob, TelnetAlike):
block = fob.sock.block
if block:
return [[], [], []]
else:
return s_args
class MockPoller(object):
test_case = None # Set during TestCase setUp.
def __init__(self): def __init__(self):
self._file_objs = [] super().__init__()
self.keys = {}
def register(self, fd, eventmask): def register(self, fileobj, events, data=None):
self.test_case.assertTrue(hasattr(fd, 'fileno'), fd) key = selectors.SelectorKey(fileobj, 0, events, data)
self.test_case.assertEqual(eventmask, select.POLLIN|select.POLLPRI) self.keys[fileobj] = key
self._file_objs.append(fd) return key
def poll(self, timeout=None): def unregister(self, fileobj):
key = self.keys.pop(fileobj)
return key
def select(self, timeout=None):
block = False block = False
for fob in self._file_objs: for fileobj in self.keys:
if isinstance(fob, TelnetAlike): if isinstance(fileobj, TelnetAlike):
block = fob.sock.block block = fileobj.sock.block
break
if block: if block:
return [] return []
else: else:
return zip(self._file_objs, [select.POLLIN]*len(self._file_objs)) return [(key, key.events) for key in self.keys.values()]
def unregister(self, fd):
self._file_objs.remove(fd)
@contextlib.contextmanager @contextlib.contextmanager
def test_socket(reads): def test_socket(reads):
@ -159,7 +150,7 @@ def test_socket(reads):
socket.create_connection = old_conn socket.create_connection = old_conn
return return
def test_telnet(reads=(), cls=TelnetAlike, use_poll=None): def test_telnet(reads=(), cls=TelnetAlike):
''' return a telnetlib.Telnet object that uses a SocketStub with ''' return a telnetlib.Telnet object that uses a SocketStub with
reads queued up to be read ''' reads queued up to be read '''
for x in reads: for x in reads:
@ -167,29 +158,14 @@ def test_telnet(reads=(), cls=TelnetAlike, use_poll=None):
with test_socket(reads): with test_socket(reads):
telnet = cls('dummy', 0) telnet = cls('dummy', 0)
telnet._messages = '' # debuglevel output telnet._messages = '' # debuglevel output
if use_poll is not None:
if use_poll and not telnet._has_poll:
raise unittest.SkipTest('select.poll() required.')
telnet._has_poll = use_poll
return telnet return telnet
class ExpectAndReadTestCase(TestCase): class ExpectAndReadTestCase(TestCase):
def setUp(self): def setUp(self):
self.old_select = select.select self.old_selector = telnetlib._TelnetSelector
select.select = mock_select telnetlib._TelnetSelector = MockSelector
self.old_poll = False
if hasattr(select, 'poll'):
self.old_poll = select.poll
select.poll = MockPoller
MockPoller.test_case = self
def tearDown(self): def tearDown(self):
if self.old_poll: telnetlib._TelnetSelector = self.old_selector
MockPoller.test_case = None
select.poll = self.old_poll
select.select = self.old_select
class ReadTests(ExpectAndReadTestCase): class ReadTests(ExpectAndReadTestCase):
def test_read_until(self): def test_read_until(self):
@ -208,22 +184,6 @@ class ReadTests(ExpectAndReadTestCase):
data = telnet.read_until(b'match') data = telnet.read_until(b'match')
self.assertEqual(data, expect) self.assertEqual(data, expect)
def test_read_until_with_poll(self):
"""Use select.poll() to implement telnet.read_until()."""
want = [b'x' * 10, b'match', b'y' * 10]
telnet = test_telnet(want, use_poll=True)
select.select = lambda *_: self.fail('unexpected select() call.')
data = telnet.read_until(b'match')
self.assertEqual(data, b''.join(want[:-1]))
def test_read_until_with_select(self):
"""Use select.select() to implement telnet.read_until()."""
want = [b'x' * 10, b'match', b'y' * 10]
telnet = test_telnet(want, use_poll=False)
if self.old_poll:
select.poll = lambda *_: self.fail('unexpected poll() call.')
data = telnet.read_until(b'match')
self.assertEqual(data, b''.join(want[:-1]))
def test_read_all(self): def test_read_all(self):
""" """
@ -427,23 +387,6 @@ class ExpectTests(ExpectAndReadTestCase):
(_,_,data) = telnet.expect([b'match']) (_,_,data) = telnet.expect([b'match'])
self.assertEqual(data, b''.join(want[:-1])) self.assertEqual(data, b''.join(want[:-1]))
def test_expect_with_poll(self):
"""Use select.poll() to implement telnet.expect()."""
want = [b'x' * 10, b'match', b'y' * 10]
telnet = test_telnet(want, use_poll=True)
select.select = lambda *_: self.fail('unexpected select() call.')
(_,_,data) = telnet.expect([b'match'])
self.assertEqual(data, b''.join(want[:-1]))
def test_expect_with_select(self):
"""Use select.select() to implement telnet.expect()."""
want = [b'x' * 10, b'match', b'y' * 10]
telnet = test_telnet(want, use_poll=False)
if self.old_poll:
select.poll = lambda *_: self.fail('unexpected poll() call.')
(_,_,data) = telnet.expect([b'match'])
self.assertEqual(data, b''.join(want[:-1]))
def test_main(verbose=None): def test_main(verbose=None):
support.run_unittest(GeneralTests, ReadTests, WriteTests, OptionTests, support.run_unittest(GeneralTests, ReadTests, WriteTests, OptionTests,