gh-81322: support multiple separators in StreamReader.readuntil (#16429)

This commit is contained in:
Bruce Merry 2024-04-08 18:58:02 +02:00 committed by GitHub
parent 24a2bd0481
commit 775912a51d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 102 additions and 20 deletions

View File

@ -260,8 +260,19 @@ StreamReader
buffer is reset. The :attr:`IncompleteReadError.partial` attribute
may contain a portion of the separator.
The *separator* may also be an :term:`iterable` of separators. In this
case the return value will be the shortest possible that has any
separator as the suffix. For the purposes of :exc:`LimitOverrunError`,
the shortest possible separator is considered to be the one that
matched.
.. versionadded:: 3.5.2
.. versionchanged:: 3.13
The *separator* parameter may now be an :term:`iterable` of
separators.
.. method:: at_eof()
Return ``True`` if the buffer is empty and :meth:`feed_eof`

View File

@ -590,20 +590,34 @@ class StreamReader:
If the data cannot be read because of over limit, a
LimitOverrunError exception will be raised, and the data
will be left in the internal buffer, so it can be read again.
The ``separator`` may also be an iterable of separators. In this
case the return value will be the shortest possible that has any
separator as the suffix. For the purposes of LimitOverrunError,
the shortest possible separator is considered to be the one that
matched.
"""
seplen = len(separator)
if seplen == 0:
if isinstance(separator, bytes):
separator = [separator]
else:
# Makes sure shortest matches wins, and supports arbitrary iterables
separator = sorted(separator, key=len)
if not separator:
raise ValueError('Separator should contain at least one element')
min_seplen = len(separator[0])
max_seplen = len(separator[-1])
if min_seplen == 0:
raise ValueError('Separator should be at least one-byte string')
if self._exception is not None:
raise self._exception
# Consume whole buffer except last bytes, which length is
# one less than seplen. Let's check corner cases with
# separator='SEPARATOR':
# one less than max_seplen. Let's check corner cases with
# separator[-1]='SEPARATOR':
# * we have received almost complete separator (without last
# byte). i.e buffer='some textSEPARATO'. In this case we
# can safely consume len(separator) - 1 bytes.
# can safely consume max_seplen - 1 bytes.
# * last byte of buffer is first byte of separator, i.e.
# buffer='abcdefghijklmnopqrS'. We may safely consume
# everything except that last byte, but this require to
@ -616,26 +630,35 @@ class StreamReader:
# messages :)
# `offset` is the number of bytes from the beginning of the buffer
# where there is no occurrence of `separator`.
# where there is no occurrence of any `separator`.
offset = 0
# Loop until we find `separator` in the buffer, exceed the buffer size,
# Loop until we find a `separator` in the buffer, exceed the buffer size,
# or an EOF has happened.
while True:
buflen = len(self._buffer)
# Check if we now have enough data in the buffer for `separator` to
# fit.
if buflen - offset >= seplen:
isep = self._buffer.find(separator, offset)
# Check if we now have enough data in the buffer for shortest
# separator to fit.
if buflen - offset >= min_seplen:
match_start = None
match_end = None
for sep in separator:
isep = self._buffer.find(sep, offset)
if isep != -1:
# `separator` is in the buffer. `isep` will be used later
# to retrieve the data.
if isep != -1:
# `separator` is in the buffer. `match_start` and
# `match_end` will be used later to retrieve the
# data.
end = isep + len(sep)
if match_end is None or end < match_end:
match_end = end
match_start = isep
if match_end is not None:
break
# see upper comment for explanation.
offset = buflen + 1 - seplen
offset = max(0, buflen + 1 - max_seplen)
if offset > self._limit:
raise exceptions.LimitOverrunError(
'Separator is not found, and chunk exceed the limit',
@ -644,7 +667,7 @@ class StreamReader:
# Complete message (with full separator) may be present in buffer
# even when EOF flag is set. This may happen when the last chunk
# adds data which makes separator be found. That's why we check for
# EOF *ater* inspecting the buffer.
# EOF *after* inspecting the buffer.
if self._eof:
chunk = bytes(self._buffer)
self._buffer.clear()
@ -653,12 +676,12 @@ class StreamReader:
# _wait_for_data() will resume reading if stream was paused.
await self._wait_for_data('readuntil')
if isep > self._limit:
if match_start > self._limit:
raise exceptions.LimitOverrunError(
'Separator is found, but chunk is longer than limit', isep)
'Separator is found, but chunk is longer than limit', match_start)
chunk = self._buffer[:isep + seplen]
del self._buffer[:isep + seplen]
chunk = self._buffer[:match_end]
del self._buffer[:match_end]
self._maybe_resume_transport()
return bytes(chunk)

View File

@ -383,6 +383,10 @@ class StreamTests(test_utils.TestCase):
stream = asyncio.StreamReader(loop=self.loop)
with self.assertRaisesRegex(ValueError, 'Separator should be'):
self.loop.run_until_complete(stream.readuntil(separator=b''))
with self.assertRaisesRegex(ValueError, 'Separator should be'):
self.loop.run_until_complete(stream.readuntil(separator=[b'']))
with self.assertRaisesRegex(ValueError, 'Separator should contain'):
self.loop.run_until_complete(stream.readuntil(separator=[]))
def test_readuntil_multi_chunks(self):
stream = asyncio.StreamReader(loop=self.loop)
@ -466,6 +470,48 @@ class StreamTests(test_utils.TestCase):
self.assertEqual(b'some dataAAA', stream._buffer)
def test_readuntil_multi_separator(self):
stream = asyncio.StreamReader(loop=self.loop)
# Simple case
stream.feed_data(b'line 1\nline 2\r')
data = self.loop.run_until_complete(stream.readuntil([b'\r', b'\n']))
self.assertEqual(b'line 1\n', data)
data = self.loop.run_until_complete(stream.readuntil([b'\r', b'\n']))
self.assertEqual(b'line 2\r', data)
self.assertEqual(b'', stream._buffer)
# First end position matches, even if that's a longer match
stream.feed_data(b'ABCDEFG')
data = self.loop.run_until_complete(stream.readuntil([b'DEF', b'BCDE']))
self.assertEqual(b'ABCDE', data)
self.assertEqual(b'FG', stream._buffer)
def test_readuntil_multi_separator_limit(self):
stream = asyncio.StreamReader(loop=self.loop, limit=3)
stream.feed_data(b'some dataA')
with self.assertRaisesRegex(asyncio.LimitOverrunError,
'is found') as cm:
self.loop.run_until_complete(stream.readuntil([b'A', b'ome dataA']))
self.assertEqual(b'some dataA', stream._buffer)
def test_readuntil_multi_separator_negative_offset(self):
# If the buffer is big enough for the smallest separator (but does
# not contain it) but too small for the largest, `offset` must not
# become negative.
stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(b'data')
readuntil_task = self.loop.create_task(stream.readuntil([b'A', b'long sep']))
self.loop.call_soon(stream.feed_data, b'Z')
self.loop.call_soon(stream.feed_data, b'Aaaa')
data = self.loop.run_until_complete(readuntil_task)
self.assertEqual(b'dataZA', data)
self.assertEqual(b'aaa', stream._buffer)
def test_readexactly_zero_or_less(self):
# Read exact number of bytes (zero or less).
stream = asyncio.StreamReader(loop=self.loop)

View File

@ -0,0 +1,2 @@
Accept an iterable of separators in :meth:`asyncio.StreamReader.readuntil`, stopping
when one of them is encountered.