Write flow control for asyncio (includes asyncio.streams overhaul).

This commit is contained in:
Guido van Rossum 2013-10-18 15:17:11 -07:00
parent 051a331488
commit 355491dc47
5 changed files with 288 additions and 93 deletions

View File

@ -29,6 +29,34 @@ class BaseProtocol:
aborted or closed).
"""
def pause_writing(self):
"""Called when the transport's buffer goes over the high-water mark.
Pause and resume calls are paired -- pause_writing() is called
once when the buffer goes strictly over the high-water mark
(even if subsequent writes increases the buffer size even
more), and eventually resume_writing() is called once when the
buffer size reaches the low-water mark.
Note that if the buffer size equals the high-water mark,
pause_writing() is not called -- it must go strictly over.
Conversely, resume_writing() is called when the buffer size is
equal or lower than the low-water mark. These end conditions
are important to ensure that things go as expected when either
mark is zero.
NOTE: This is the only Protocol callback that is not called
through EventLoop.call_soon() -- if it were, it would have no
effect when it's most needed (when the app keeps writing
without yielding until pause_writing() is called).
"""
def resume_writing(self):
"""Called when the transport's buffer drains below the low-water mark.
See pause_writing() for details.
"""
class Protocol(BaseProtocol):
"""ABC representing a protocol.

View File

@ -346,8 +346,10 @@ class _SelectorTransport(transports.Transport):
self._buffer = collections.deque()
self._conn_lost = 0 # Set when call to connection_lost scheduled.
self._closing = False # Set when close() called.
if server is not None:
server.attach(self)
self._protocol_paused = False
self.set_write_buffer_limits()
if self._server is not None:
self._server.attach(self)
def abort(self):
self._force_close(None)
@ -392,6 +394,40 @@ class _SelectorTransport(transports.Transport):
server.detach(self)
self._server = None
def _maybe_pause_protocol(self):
size = self.get_write_buffer_size()
if size <= self._high_water:
return
if not self._protocol_paused:
self._protocol_paused = True
try:
self._protocol.pause_writing()
except Exception:
tulip_log.exception('pause_writing() failed')
def _maybe_resume_protocol(self):
if self._protocol_paused and self.get_write_buffer_size() <= self._low_water:
self._protocol_paused = False
try:
self._protocol.resume_writing()
except Exception:
tulip_log.exception('resume_writing() failed')
def set_write_buffer_limits(self, high=None, low=None):
if high is None:
if low is None:
high = 64*1024
else:
high = 4*low
if low is None:
low = high // 4
assert 0 <= low <= high, repr((low, high))
self._high_water = high
self._low_water = low
def get_write_buffer_size(self):
return sum(len(data) for data in self._buffer)
class _SelectorSocketTransport(_SelectorTransport):
@ -447,7 +483,7 @@ class _SelectorSocketTransport(_SelectorTransport):
return
if not self._buffer:
# Attempt to send it right away first.
# Optimization: try to send now.
try:
n = self._sock.send(data)
except (BlockingIOError, InterruptedError):
@ -459,34 +495,36 @@ class _SelectorSocketTransport(_SelectorTransport):
data = data[n:]
if not data:
return
# Start async I/O.
# Not all was written; register write handler.
self._loop.add_writer(self._sock_fd, self._write_ready)
# Add it to the buffer.
self._buffer.append(data)
self._maybe_pause_protocol()
def _write_ready(self):
data = b''.join(self._buffer)
assert data, 'Data should not be empty'
self._buffer.clear()
self._buffer.clear() # Optimistically; may have to put it back later.
try:
n = self._sock.send(data)
except (BlockingIOError, InterruptedError):
self._buffer.append(data)
self._buffer.append(data) # Still need to write this.
except Exception as exc:
self._loop.remove_writer(self._sock_fd)
self._fatal_error(exc)
else:
data = data[n:]
if not data:
if data:
self._buffer.append(data) # Still need to write this.
self._maybe_resume_protocol() # May append to buffer.
if not self._buffer:
self._loop.remove_writer(self._sock_fd)
if self._closing:
self._call_connection_lost(None)
elif self._eof:
self._sock.shutdown(socket.SHUT_WR)
return
self._buffer.append(data) # Try again later.
def write_eof(self):
if self._eof:
@ -546,16 +584,23 @@ class _SelectorSslTransport(_SelectorTransport):
self._loop.add_writer(self._sock_fd, self._on_handshake)
return
except Exception as exc:
self._loop.remove_reader(self._sock_fd)
self._loop.remove_writer(self._sock_fd)
self._sock.close()
if self._waiter is not None:
self._waiter.set_exception(exc)
return
except BaseException as exc:
self._loop.remove_reader(self._sock_fd)
self._loop.remove_writer(self._sock_fd)
self._sock.close()
if self._waiter is not None:
self._waiter.set_exception(exc)
raise
self._loop.remove_reader(self._sock_fd)
self._loop.remove_writer(self._sock_fd)
# Verify hostname if requested.
peercert = self._sock.getpeercert()
if (self._server_hostname is not None and
@ -574,8 +619,6 @@ class _SelectorSslTransport(_SelectorTransport):
compression=self._sock.compression(),
)
self._loop.remove_reader(self._sock_fd)
self._loop.remove_writer(self._sock_fd)
self._loop.add_reader(self._sock_fd, self._on_ready)
self._loop.add_writer(self._sock_fd, self._on_ready)
self._loop.call_soon(self._protocol.connection_made, self)
@ -642,6 +685,8 @@ class _SelectorSslTransport(_SelectorTransport):
if n < len(data):
self._buffer.append(data[n:])
self._maybe_resume_protocol() # May append to buffer.
if self._closing and not self._buffer:
self._loop.remove_writer(self._sock_fd)
self._call_connection_lost(None)
@ -657,8 +702,9 @@ class _SelectorSslTransport(_SelectorTransport):
self._conn_lost += 1
return
self._buffer.append(data)
# We could optimize, but the callback can do this for now.
self._buffer.append(data)
self._maybe_pause_protocol()
def can_write_eof(self):
return False
@ -675,11 +721,13 @@ class _SelectorDatagramTransport(_SelectorTransport):
def __init__(self, loop, sock, protocol, address=None, extra=None):
super().__init__(loop, sock, protocol, extra)
self._address = address
self._loop.add_reader(self._sock_fd, self._read_ready)
self._loop.call_soon(self._protocol.connection_made, self)
def get_write_buffer_size(self):
return sum(len(data) for data, _ in self._buffer)
def _read_ready(self):
try:
data, addr = self._sock.recvfrom(self.max_size)
@ -723,6 +771,7 @@ class _SelectorDatagramTransport(_SelectorTransport):
return
self._buffer.append((data, addr))
self._maybe_pause_protocol()
def _sendto_ready(self):
while self._buffer:
@ -743,6 +792,7 @@ class _SelectorDatagramTransport(_SelectorTransport):
self._fatal_error(exc)
return
self._maybe_resume_protocol() # May append to buffer.
if not self._buffer:
self._loop.remove_writer(self._sock_fd)
if self._closing:

View File

@ -39,7 +39,8 @@ def open_connection(host=None, port=None, *,
protocol = StreamReaderProtocol(reader)
transport, _ = yield from loop.create_connection(
lambda: protocol, host, port, **kwds)
return reader, transport # (reader, writer)
writer = StreamWriter(transport, protocol, reader, loop)
return reader, writer
class StreamReaderProtocol(protocols.Protocol):
@ -52,22 +53,113 @@ class StreamReaderProtocol(protocols.Protocol):
"""
def __init__(self, stream_reader):
self.stream_reader = stream_reader
self._stream_reader = stream_reader
self._drain_waiter = None
self._paused = False
def connection_made(self, transport):
self.stream_reader.set_transport(transport)
self._stream_reader.set_transport(transport)
def connection_lost(self, exc):
if exc is None:
self.stream_reader.feed_eof()
self._stream_reader.feed_eof()
else:
self.stream_reader.set_exception(exc)
self._stream_reader.set_exception(exc)
# Also wake up the writing side.
if self._paused:
waiter = self._drain_waiter
if waiter is not None:
self._drain_waiter = None
if not waiter.done():
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)
def data_received(self, data):
self.stream_reader.feed_data(data)
self._stream_reader.feed_data(data)
def eof_received(self):
self.stream_reader.feed_eof()
self._stream_reader.feed_eof()
def pause_writing(self):
assert not self._paused
self._paused = True
def resume_writing(self):
assert self._paused
self._paused = False
waiter = self._drain_waiter
if waiter is not None:
self._drain_waiter = None
if not waiter.done():
waiter.set_result(None)
class StreamWriter:
"""Wraps a Transport.
This exposes write(), writelines(), [can_]write_eof(),
get_extra_info() and close(). It adds drain() which returns an
optional Future on which you can wait for flow control. It also
adds a transport attribute which references the Transport
directly.
"""
def __init__(self, transport, protocol, reader, loop):
self._transport = transport
self._protocol = protocol
self._reader = reader
self._loop = loop
@property
def transport(self):
return self._transport
def write(self, data):
self._transport.write(data)
def writelines(self, data):
self._transport.writelines(data)
def write_eof(self):
return self._transport.write_eof()
def can_write_eof(self):
return self._transport.can_write_eof()
def close(self):
return self._transport.close()
def get_extra_info(self, name, default=None):
return self._transport.get_extra_info(name, default)
def drain(self):
"""This method has an unusual return value.
The intended use is to write
w.write(data)
yield from w.drain()
When there's nothing to wait for, drain() returns (), and the
yield-from continues immediately. When the transport buffer
is full (the protocol is paused), drain() creates and returns
a Future and the yield-from will block until that Future is
completed, which will happen when the buffer is (partially)
drained and the protocol is resumed.
"""
if self._reader._exception is not None:
raise self._writer._exception
if self._transport._conn_lost: # Uses private variable.
raise ConnectionResetError('Connection lost')
if not self._protocol._paused:
return ()
waiter = self._protocol._drain_waiter
assert waiter is None or waiter.cancelled()
waiter = futures.Future(loop=self._loop)
self._protocol._drain_waiter = waiter
return waiter
class StreamReader:
@ -75,14 +167,14 @@ class StreamReader:
def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
# The line length limit is a security feature;
# it also doubles as half the buffer limit.
self.limit = limit
self._limit = limit
if loop is None:
loop = events.get_event_loop()
self.loop = loop
self.buffer = collections.deque() # Deque of bytes objects.
self.byte_count = 0 # Bytes in buffer.
self.eof = False # Whether we're done.
self.waiter = None # A future.
self._loop = loop
self._buffer = collections.deque() # Deque of bytes objects.
self._byte_count = 0 # Bytes in buffer.
self._eof = False # Whether we're done.
self._waiter = None # A future.
self._exception = None
self._transport = None
self._paused = False
@ -93,9 +185,9 @@ class StreamReader:
def set_exception(self, exc):
self._exception = exc
waiter = self.waiter
waiter = self._waiter
if waiter is not None:
self.waiter = None
self._waiter = None
if not waiter.cancelled():
waiter.set_exception(exc)
@ -104,15 +196,15 @@ class StreamReader:
self._transport = transport
def _maybe_resume_transport(self):
if self._paused and self.byte_count <= self.limit:
if self._paused and self._byte_count <= self._limit:
self._paused = False
self._transport.resume_reading()
def feed_eof(self):
self.eof = True
waiter = self.waiter
self._eof = True
waiter = self._waiter
if waiter is not None:
self.waiter = None
self._waiter = None
if not waiter.cancelled():
waiter.set_result(True)
@ -120,18 +212,18 @@ class StreamReader:
if not data:
return
self.buffer.append(data)
self.byte_count += len(data)
self._buffer.append(data)
self._byte_count += len(data)
waiter = self.waiter
waiter = self._waiter
if waiter is not None:
self.waiter = None
self._waiter = None
if not waiter.cancelled():
waiter.set_result(False)
if (self._transport is not None and
not self._paused and
self.byte_count > 2*self.limit):
self._byte_count > 2*self._limit):
try:
self._transport.pause_reading()
except NotImplementedError:
@ -152,8 +244,8 @@ class StreamReader:
not_enough = True
while not_enough:
while self.buffer and not_enough:
data = self.buffer.popleft()
while self._buffer and not_enough:
data = self._buffer.popleft()
ichar = data.find(b'\n')
if ichar < 0:
parts.append(data)
@ -162,29 +254,29 @@ class StreamReader:
ichar += 1
head, tail = data[:ichar], data[ichar:]
if tail:
self.buffer.appendleft(tail)
self._buffer.appendleft(tail)
not_enough = False
parts.append(head)
parts_size += len(head)
if parts_size > self.limit:
self.byte_count -= parts_size
if parts_size > self._limit:
self._byte_count -= parts_size
self._maybe_resume_transport()
raise ValueError('Line is too long')
if self.eof:
if self._eof:
break
if not_enough:
assert self.waiter is None
self.waiter = futures.Future(loop=self.loop)
assert self._waiter is None
self._waiter = futures.Future(loop=self._loop)
try:
yield from self.waiter
yield from self._waiter
finally:
self.waiter = None
self._waiter = None
line = b''.join(parts)
self.byte_count -= parts_size
self._byte_count -= parts_size
self._maybe_resume_transport()
return line
@ -198,42 +290,42 @@ class StreamReader:
return b''
if n < 0:
while not self.eof:
assert not self.waiter
self.waiter = futures.Future(loop=self.loop)
while not self._eof:
assert not self._waiter
self._waiter = futures.Future(loop=self._loop)
try:
yield from self.waiter
yield from self._waiter
finally:
self.waiter = None
self._waiter = None
else:
if not self.byte_count and not self.eof:
assert not self.waiter
self.waiter = futures.Future(loop=self.loop)
if not self._byte_count and not self._eof:
assert not self._waiter
self._waiter = futures.Future(loop=self._loop)
try:
yield from self.waiter
yield from self._waiter
finally:
self.waiter = None
self._waiter = None
if n < 0 or self.byte_count <= n:
data = b''.join(self.buffer)
self.buffer.clear()
self.byte_count = 0
if n < 0 or self._byte_count <= n:
data = b''.join(self._buffer)
self._buffer.clear()
self._byte_count = 0
self._maybe_resume_transport()
return data
parts = []
parts_bytes = 0
while self.buffer and parts_bytes < n:
data = self.buffer.popleft()
while self._buffer and parts_bytes < n:
data = self._buffer.popleft()
data_bytes = len(data)
if n < parts_bytes + data_bytes:
data_bytes = n - parts_bytes
data, rest = data[:data_bytes], data[data_bytes:]
self.buffer.appendleft(rest)
self._buffer.appendleft(rest)
parts.append(data)
parts_bytes += data_bytes
self.byte_count -= data_bytes
self._byte_count -= data_bytes
self._maybe_resume_transport()
return b''.join(parts)
@ -246,12 +338,12 @@ class StreamReader:
if n <= 0:
return b''
while self.byte_count < n and not self.eof:
assert not self.waiter
self.waiter = futures.Future(loop=self.loop)
while self._byte_count < n and not self._eof:
assert not self._waiter
self._waiter = futures.Future(loop=self._loop)
try:
yield from self.waiter
yield from self._waiter
finally:
self.waiter = None
self._waiter = None
return (yield from self.read(n))

View File

@ -49,6 +49,31 @@ class ReadTransport(BaseTransport):
class WriteTransport(BaseTransport):
"""ABC for write-only transports."""
def set_write_buffer_limits(self, high=None, low=None):
"""Set the high- and low-water limits for write flow control.
These two values control when to call the protocol's
pause_writing() and resume_writing() methods. If specified,
the low-water limit must be less than or equal to the
high-water limit. Neither value can be negative.
The defaults are implementation-specific. If only the
high-water limit is given, the low-water limit defaults to a
implementation-specific value less than or equal to the
high-water limit. Setting high to zero forces low to zero as
well, and causes pause_writing() to be called whenever the
buffer becomes non-empty. Setting low to zero causes
resume_writing() to be called only once the buffer is empty.
Use of zero for either limit is generally sub-optimal as it
reduces opportunities for doing I/O and computation
concurrently.
"""
raise NotImplementedError
def get_write_buffer_size(self):
"""Return the current size of the write buffer."""
raise NotImplementedError
def write(self, data):
"""Write some data bytes to the transport.

View File

@ -32,7 +32,7 @@ class StreamReaderTests(unittest.TestCase):
@unittest.mock.patch('asyncio.streams.events')
def test_ctor_global_loop(self, m_events):
stream = streams.StreamReader()
self.assertIs(stream.loop, m_events.get_event_loop.return_value)
self.assertIs(stream._loop, m_events.get_event_loop.return_value)
def test_open_connection(self):
with test_utils.run_test_server() as httpd:
@ -81,13 +81,13 @@ class StreamReaderTests(unittest.TestCase):
stream = streams.StreamReader(loop=self.loop)
stream.feed_data(b'')
self.assertEqual(0, stream.byte_count)
self.assertEqual(0, stream._byte_count)
def test_feed_data_byte_count(self):
stream = streams.StreamReader(loop=self.loop)
stream.feed_data(self.DATA)
self.assertEqual(len(self.DATA), stream.byte_count)
self.assertEqual(len(self.DATA), stream._byte_count)
def test_read_zero(self):
# Read zero bytes.
@ -96,7 +96,7 @@ class StreamReaderTests(unittest.TestCase):
data = self.loop.run_until_complete(stream.read(0))
self.assertEqual(b'', data)
self.assertEqual(len(self.DATA), stream.byte_count)
self.assertEqual(len(self.DATA), stream._byte_count)
def test_read(self):
# Read bytes.
@ -109,7 +109,7 @@ class StreamReaderTests(unittest.TestCase):
data = self.loop.run_until_complete(read_task)
self.assertEqual(self.DATA, data)
self.assertFalse(stream.byte_count)
self.assertFalse(stream._byte_count)
def test_read_line_breaks(self):
# Read bytes without line breaks.
@ -120,7 +120,7 @@ class StreamReaderTests(unittest.TestCase):
data = self.loop.run_until_complete(stream.read(5))
self.assertEqual(b'line1', data)
self.assertEqual(5, stream.byte_count)
self.assertEqual(5, stream._byte_count)
def test_read_eof(self):
# Read bytes, stop at eof.
@ -133,7 +133,7 @@ class StreamReaderTests(unittest.TestCase):
data = self.loop.run_until_complete(read_task)
self.assertEqual(b'', data)
self.assertFalse(stream.byte_count)
self.assertFalse(stream._byte_count)
def test_read_until_eof(self):
# Read all bytes until eof.
@ -149,7 +149,7 @@ class StreamReaderTests(unittest.TestCase):
data = self.loop.run_until_complete(read_task)
self.assertEqual(b'chunk1\nchunk2', data)
self.assertFalse(stream.byte_count)
self.assertFalse(stream._byte_count)
def test_read_exception(self):
stream = streams.StreamReader(loop=self.loop)
@ -176,7 +176,7 @@ class StreamReaderTests(unittest.TestCase):
line = self.loop.run_until_complete(read_task)
self.assertEqual(b'chunk1 chunk2 chunk3 \n', line)
self.assertEqual(len(b'\n chunk4')-1, stream.byte_count)
self.assertEqual(len(b'\n chunk4')-1, stream._byte_count)
def test_readline_limit_with_existing_data(self):
stream = streams.StreamReader(3, loop=self.loop)
@ -185,7 +185,7 @@ class StreamReaderTests(unittest.TestCase):
self.assertRaises(
ValueError, self.loop.run_until_complete, stream.readline())
self.assertEqual([b'line2\n'], list(stream.buffer))
self.assertEqual([b'line2\n'], list(stream._buffer))
stream = streams.StreamReader(3, loop=self.loop)
stream.feed_data(b'li')
@ -194,8 +194,8 @@ class StreamReaderTests(unittest.TestCase):
self.assertRaises(
ValueError, self.loop.run_until_complete, stream.readline())
self.assertEqual([b'li'], list(stream.buffer))
self.assertEqual(2, stream.byte_count)
self.assertEqual([b'li'], list(stream._buffer))
self.assertEqual(2, stream._byte_count)
def test_readline_limit(self):
stream = streams.StreamReader(7, loop=self.loop)
@ -209,8 +209,8 @@ class StreamReaderTests(unittest.TestCase):
self.assertRaises(
ValueError, self.loop.run_until_complete, stream.readline())
self.assertEqual([b'chunk3\n'], list(stream.buffer))
self.assertEqual(7, stream.byte_count)
self.assertEqual([b'chunk3\n'], list(stream._buffer))
self.assertEqual(7, stream._byte_count)
def test_readline_line_byte_count(self):
stream = streams.StreamReader(loop=self.loop)
@ -220,7 +220,7 @@ class StreamReaderTests(unittest.TestCase):
line = self.loop.run_until_complete(stream.readline())
self.assertEqual(b'line1\n', line)
self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count)
self.assertEqual(len(self.DATA) - len(b'line1\n'), stream._byte_count)
def test_readline_eof(self):
stream = streams.StreamReader(loop=self.loop)
@ -248,7 +248,7 @@ class StreamReaderTests(unittest.TestCase):
self.assertEqual(b'line2\nl', data)
self.assertEqual(
len(self.DATA) - len(b'line1\n') - len(b'line2\nl'),
stream.byte_count)
stream._byte_count)
def test_readline_exception(self):
stream = streams.StreamReader(loop=self.loop)
@ -268,11 +268,11 @@ class StreamReaderTests(unittest.TestCase):
data = self.loop.run_until_complete(stream.readexactly(0))
self.assertEqual(b'', data)
self.assertEqual(len(self.DATA), stream.byte_count)
self.assertEqual(len(self.DATA), stream._byte_count)
data = self.loop.run_until_complete(stream.readexactly(-1))
self.assertEqual(b'', data)
self.assertEqual(len(self.DATA), stream.byte_count)
self.assertEqual(len(self.DATA), stream._byte_count)
def test_readexactly(self):
# Read exact number of bytes.
@ -289,7 +289,7 @@ class StreamReaderTests(unittest.TestCase):
data = self.loop.run_until_complete(read_task)
self.assertEqual(self.DATA + self.DATA, data)
self.assertEqual(len(self.DATA), stream.byte_count)
self.assertEqual(len(self.DATA), stream._byte_count)
def test_readexactly_eof(self):
# Read exact number of bytes (eof).
@ -304,7 +304,7 @@ class StreamReaderTests(unittest.TestCase):
data = self.loop.run_until_complete(read_task)
self.assertEqual(self.DATA, data)
self.assertFalse(stream.byte_count)
self.assertFalse(stream._byte_count)
def test_readexactly_exception(self):
stream = streams.StreamReader(loop=self.loop)
@ -357,7 +357,7 @@ class StreamReaderTests(unittest.TestCase):
# The following line fails if set_exception() isn't careful.
stream.set_exception(RuntimeError('message'))
test_utils.run_briefly(self.loop)
self.assertIs(stream.waiter, None)
self.assertIs(stream._waiter, None)
if __name__ == '__main__':