Fixes Issue #14635: telnetlib will use poll() rather than select() when possible
to avoid failing due to the select() file descriptor limit.
This commit is contained in:
parent
4774946c3b
commit
dad5711677
130
Lib/telnetlib.py
130
Lib/telnetlib.py
|
@ -34,6 +34,7 @@ To do:
|
|||
|
||||
|
||||
# Imported modules
|
||||
import errno
|
||||
import sys
|
||||
import socket
|
||||
import select
|
||||
|
@ -205,6 +206,7 @@ class Telnet:
|
|||
self.sb = 0 # flag for SB and SE sequence.
|
||||
self.sbdataq = b''
|
||||
self.option_callback = None
|
||||
self._has_poll = hasattr(select, 'poll')
|
||||
if host is not None:
|
||||
self.open(host, port, timeout)
|
||||
|
||||
|
@ -286,6 +288,61 @@ class Telnet:
|
|||
possibly the empty string. Raise EOFError if the connection
|
||||
is closed and no cooked data is available.
|
||||
|
||||
"""
|
||||
if self._has_poll:
|
||||
return self._read_until_with_poll(match, timeout)
|
||||
else:
|
||||
return self._read_until_with_select(match, timeout)
|
||||
|
||||
def _read_until_with_poll(self, match, timeout):
|
||||
"""Read until a given string is encountered or until timeout.
|
||||
|
||||
This method uses select.poll() to implement the timeout.
|
||||
"""
|
||||
n = len(match)
|
||||
call_timeout = timeout
|
||||
if timeout is not None:
|
||||
from time import time
|
||||
time_start = time()
|
||||
self.process_rawq()
|
||||
i = self.cookedq.find(match)
|
||||
if i < 0:
|
||||
poller = select.poll()
|
||||
poll_in_or_priority_flags = select.POLLIN | select.POLLPRI
|
||||
poller.register(self, poll_in_or_priority_flags)
|
||||
while i < 0 and not self.eof:
|
||||
try:
|
||||
ready = poller.poll(call_timeout)
|
||||
except select.error as e:
|
||||
if e.errno == errno.EINTR:
|
||||
if timeout is not None:
|
||||
elapsed = time() - time_start
|
||||
call_timeout = timeout-elapsed
|
||||
continue
|
||||
raise
|
||||
for fd, mode in ready:
|
||||
if mode & poll_in_or_priority_flags:
|
||||
i = max(0, len(self.cookedq)-n)
|
||||
self.fill_rawq()
|
||||
self.process_rawq()
|
||||
i = self.cookedq.find(match, i)
|
||||
if timeout is not None:
|
||||
elapsed = time() - time_start
|
||||
if elapsed >= timeout:
|
||||
break
|
||||
call_timeout = timeout-elapsed
|
||||
poller.unregister(self)
|
||||
if i >= 0:
|
||||
i = i + n
|
||||
buf = self.cookedq[:i]
|
||||
self.cookedq = self.cookedq[i:]
|
||||
return buf
|
||||
return self.read_very_lazy()
|
||||
|
||||
def _read_until_with_select(self, match, timeout=None):
|
||||
"""Read until a given string is encountered or until timeout.
|
||||
|
||||
The timeout is implemented using select.select().
|
||||
"""
|
||||
n = len(match)
|
||||
self.process_rawq()
|
||||
|
@ -588,6 +645,79 @@ class Telnet:
|
|||
or if more than one expression can match the same input, the
|
||||
results are undeterministic, and may depend on the I/O timing.
|
||||
|
||||
"""
|
||||
if self._has_poll:
|
||||
return self._expect_with_poll(list, timeout)
|
||||
else:
|
||||
return self._expect_with_select(list, timeout)
|
||||
|
||||
def _expect_with_poll(self, expect_list, timeout=None):
|
||||
"""Read until one from a list of a regular expressions matches.
|
||||
|
||||
This method uses select.poll() to implement the timeout.
|
||||
"""
|
||||
re = None
|
||||
expect_list = expect_list[:]
|
||||
indices = range(len(expect_list))
|
||||
for i in indices:
|
||||
if not hasattr(expect_list[i], "search"):
|
||||
if not re: import re
|
||||
expect_list[i] = re.compile(expect_list[i])
|
||||
call_timeout = timeout
|
||||
if timeout is not None:
|
||||
from time import time
|
||||
time_start = time()
|
||||
self.process_rawq()
|
||||
m = None
|
||||
for i in indices:
|
||||
m = expect_list[i].search(self.cookedq)
|
||||
if m:
|
||||
e = m.end()
|
||||
text = self.cookedq[:e]
|
||||
self.cookedq = self.cookedq[e:]
|
||||
break
|
||||
if not m:
|
||||
poller = select.poll()
|
||||
poll_in_or_priority_flags = select.POLLIN | select.POLLPRI
|
||||
poller.register(self, poll_in_or_priority_flags)
|
||||
while not m and not self.eof:
|
||||
try:
|
||||
ready = poller.poll(call_timeout)
|
||||
except select.error as e:
|
||||
if e.errno == errno.EINTR:
|
||||
if timeout is not None:
|
||||
elapsed = time() - time_start
|
||||
call_timeout = timeout-elapsed
|
||||
continue
|
||||
raise
|
||||
for fd, mode in ready:
|
||||
if mode & poll_in_or_priority_flags:
|
||||
self.fill_rawq()
|
||||
self.process_rawq()
|
||||
for i in indices:
|
||||
m = expect_list[i].search(self.cookedq)
|
||||
if m:
|
||||
e = m.end()
|
||||
text = self.cookedq[:e]
|
||||
self.cookedq = self.cookedq[e:]
|
||||
break
|
||||
if timeout is not None:
|
||||
elapsed = time() - time_start
|
||||
if elapsed >= timeout:
|
||||
break
|
||||
call_timeout = timeout-elapsed
|
||||
poller.unregister(self)
|
||||
if m:
|
||||
return (i, m, text)
|
||||
text = self.read_very_lazy()
|
||||
if not text and self.eof:
|
||||
raise EOFError
|
||||
return (-1, None, text)
|
||||
|
||||
def _expect_with_select(self, list, timeout=None):
|
||||
"""Read until one from a list of a regular expressions matches.
|
||||
|
||||
The timeout is implemented using select.select().
|
||||
"""
|
||||
re = None
|
||||
list = list[:]
|
||||
|
|
|
@ -75,8 +75,8 @@ class GeneralTests(TestCase):
|
|||
|
||||
class SocketStub(object):
|
||||
''' a socket proxy that re-defines sendall() '''
|
||||
def __init__(self, reads=[]):
|
||||
self.reads = reads
|
||||
def __init__(self, reads=()):
|
||||
self.reads = list(reads) # Intentionally make a copy.
|
||||
self.writes = []
|
||||
self.block = False
|
||||
def sendall(self, data):
|
||||
|
@ -102,7 +102,7 @@ class TelnetAlike(telnetlib.Telnet):
|
|||
self._messages += out.getvalue()
|
||||
return
|
||||
|
||||
def new_select(*s_args):
|
||||
def mock_select(*s_args):
|
||||
block = False
|
||||
for l in s_args:
|
||||
for fob in l:
|
||||
|
@ -113,6 +113,30 @@ def new_select(*s_args):
|
|||
else:
|
||||
return s_args
|
||||
|
||||
class MockPoller(object):
|
||||
test_case = None # Set during TestCase setUp.
|
||||
|
||||
def __init__(self):
|
||||
self._file_objs = []
|
||||
|
||||
def register(self, fd, eventmask):
|
||||
self.test_case.assertTrue(hasattr(fd, 'fileno'), fd)
|
||||
self.test_case.assertEqual(eventmask, select.POLLIN|select.POLLPRI)
|
||||
self._file_objs.append(fd)
|
||||
|
||||
def poll(self, timeout=None):
|
||||
block = False
|
||||
for fob in self._file_objs:
|
||||
if isinstance(fob, TelnetAlike):
|
||||
block = fob.sock.block
|
||||
if block:
|
||||
return []
|
||||
else:
|
||||
return zip(self._file_objs, [select.POLLIN]*len(self._file_objs))
|
||||
|
||||
def unregister(self, fd):
|
||||
self._file_objs.remove(fd)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def test_socket(reads):
|
||||
def new_conn(*ignored):
|
||||
|
@ -125,7 +149,7 @@ def test_socket(reads):
|
|||
socket.create_connection = old_conn
|
||||
return
|
||||
|
||||
def test_telnet(reads=[], cls=TelnetAlike):
|
||||
def test_telnet(reads=(), cls=TelnetAlike, use_poll=None):
|
||||
''' return a telnetlib.Telnet object that uses a SocketStub with
|
||||
reads queued up to be read '''
|
||||
for x in reads:
|
||||
|
@ -133,15 +157,28 @@ def test_telnet(reads=[], cls=TelnetAlike):
|
|||
with test_socket(reads):
|
||||
telnet = cls('dummy', 0)
|
||||
telnet._messages = '' # debuglevel output
|
||||
if use_poll is not None:
|
||||
if use_poll and not telnet._has_poll:
|
||||
raise unittest.SkipTest('select.poll() required.')
|
||||
telnet._has_poll = use_poll
|
||||
return telnet
|
||||
|
||||
class ReadTests(TestCase):
|
||||
|
||||
class ExpectAndReadTestCase(TestCase):
|
||||
def setUp(self):
|
||||
self.old_select = select.select
|
||||
select.select = new_select
|
||||
self.old_poll = select.poll
|
||||
select.select = mock_select
|
||||
select.poll = MockPoller
|
||||
MockPoller.test_case = self
|
||||
|
||||
def tearDown(self):
|
||||
MockPoller.test_case = None
|
||||
select.poll = self.old_poll
|
||||
select.select = self.old_select
|
||||
|
||||
|
||||
class ReadTests(ExpectAndReadTestCase):
|
||||
def test_read_until(self):
|
||||
"""
|
||||
read_until(expected, timeout=None)
|
||||
|
@ -158,6 +195,21 @@ class ReadTests(TestCase):
|
|||
data = telnet.read_until(b'match')
|
||||
self.assertEqual(data, expect)
|
||||
|
||||
def test_read_until_with_poll(self):
|
||||
"""Use select.poll() to implement telnet.read_until()."""
|
||||
want = [b'x' * 10, b'match', b'y' * 10]
|
||||
telnet = test_telnet(want, use_poll=True)
|
||||
select.select = lambda *_: self.fail('unexpected select() call.')
|
||||
data = telnet.read_until(b'match')
|
||||
self.assertEqual(data, b''.join(want[:-1]))
|
||||
|
||||
def test_read_until_with_select(self):
|
||||
"""Use select.select() to implement telnet.read_until()."""
|
||||
want = [b'x' * 10, b'match', b'y' * 10]
|
||||
telnet = test_telnet(want, use_poll=False)
|
||||
select.poll = lambda *_: self.fail('unexpected poll() call.')
|
||||
data = telnet.read_until(b'match')
|
||||
self.assertEqual(data, b''.join(want[:-1]))
|
||||
|
||||
def test_read_all(self):
|
||||
"""
|
||||
|
@ -349,8 +401,38 @@ class OptionTests(TestCase):
|
|||
self.assertRegex(telnet._messages, r'0.*test')
|
||||
|
||||
|
||||
class ExpectTests(ExpectAndReadTestCase):
|
||||
def test_expect(self):
|
||||
"""
|
||||
expect(expected, [timeout])
|
||||
Read until the expected string has been seen, or a timeout is
|
||||
hit (default is no timeout); may block.
|
||||
"""
|
||||
want = [b'x' * 10, b'match', b'y' * 10]
|
||||
telnet = test_telnet(want)
|
||||
(_,_,data) = telnet.expect([b'match'])
|
||||
self.assertEqual(data, b''.join(want[:-1]))
|
||||
|
||||
def test_expect_with_poll(self):
|
||||
"""Use select.poll() to implement telnet.expect()."""
|
||||
want = [b'x' * 10, b'match', b'y' * 10]
|
||||
telnet = test_telnet(want, use_poll=True)
|
||||
select.select = lambda *_: self.fail('unexpected select() call.')
|
||||
(_,_,data) = telnet.expect([b'match'])
|
||||
self.assertEqual(data, b''.join(want[:-1]))
|
||||
|
||||
def test_expect_with_select(self):
|
||||
"""Use select.select() to implement telnet.expect()."""
|
||||
want = [b'x' * 10, b'match', b'y' * 10]
|
||||
telnet = test_telnet(want, use_poll=False)
|
||||
select.poll = lambda *_: self.fail('unexpected poll() call.')
|
||||
(_,_,data) = telnet.expect([b'match'])
|
||||
self.assertEqual(data, b''.join(want[:-1]))
|
||||
|
||||
|
||||
def test_main(verbose=None):
|
||||
support.run_unittest(GeneralTests, ReadTests, WriteTests, OptionTests)
|
||||
support.run_unittest(GeneralTests, ReadTests, WriteTests, OptionTests,
|
||||
ExpectTests)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_main()
|
||||
|
|
|
@ -410,6 +410,7 @@ Chris Hoffman
|
|||
Albert Hofkamp
|
||||
Tomas Hoger
|
||||
Jonathan Hogg
|
||||
Akintayo Holder
|
||||
Gerrit Holl
|
||||
Shane Holloway
|
||||
Rune Holm
|
||||
|
|
|
@ -87,6 +87,9 @@ Core and Builtins
|
|||
Library
|
||||
-------
|
||||
|
||||
- Issue #14635: telnetlib will use poll() rather than select() when possible
|
||||
to avoid failing due to the select() file descriptor limit.
|
||||
|
||||
- Issue #15180: Clarify posixpath.join() error message when mixing str & bytes
|
||||
|
||||
- Issue #15230: runpy.run_path now correctly sets __package__ as described
|
||||
|
|
Loading…
Reference in New Issue