Write flow control for asyncio (includes asyncio.streams overhaul).
This commit is contained in:
parent
051a331488
commit
355491dc47
|
@ -29,6 +29,34 @@ class BaseProtocol:
|
|||
aborted or closed).
|
||||
"""
|
||||
|
||||
def pause_writing(self):
|
||||
"""Called when the transport's buffer goes over the high-water mark.
|
||||
|
||||
Pause and resume calls are paired -- pause_writing() is called
|
||||
once when the buffer goes strictly over the high-water mark
|
||||
(even if subsequent writes increases the buffer size even
|
||||
more), and eventually resume_writing() is called once when the
|
||||
buffer size reaches the low-water mark.
|
||||
|
||||
Note that if the buffer size equals the high-water mark,
|
||||
pause_writing() is not called -- it must go strictly over.
|
||||
Conversely, resume_writing() is called when the buffer size is
|
||||
equal or lower than the low-water mark. These end conditions
|
||||
are important to ensure that things go as expected when either
|
||||
mark is zero.
|
||||
|
||||
NOTE: This is the only Protocol callback that is not called
|
||||
through EventLoop.call_soon() -- if it were, it would have no
|
||||
effect when it's most needed (when the app keeps writing
|
||||
without yielding until pause_writing() is called).
|
||||
"""
|
||||
|
||||
def resume_writing(self):
|
||||
"""Called when the transport's buffer drains below the low-water mark.
|
||||
|
||||
See pause_writing() for details.
|
||||
"""
|
||||
|
||||
|
||||
class Protocol(BaseProtocol):
|
||||
"""ABC representing a protocol.
|
||||
|
|
|
@ -346,8 +346,10 @@ class _SelectorTransport(transports.Transport):
|
|||
self._buffer = collections.deque()
|
||||
self._conn_lost = 0 # Set when call to connection_lost scheduled.
|
||||
self._closing = False # Set when close() called.
|
||||
if server is not None:
|
||||
server.attach(self)
|
||||
self._protocol_paused = False
|
||||
self.set_write_buffer_limits()
|
||||
if self._server is not None:
|
||||
self._server.attach(self)
|
||||
|
||||
def abort(self):
|
||||
self._force_close(None)
|
||||
|
@ -392,6 +394,40 @@ class _SelectorTransport(transports.Transport):
|
|||
server.detach(self)
|
||||
self._server = None
|
||||
|
||||
def _maybe_pause_protocol(self):
|
||||
size = self.get_write_buffer_size()
|
||||
if size <= self._high_water:
|
||||
return
|
||||
if not self._protocol_paused:
|
||||
self._protocol_paused = True
|
||||
try:
|
||||
self._protocol.pause_writing()
|
||||
except Exception:
|
||||
tulip_log.exception('pause_writing() failed')
|
||||
|
||||
def _maybe_resume_protocol(self):
|
||||
if self._protocol_paused and self.get_write_buffer_size() <= self._low_water:
|
||||
self._protocol_paused = False
|
||||
try:
|
||||
self._protocol.resume_writing()
|
||||
except Exception:
|
||||
tulip_log.exception('resume_writing() failed')
|
||||
|
||||
def set_write_buffer_limits(self, high=None, low=None):
|
||||
if high is None:
|
||||
if low is None:
|
||||
high = 64*1024
|
||||
else:
|
||||
high = 4*low
|
||||
if low is None:
|
||||
low = high // 4
|
||||
assert 0 <= low <= high, repr((low, high))
|
||||
self._high_water = high
|
||||
self._low_water = low
|
||||
|
||||
def get_write_buffer_size(self):
|
||||
return sum(len(data) for data in self._buffer)
|
||||
|
||||
|
||||
class _SelectorSocketTransport(_SelectorTransport):
|
||||
|
||||
|
@ -447,7 +483,7 @@ class _SelectorSocketTransport(_SelectorTransport):
|
|||
return
|
||||
|
||||
if not self._buffer:
|
||||
# Attempt to send it right away first.
|
||||
# Optimization: try to send now.
|
||||
try:
|
||||
n = self._sock.send(data)
|
||||
except (BlockingIOError, InterruptedError):
|
||||
|
@ -459,34 +495,36 @@ class _SelectorSocketTransport(_SelectorTransport):
|
|||
data = data[n:]
|
||||
if not data:
|
||||
return
|
||||
|
||||
# Start async I/O.
|
||||
# Not all was written; register write handler.
|
||||
self._loop.add_writer(self._sock_fd, self._write_ready)
|
||||
|
||||
# Add it to the buffer.
|
||||
self._buffer.append(data)
|
||||
self._maybe_pause_protocol()
|
||||
|
||||
def _write_ready(self):
|
||||
data = b''.join(self._buffer)
|
||||
assert data, 'Data should not be empty'
|
||||
|
||||
self._buffer.clear()
|
||||
self._buffer.clear() # Optimistically; may have to put it back later.
|
||||
try:
|
||||
n = self._sock.send(data)
|
||||
except (BlockingIOError, InterruptedError):
|
||||
self._buffer.append(data)
|
||||
self._buffer.append(data) # Still need to write this.
|
||||
except Exception as exc:
|
||||
self._loop.remove_writer(self._sock_fd)
|
||||
self._fatal_error(exc)
|
||||
else:
|
||||
data = data[n:]
|
||||
if not data:
|
||||
if data:
|
||||
self._buffer.append(data) # Still need to write this.
|
||||
self._maybe_resume_protocol() # May append to buffer.
|
||||
if not self._buffer:
|
||||
self._loop.remove_writer(self._sock_fd)
|
||||
if self._closing:
|
||||
self._call_connection_lost(None)
|
||||
elif self._eof:
|
||||
self._sock.shutdown(socket.SHUT_WR)
|
||||
return
|
||||
self._buffer.append(data) # Try again later.
|
||||
|
||||
def write_eof(self):
|
||||
if self._eof:
|
||||
|
@ -546,16 +584,23 @@ class _SelectorSslTransport(_SelectorTransport):
|
|||
self._loop.add_writer(self._sock_fd, self._on_handshake)
|
||||
return
|
||||
except Exception as exc:
|
||||
self._loop.remove_reader(self._sock_fd)
|
||||
self._loop.remove_writer(self._sock_fd)
|
||||
self._sock.close()
|
||||
if self._waiter is not None:
|
||||
self._waiter.set_exception(exc)
|
||||
return
|
||||
except BaseException as exc:
|
||||
self._loop.remove_reader(self._sock_fd)
|
||||
self._loop.remove_writer(self._sock_fd)
|
||||
self._sock.close()
|
||||
if self._waiter is not None:
|
||||
self._waiter.set_exception(exc)
|
||||
raise
|
||||
|
||||
self._loop.remove_reader(self._sock_fd)
|
||||
self._loop.remove_writer(self._sock_fd)
|
||||
|
||||
# Verify hostname if requested.
|
||||
peercert = self._sock.getpeercert()
|
||||
if (self._server_hostname is not None and
|
||||
|
@ -574,8 +619,6 @@ class _SelectorSslTransport(_SelectorTransport):
|
|||
compression=self._sock.compression(),
|
||||
)
|
||||
|
||||
self._loop.remove_reader(self._sock_fd)
|
||||
self._loop.remove_writer(self._sock_fd)
|
||||
self._loop.add_reader(self._sock_fd, self._on_ready)
|
||||
self._loop.add_writer(self._sock_fd, self._on_ready)
|
||||
self._loop.call_soon(self._protocol.connection_made, self)
|
||||
|
@ -642,6 +685,8 @@ class _SelectorSslTransport(_SelectorTransport):
|
|||
if n < len(data):
|
||||
self._buffer.append(data[n:])
|
||||
|
||||
self._maybe_resume_protocol() # May append to buffer.
|
||||
|
||||
if self._closing and not self._buffer:
|
||||
self._loop.remove_writer(self._sock_fd)
|
||||
self._call_connection_lost(None)
|
||||
|
@ -657,8 +702,9 @@ class _SelectorSslTransport(_SelectorTransport):
|
|||
self._conn_lost += 1
|
||||
return
|
||||
|
||||
self._buffer.append(data)
|
||||
# We could optimize, but the callback can do this for now.
|
||||
self._buffer.append(data)
|
||||
self._maybe_pause_protocol()
|
||||
|
||||
def can_write_eof(self):
|
||||
return False
|
||||
|
@ -675,11 +721,13 @@ class _SelectorDatagramTransport(_SelectorTransport):
|
|||
|
||||
def __init__(self, loop, sock, protocol, address=None, extra=None):
|
||||
super().__init__(loop, sock, protocol, extra)
|
||||
|
||||
self._address = address
|
||||
self._loop.add_reader(self._sock_fd, self._read_ready)
|
||||
self._loop.call_soon(self._protocol.connection_made, self)
|
||||
|
||||
def get_write_buffer_size(self):
|
||||
return sum(len(data) for data, _ in self._buffer)
|
||||
|
||||
def _read_ready(self):
|
||||
try:
|
||||
data, addr = self._sock.recvfrom(self.max_size)
|
||||
|
@ -723,6 +771,7 @@ class _SelectorDatagramTransport(_SelectorTransport):
|
|||
return
|
||||
|
||||
self._buffer.append((data, addr))
|
||||
self._maybe_pause_protocol()
|
||||
|
||||
def _sendto_ready(self):
|
||||
while self._buffer:
|
||||
|
@ -743,6 +792,7 @@ class _SelectorDatagramTransport(_SelectorTransport):
|
|||
self._fatal_error(exc)
|
||||
return
|
||||
|
||||
self._maybe_resume_protocol() # May append to buffer.
|
||||
if not self._buffer:
|
||||
self._loop.remove_writer(self._sock_fd)
|
||||
if self._closing:
|
||||
|
|
|
@ -39,7 +39,8 @@ def open_connection(host=None, port=None, *,
|
|||
protocol = StreamReaderProtocol(reader)
|
||||
transport, _ = yield from loop.create_connection(
|
||||
lambda: protocol, host, port, **kwds)
|
||||
return reader, transport # (reader, writer)
|
||||
writer = StreamWriter(transport, protocol, reader, loop)
|
||||
return reader, writer
|
||||
|
||||
|
||||
class StreamReaderProtocol(protocols.Protocol):
|
||||
|
@ -52,22 +53,113 @@ class StreamReaderProtocol(protocols.Protocol):
|
|||
"""
|
||||
|
||||
def __init__(self, stream_reader):
|
||||
self.stream_reader = stream_reader
|
||||
self._stream_reader = stream_reader
|
||||
self._drain_waiter = None
|
||||
self._paused = False
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.stream_reader.set_transport(transport)
|
||||
self._stream_reader.set_transport(transport)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
if exc is None:
|
||||
self.stream_reader.feed_eof()
|
||||
self._stream_reader.feed_eof()
|
||||
else:
|
||||
self.stream_reader.set_exception(exc)
|
||||
self._stream_reader.set_exception(exc)
|
||||
# Also wake up the writing side.
|
||||
if self._paused:
|
||||
waiter = self._drain_waiter
|
||||
if waiter is not None:
|
||||
self._drain_waiter = None
|
||||
if not waiter.done():
|
||||
if exc is None:
|
||||
waiter.set_result(None)
|
||||
else:
|
||||
waiter.set_exception(exc)
|
||||
|
||||
def data_received(self, data):
|
||||
self.stream_reader.feed_data(data)
|
||||
self._stream_reader.feed_data(data)
|
||||
|
||||
def eof_received(self):
|
||||
self.stream_reader.feed_eof()
|
||||
self._stream_reader.feed_eof()
|
||||
|
||||
def pause_writing(self):
|
||||
assert not self._paused
|
||||
self._paused = True
|
||||
|
||||
def resume_writing(self):
|
||||
assert self._paused
|
||||
self._paused = False
|
||||
waiter = self._drain_waiter
|
||||
if waiter is not None:
|
||||
self._drain_waiter = None
|
||||
if not waiter.done():
|
||||
waiter.set_result(None)
|
||||
|
||||
|
||||
class StreamWriter:
|
||||
"""Wraps a Transport.
|
||||
|
||||
This exposes write(), writelines(), [can_]write_eof(),
|
||||
get_extra_info() and close(). It adds drain() which returns an
|
||||
optional Future on which you can wait for flow control. It also
|
||||
adds a transport attribute which references the Transport
|
||||
directly.
|
||||
"""
|
||||
|
||||
def __init__(self, transport, protocol, reader, loop):
|
||||
self._transport = transport
|
||||
self._protocol = protocol
|
||||
self._reader = reader
|
||||
self._loop = loop
|
||||
|
||||
@property
|
||||
def transport(self):
|
||||
return self._transport
|
||||
|
||||
def write(self, data):
|
||||
self._transport.write(data)
|
||||
|
||||
def writelines(self, data):
|
||||
self._transport.writelines(data)
|
||||
|
||||
def write_eof(self):
|
||||
return self._transport.write_eof()
|
||||
|
||||
def can_write_eof(self):
|
||||
return self._transport.can_write_eof()
|
||||
|
||||
def close(self):
|
||||
return self._transport.close()
|
||||
|
||||
def get_extra_info(self, name, default=None):
|
||||
return self._transport.get_extra_info(name, default)
|
||||
|
||||
def drain(self):
|
||||
"""This method has an unusual return value.
|
||||
|
||||
The intended use is to write
|
||||
|
||||
w.write(data)
|
||||
yield from w.drain()
|
||||
|
||||
When there's nothing to wait for, drain() returns (), and the
|
||||
yield-from continues immediately. When the transport buffer
|
||||
is full (the protocol is paused), drain() creates and returns
|
||||
a Future and the yield-from will block until that Future is
|
||||
completed, which will happen when the buffer is (partially)
|
||||
drained and the protocol is resumed.
|
||||
"""
|
||||
if self._reader._exception is not None:
|
||||
raise self._writer._exception
|
||||
if self._transport._conn_lost: # Uses private variable.
|
||||
raise ConnectionResetError('Connection lost')
|
||||
if not self._protocol._paused:
|
||||
return ()
|
||||
waiter = self._protocol._drain_waiter
|
||||
assert waiter is None or waiter.cancelled()
|
||||
waiter = futures.Future(loop=self._loop)
|
||||
self._protocol._drain_waiter = waiter
|
||||
return waiter
|
||||
|
||||
|
||||
class StreamReader:
|
||||
|
@ -75,14 +167,14 @@ class StreamReader:
|
|||
def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
|
||||
# The line length limit is a security feature;
|
||||
# it also doubles as half the buffer limit.
|
||||
self.limit = limit
|
||||
self._limit = limit
|
||||
if loop is None:
|
||||
loop = events.get_event_loop()
|
||||
self.loop = loop
|
||||
self.buffer = collections.deque() # Deque of bytes objects.
|
||||
self.byte_count = 0 # Bytes in buffer.
|
||||
self.eof = False # Whether we're done.
|
||||
self.waiter = None # A future.
|
||||
self._loop = loop
|
||||
self._buffer = collections.deque() # Deque of bytes objects.
|
||||
self._byte_count = 0 # Bytes in buffer.
|
||||
self._eof = False # Whether we're done.
|
||||
self._waiter = None # A future.
|
||||
self._exception = None
|
||||
self._transport = None
|
||||
self._paused = False
|
||||
|
@ -93,9 +185,9 @@ class StreamReader:
|
|||
def set_exception(self, exc):
|
||||
self._exception = exc
|
||||
|
||||
waiter = self.waiter
|
||||
waiter = self._waiter
|
||||
if waiter is not None:
|
||||
self.waiter = None
|
||||
self._waiter = None
|
||||
if not waiter.cancelled():
|
||||
waiter.set_exception(exc)
|
||||
|
||||
|
@ -104,15 +196,15 @@ class StreamReader:
|
|||
self._transport = transport
|
||||
|
||||
def _maybe_resume_transport(self):
|
||||
if self._paused and self.byte_count <= self.limit:
|
||||
if self._paused and self._byte_count <= self._limit:
|
||||
self._paused = False
|
||||
self._transport.resume_reading()
|
||||
|
||||
def feed_eof(self):
|
||||
self.eof = True
|
||||
waiter = self.waiter
|
||||
self._eof = True
|
||||
waiter = self._waiter
|
||||
if waiter is not None:
|
||||
self.waiter = None
|
||||
self._waiter = None
|
||||
if not waiter.cancelled():
|
||||
waiter.set_result(True)
|
||||
|
||||
|
@ -120,18 +212,18 @@ class StreamReader:
|
|||
if not data:
|
||||
return
|
||||
|
||||
self.buffer.append(data)
|
||||
self.byte_count += len(data)
|
||||
self._buffer.append(data)
|
||||
self._byte_count += len(data)
|
||||
|
||||
waiter = self.waiter
|
||||
waiter = self._waiter
|
||||
if waiter is not None:
|
||||
self.waiter = None
|
||||
self._waiter = None
|
||||
if not waiter.cancelled():
|
||||
waiter.set_result(False)
|
||||
|
||||
if (self._transport is not None and
|
||||
not self._paused and
|
||||
self.byte_count > 2*self.limit):
|
||||
self._byte_count > 2*self._limit):
|
||||
try:
|
||||
self._transport.pause_reading()
|
||||
except NotImplementedError:
|
||||
|
@ -152,8 +244,8 @@ class StreamReader:
|
|||
not_enough = True
|
||||
|
||||
while not_enough:
|
||||
while self.buffer and not_enough:
|
||||
data = self.buffer.popleft()
|
||||
while self._buffer and not_enough:
|
||||
data = self._buffer.popleft()
|
||||
ichar = data.find(b'\n')
|
||||
if ichar < 0:
|
||||
parts.append(data)
|
||||
|
@ -162,29 +254,29 @@ class StreamReader:
|
|||
ichar += 1
|
||||
head, tail = data[:ichar], data[ichar:]
|
||||
if tail:
|
||||
self.buffer.appendleft(tail)
|
||||
self._buffer.appendleft(tail)
|
||||
not_enough = False
|
||||
parts.append(head)
|
||||
parts_size += len(head)
|
||||
|
||||
if parts_size > self.limit:
|
||||
self.byte_count -= parts_size
|
||||
if parts_size > self._limit:
|
||||
self._byte_count -= parts_size
|
||||
self._maybe_resume_transport()
|
||||
raise ValueError('Line is too long')
|
||||
|
||||
if self.eof:
|
||||
if self._eof:
|
||||
break
|
||||
|
||||
if not_enough:
|
||||
assert self.waiter is None
|
||||
self.waiter = futures.Future(loop=self.loop)
|
||||
assert self._waiter is None
|
||||
self._waiter = futures.Future(loop=self._loop)
|
||||
try:
|
||||
yield from self.waiter
|
||||
yield from self._waiter
|
||||
finally:
|
||||
self.waiter = None
|
||||
self._waiter = None
|
||||
|
||||
line = b''.join(parts)
|
||||
self.byte_count -= parts_size
|
||||
self._byte_count -= parts_size
|
||||
self._maybe_resume_transport()
|
||||
|
||||
return line
|
||||
|
@ -198,42 +290,42 @@ class StreamReader:
|
|||
return b''
|
||||
|
||||
if n < 0:
|
||||
while not self.eof:
|
||||
assert not self.waiter
|
||||
self.waiter = futures.Future(loop=self.loop)
|
||||
while not self._eof:
|
||||
assert not self._waiter
|
||||
self._waiter = futures.Future(loop=self._loop)
|
||||
try:
|
||||
yield from self.waiter
|
||||
yield from self._waiter
|
||||
finally:
|
||||
self.waiter = None
|
||||
self._waiter = None
|
||||
else:
|
||||
if not self.byte_count and not self.eof:
|
||||
assert not self.waiter
|
||||
self.waiter = futures.Future(loop=self.loop)
|
||||
if not self._byte_count and not self._eof:
|
||||
assert not self._waiter
|
||||
self._waiter = futures.Future(loop=self._loop)
|
||||
try:
|
||||
yield from self.waiter
|
||||
yield from self._waiter
|
||||
finally:
|
||||
self.waiter = None
|
||||
self._waiter = None
|
||||
|
||||
if n < 0 or self.byte_count <= n:
|
||||
data = b''.join(self.buffer)
|
||||
self.buffer.clear()
|
||||
self.byte_count = 0
|
||||
if n < 0 or self._byte_count <= n:
|
||||
data = b''.join(self._buffer)
|
||||
self._buffer.clear()
|
||||
self._byte_count = 0
|
||||
self._maybe_resume_transport()
|
||||
return data
|
||||
|
||||
parts = []
|
||||
parts_bytes = 0
|
||||
while self.buffer and parts_bytes < n:
|
||||
data = self.buffer.popleft()
|
||||
while self._buffer and parts_bytes < n:
|
||||
data = self._buffer.popleft()
|
||||
data_bytes = len(data)
|
||||
if n < parts_bytes + data_bytes:
|
||||
data_bytes = n - parts_bytes
|
||||
data, rest = data[:data_bytes], data[data_bytes:]
|
||||
self.buffer.appendleft(rest)
|
||||
self._buffer.appendleft(rest)
|
||||
|
||||
parts.append(data)
|
||||
parts_bytes += data_bytes
|
||||
self.byte_count -= data_bytes
|
||||
self._byte_count -= data_bytes
|
||||
self._maybe_resume_transport()
|
||||
|
||||
return b''.join(parts)
|
||||
|
@ -246,12 +338,12 @@ class StreamReader:
|
|||
if n <= 0:
|
||||
return b''
|
||||
|
||||
while self.byte_count < n and not self.eof:
|
||||
assert not self.waiter
|
||||
self.waiter = futures.Future(loop=self.loop)
|
||||
while self._byte_count < n and not self._eof:
|
||||
assert not self._waiter
|
||||
self._waiter = futures.Future(loop=self._loop)
|
||||
try:
|
||||
yield from self.waiter
|
||||
yield from self._waiter
|
||||
finally:
|
||||
self.waiter = None
|
||||
self._waiter = None
|
||||
|
||||
return (yield from self.read(n))
|
||||
|
|
|
@ -49,6 +49,31 @@ class ReadTransport(BaseTransport):
|
|||
class WriteTransport(BaseTransport):
|
||||
"""ABC for write-only transports."""
|
||||
|
||||
def set_write_buffer_limits(self, high=None, low=None):
|
||||
"""Set the high- and low-water limits for write flow control.
|
||||
|
||||
These two values control when to call the protocol's
|
||||
pause_writing() and resume_writing() methods. If specified,
|
||||
the low-water limit must be less than or equal to the
|
||||
high-water limit. Neither value can be negative.
|
||||
|
||||
The defaults are implementation-specific. If only the
|
||||
high-water limit is given, the low-water limit defaults to a
|
||||
implementation-specific value less than or equal to the
|
||||
high-water limit. Setting high to zero forces low to zero as
|
||||
well, and causes pause_writing() to be called whenever the
|
||||
buffer becomes non-empty. Setting low to zero causes
|
||||
resume_writing() to be called only once the buffer is empty.
|
||||
Use of zero for either limit is generally sub-optimal as it
|
||||
reduces opportunities for doing I/O and computation
|
||||
concurrently.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_write_buffer_size(self):
|
||||
"""Return the current size of the write buffer."""
|
||||
raise NotImplementedError
|
||||
|
||||
def write(self, data):
|
||||
"""Write some data bytes to the transport.
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ class StreamReaderTests(unittest.TestCase):
|
|||
@unittest.mock.patch('asyncio.streams.events')
|
||||
def test_ctor_global_loop(self, m_events):
|
||||
stream = streams.StreamReader()
|
||||
self.assertIs(stream.loop, m_events.get_event_loop.return_value)
|
||||
self.assertIs(stream._loop, m_events.get_event_loop.return_value)
|
||||
|
||||
def test_open_connection(self):
|
||||
with test_utils.run_test_server() as httpd:
|
||||
|
@ -81,13 +81,13 @@ class StreamReaderTests(unittest.TestCase):
|
|||
stream = streams.StreamReader(loop=self.loop)
|
||||
|
||||
stream.feed_data(b'')
|
||||
self.assertEqual(0, stream.byte_count)
|
||||
self.assertEqual(0, stream._byte_count)
|
||||
|
||||
def test_feed_data_byte_count(self):
|
||||
stream = streams.StreamReader(loop=self.loop)
|
||||
|
||||
stream.feed_data(self.DATA)
|
||||
self.assertEqual(len(self.DATA), stream.byte_count)
|
||||
self.assertEqual(len(self.DATA), stream._byte_count)
|
||||
|
||||
def test_read_zero(self):
|
||||
# Read zero bytes.
|
||||
|
@ -96,7 +96,7 @@ class StreamReaderTests(unittest.TestCase):
|
|||
|
||||
data = self.loop.run_until_complete(stream.read(0))
|
||||
self.assertEqual(b'', data)
|
||||
self.assertEqual(len(self.DATA), stream.byte_count)
|
||||
self.assertEqual(len(self.DATA), stream._byte_count)
|
||||
|
||||
def test_read(self):
|
||||
# Read bytes.
|
||||
|
@ -109,7 +109,7 @@ class StreamReaderTests(unittest.TestCase):
|
|||
|
||||
data = self.loop.run_until_complete(read_task)
|
||||
self.assertEqual(self.DATA, data)
|
||||
self.assertFalse(stream.byte_count)
|
||||
self.assertFalse(stream._byte_count)
|
||||
|
||||
def test_read_line_breaks(self):
|
||||
# Read bytes without line breaks.
|
||||
|
@ -120,7 +120,7 @@ class StreamReaderTests(unittest.TestCase):
|
|||
data = self.loop.run_until_complete(stream.read(5))
|
||||
|
||||
self.assertEqual(b'line1', data)
|
||||
self.assertEqual(5, stream.byte_count)
|
||||
self.assertEqual(5, stream._byte_count)
|
||||
|
||||
def test_read_eof(self):
|
||||
# Read bytes, stop at eof.
|
||||
|
@ -133,7 +133,7 @@ class StreamReaderTests(unittest.TestCase):
|
|||
|
||||
data = self.loop.run_until_complete(read_task)
|
||||
self.assertEqual(b'', data)
|
||||
self.assertFalse(stream.byte_count)
|
||||
self.assertFalse(stream._byte_count)
|
||||
|
||||
def test_read_until_eof(self):
|
||||
# Read all bytes until eof.
|
||||
|
@ -149,7 +149,7 @@ class StreamReaderTests(unittest.TestCase):
|
|||
data = self.loop.run_until_complete(read_task)
|
||||
|
||||
self.assertEqual(b'chunk1\nchunk2', data)
|
||||
self.assertFalse(stream.byte_count)
|
||||
self.assertFalse(stream._byte_count)
|
||||
|
||||
def test_read_exception(self):
|
||||
stream = streams.StreamReader(loop=self.loop)
|
||||
|
@ -176,7 +176,7 @@ class StreamReaderTests(unittest.TestCase):
|
|||
|
||||
line = self.loop.run_until_complete(read_task)
|
||||
self.assertEqual(b'chunk1 chunk2 chunk3 \n', line)
|
||||
self.assertEqual(len(b'\n chunk4')-1, stream.byte_count)
|
||||
self.assertEqual(len(b'\n chunk4')-1, stream._byte_count)
|
||||
|
||||
def test_readline_limit_with_existing_data(self):
|
||||
stream = streams.StreamReader(3, loop=self.loop)
|
||||
|
@ -185,7 +185,7 @@ class StreamReaderTests(unittest.TestCase):
|
|||
|
||||
self.assertRaises(
|
||||
ValueError, self.loop.run_until_complete, stream.readline())
|
||||
self.assertEqual([b'line2\n'], list(stream.buffer))
|
||||
self.assertEqual([b'line2\n'], list(stream._buffer))
|
||||
|
||||
stream = streams.StreamReader(3, loop=self.loop)
|
||||
stream.feed_data(b'li')
|
||||
|
@ -194,8 +194,8 @@ class StreamReaderTests(unittest.TestCase):
|
|||
|
||||
self.assertRaises(
|
||||
ValueError, self.loop.run_until_complete, stream.readline())
|
||||
self.assertEqual([b'li'], list(stream.buffer))
|
||||
self.assertEqual(2, stream.byte_count)
|
||||
self.assertEqual([b'li'], list(stream._buffer))
|
||||
self.assertEqual(2, stream._byte_count)
|
||||
|
||||
def test_readline_limit(self):
|
||||
stream = streams.StreamReader(7, loop=self.loop)
|
||||
|
@ -209,8 +209,8 @@ class StreamReaderTests(unittest.TestCase):
|
|||
|
||||
self.assertRaises(
|
||||
ValueError, self.loop.run_until_complete, stream.readline())
|
||||
self.assertEqual([b'chunk3\n'], list(stream.buffer))
|
||||
self.assertEqual(7, stream.byte_count)
|
||||
self.assertEqual([b'chunk3\n'], list(stream._buffer))
|
||||
self.assertEqual(7, stream._byte_count)
|
||||
|
||||
def test_readline_line_byte_count(self):
|
||||
stream = streams.StreamReader(loop=self.loop)
|
||||
|
@ -220,7 +220,7 @@ class StreamReaderTests(unittest.TestCase):
|
|||
line = self.loop.run_until_complete(stream.readline())
|
||||
|
||||
self.assertEqual(b'line1\n', line)
|
||||
self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count)
|
||||
self.assertEqual(len(self.DATA) - len(b'line1\n'), stream._byte_count)
|
||||
|
||||
def test_readline_eof(self):
|
||||
stream = streams.StreamReader(loop=self.loop)
|
||||
|
@ -248,7 +248,7 @@ class StreamReaderTests(unittest.TestCase):
|
|||
self.assertEqual(b'line2\nl', data)
|
||||
self.assertEqual(
|
||||
len(self.DATA) - len(b'line1\n') - len(b'line2\nl'),
|
||||
stream.byte_count)
|
||||
stream._byte_count)
|
||||
|
||||
def test_readline_exception(self):
|
||||
stream = streams.StreamReader(loop=self.loop)
|
||||
|
@ -268,11 +268,11 @@ class StreamReaderTests(unittest.TestCase):
|
|||
|
||||
data = self.loop.run_until_complete(stream.readexactly(0))
|
||||
self.assertEqual(b'', data)
|
||||
self.assertEqual(len(self.DATA), stream.byte_count)
|
||||
self.assertEqual(len(self.DATA), stream._byte_count)
|
||||
|
||||
data = self.loop.run_until_complete(stream.readexactly(-1))
|
||||
self.assertEqual(b'', data)
|
||||
self.assertEqual(len(self.DATA), stream.byte_count)
|
||||
self.assertEqual(len(self.DATA), stream._byte_count)
|
||||
|
||||
def test_readexactly(self):
|
||||
# Read exact number of bytes.
|
||||
|
@ -289,7 +289,7 @@ class StreamReaderTests(unittest.TestCase):
|
|||
|
||||
data = self.loop.run_until_complete(read_task)
|
||||
self.assertEqual(self.DATA + self.DATA, data)
|
||||
self.assertEqual(len(self.DATA), stream.byte_count)
|
||||
self.assertEqual(len(self.DATA), stream._byte_count)
|
||||
|
||||
def test_readexactly_eof(self):
|
||||
# Read exact number of bytes (eof).
|
||||
|
@ -304,7 +304,7 @@ class StreamReaderTests(unittest.TestCase):
|
|||
|
||||
data = self.loop.run_until_complete(read_task)
|
||||
self.assertEqual(self.DATA, data)
|
||||
self.assertFalse(stream.byte_count)
|
||||
self.assertFalse(stream._byte_count)
|
||||
|
||||
def test_readexactly_exception(self):
|
||||
stream = streams.StreamReader(loop=self.loop)
|
||||
|
@ -357,7 +357,7 @@ class StreamReaderTests(unittest.TestCase):
|
|||
# The following line fails if set_exception() isn't careful.
|
||||
stream.set_exception(RuntimeError('message'))
|
||||
test_utils.run_briefly(self.loop)
|
||||
self.assertIs(stream.waiter, None)
|
||||
self.assertIs(stream._waiter, None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue