diff --git a/Doc/library/asyncio-stream.rst b/Doc/library/asyncio-stream.rst index 68b1dff2021..6231b49b1e2 100644 --- a/Doc/library/asyncio-stream.rst +++ b/Doc/library/asyncio-stream.rst @@ -260,8 +260,19 @@ StreamReader buffer is reset. The :attr:`IncompleteReadError.partial` attribute may contain a portion of the separator. + The *separator* may also be an :term:`iterable` of separators. In this + case the return value will be the shortest possible that has any + separator as the suffix. For the purposes of :exc:`LimitOverrunError`, + the shortest possible separator is considered to be the one that + matched. + .. versionadded:: 3.5.2 + .. versionchanged:: 3.13 + + The *separator* parameter may now be an :term:`iterable` of + separators. + .. method:: at_eof() Return ``True`` if the buffer is empty and :meth:`feed_eof` diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 3fe52dbac25..4517ca22d74 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -590,20 +590,34 @@ class StreamReader: If the data cannot be read because of over limit, a LimitOverrunError exception will be raised, and the data will be left in the internal buffer, so it can be read again. + + The ``separator`` may also be an iterable of separators. In this + case the return value will be the shortest possible that has any + separator as the suffix. For the purposes of LimitOverrunError, + the shortest possible separator is considered to be the one that + matched. """ - seplen = len(separator) - if seplen == 0: + if isinstance(separator, bytes): + separator = [separator] + else: + # Makes sure shortest matches wins, and supports arbitrary iterables + separator = sorted(separator, key=len) + if not separator: + raise ValueError('Separator should contain at least one element') + min_seplen = len(separator[0]) + max_seplen = len(separator[-1]) + if min_seplen == 0: raise ValueError('Separator should be at least one-byte string') if self._exception is not None: raise self._exception # Consume whole buffer except last bytes, which length is - # one less than seplen. Let's check corner cases with - # separator='SEPARATOR': + # one less than max_seplen. Let's check corner cases with + # separator[-1]='SEPARATOR': # * we have received almost complete separator (without last # byte). i.e buffer='some textSEPARATO'. In this case we - # can safely consume len(separator) - 1 bytes. + # can safely consume max_seplen - 1 bytes. # * last byte of buffer is first byte of separator, i.e. # buffer='abcdefghijklmnopqrS'. We may safely consume # everything except that last byte, but this require to @@ -616,26 +630,35 @@ class StreamReader: # messages :) # `offset` is the number of bytes from the beginning of the buffer - # where there is no occurrence of `separator`. + # where there is no occurrence of any `separator`. offset = 0 - # Loop until we find `separator` in the buffer, exceed the buffer size, + # Loop until we find a `separator` in the buffer, exceed the buffer size, # or an EOF has happened. while True: buflen = len(self._buffer) - # Check if we now have enough data in the buffer for `separator` to - # fit. - if buflen - offset >= seplen: - isep = self._buffer.find(separator, offset) + # Check if we now have enough data in the buffer for shortest + # separator to fit. + if buflen - offset >= min_seplen: + match_start = None + match_end = None + for sep in separator: + isep = self._buffer.find(sep, offset) - if isep != -1: - # `separator` is in the buffer. `isep` will be used later - # to retrieve the data. + if isep != -1: + # `separator` is in the buffer. `match_start` and + # `match_end` will be used later to retrieve the + # data. + end = isep + len(sep) + if match_end is None or end < match_end: + match_end = end + match_start = isep + if match_end is not None: break # see upper comment for explanation. - offset = buflen + 1 - seplen + offset = max(0, buflen + 1 - max_seplen) if offset > self._limit: raise exceptions.LimitOverrunError( 'Separator is not found, and chunk exceed the limit', @@ -644,7 +667,7 @@ class StreamReader: # Complete message (with full separator) may be present in buffer # even when EOF flag is set. This may happen when the last chunk # adds data which makes separator be found. That's why we check for - # EOF *ater* inspecting the buffer. + # EOF *after* inspecting the buffer. if self._eof: chunk = bytes(self._buffer) self._buffer.clear() @@ -653,12 +676,12 @@ class StreamReader: # _wait_for_data() will resume reading if stream was paused. await self._wait_for_data('readuntil') - if isep > self._limit: + if match_start > self._limit: raise exceptions.LimitOverrunError( - 'Separator is found, but chunk is longer than limit', isep) + 'Separator is found, but chunk is longer than limit', match_start) - chunk = self._buffer[:isep + seplen] - del self._buffer[:isep + seplen] + chunk = self._buffer[:match_end] + del self._buffer[:match_end] self._maybe_resume_transport() return bytes(chunk) diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 2cf48538d5d..792e88761ac 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -383,6 +383,10 @@ class StreamTests(test_utils.TestCase): stream = asyncio.StreamReader(loop=self.loop) with self.assertRaisesRegex(ValueError, 'Separator should be'): self.loop.run_until_complete(stream.readuntil(separator=b'')) + with self.assertRaisesRegex(ValueError, 'Separator should be'): + self.loop.run_until_complete(stream.readuntil(separator=[b''])) + with self.assertRaisesRegex(ValueError, 'Separator should contain'): + self.loop.run_until_complete(stream.readuntil(separator=[])) def test_readuntil_multi_chunks(self): stream = asyncio.StreamReader(loop=self.loop) @@ -466,6 +470,48 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'some dataAAA', stream._buffer) + def test_readuntil_multi_separator(self): + stream = asyncio.StreamReader(loop=self.loop) + + # Simple case + stream.feed_data(b'line 1\nline 2\r') + data = self.loop.run_until_complete(stream.readuntil([b'\r', b'\n'])) + self.assertEqual(b'line 1\n', data) + data = self.loop.run_until_complete(stream.readuntil([b'\r', b'\n'])) + self.assertEqual(b'line 2\r', data) + self.assertEqual(b'', stream._buffer) + + # First end position matches, even if that's a longer match + stream.feed_data(b'ABCDEFG') + data = self.loop.run_until_complete(stream.readuntil([b'DEF', b'BCDE'])) + self.assertEqual(b'ABCDE', data) + self.assertEqual(b'FG', stream._buffer) + + def test_readuntil_multi_separator_limit(self): + stream = asyncio.StreamReader(loop=self.loop, limit=3) + stream.feed_data(b'some dataA') + + with self.assertRaisesRegex(asyncio.LimitOverrunError, + 'is found') as cm: + self.loop.run_until_complete(stream.readuntil([b'A', b'ome dataA'])) + + self.assertEqual(b'some dataA', stream._buffer) + + def test_readuntil_multi_separator_negative_offset(self): + # If the buffer is big enough for the smallest separator (but does + # not contain it) but too small for the largest, `offset` must not + # become negative. + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'data') + + readuntil_task = self.loop.create_task(stream.readuntil([b'A', b'long sep'])) + self.loop.call_soon(stream.feed_data, b'Z') + self.loop.call_soon(stream.feed_data, b'Aaaa') + + data = self.loop.run_until_complete(readuntil_task) + self.assertEqual(b'dataZA', data) + self.assertEqual(b'aaa', stream._buffer) + def test_readexactly_zero_or_less(self): # Read exact number of bytes (zero or less). stream = asyncio.StreamReader(loop=self.loop) diff --git a/Misc/NEWS.d/next/Library/2019-09-26-17-52-52.bpo-37141.onYY2-.rst b/Misc/NEWS.d/next/Library/2019-09-26-17-52-52.bpo-37141.onYY2-.rst new file mode 100644 index 00000000000..d916f319947 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2019-09-26-17-52-52.bpo-37141.onYY2-.rst @@ -0,0 +1,2 @@ +Accept an iterable of separators in :meth:`asyncio.StreamReader.readuntil`, stopping +when one of them is encountered.