Skip to content

Commit 775912a

Browse files
authored
gh-81322: support multiple separators in StreamReader.readuntil (#16429)
1 parent 24a2bd0 commit 775912a

File tree

4 files changed

+103
-21
lines changed

4 files changed

+103
-21
lines changed

Doc/library/asyncio-stream.rst

+11
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,19 @@ StreamReader
260260
buffer is reset. The :attr:`IncompleteReadError.partial` attribute
261261
may contain a portion of the separator.
262262

263+
The *separator* may also be an :term:`iterable` of separators. In this
264+
case the return value will be the shortest possible that has any
265+
separator as the suffix. For the purposes of :exc:`LimitOverrunError`,
266+
the shortest possible separator is considered to be the one that
267+
matched.
268+
263269
.. versionadded:: 3.5.2
264270

271+
.. versionchanged:: 3.13
272+
273+
The *separator* parameter may now be an :term:`iterable` of
274+
separators.
275+
265276
.. method:: at_eof()
266277

267278
Return ``True`` if the buffer is empty and :meth:`feed_eof`

Lib/asyncio/streams.py

+44-21
Original file line numberDiff line numberDiff line change
@@ -590,20 +590,34 @@ async def readuntil(self, separator=b'\n'):
590590
If the data cannot be read because of over limit, a
591591
LimitOverrunError exception will be raised, and the data
592592
will be left in the internal buffer, so it can be read again.
593+
594+
The ``separator`` may also be an iterable of separators. In this
595+
case the return value will be the shortest possible that has any
596+
separator as the suffix. For the purposes of LimitOverrunError,
597+
the shortest possible separator is considered to be the one that
598+
matched.
593599
"""
594-
seplen = len(separator)
595-
if seplen == 0:
600+
if isinstance(separator, bytes):
601+
separator = [separator]
602+
else:
603+
# Makes sure shortest matches wins, and supports arbitrary iterables
604+
separator = sorted(separator, key=len)
605+
if not separator:
606+
raise ValueError('Separator should contain at least one element')
607+
min_seplen = len(separator[0])
608+
max_seplen = len(separator[-1])
609+
if min_seplen == 0:
596610
raise ValueError('Separator should be at least one-byte string')
597611

598612
if self._exception is not None:
599613
raise self._exception
600614

601615
# Consume whole buffer except last bytes, which length is
602-
# one less than seplen. Let's check corner cases with
603-
# separator='SEPARATOR':
616+
# one less than max_seplen. Let's check corner cases with
617+
# separator[-1]='SEPARATOR':
604618
# * we have received almost complete separator (without last
605619
# byte). i.e buffer='some textSEPARATO'. In this case we
606-
# can safely consume len(separator) - 1 bytes.
620+
# can safely consume max_seplen - 1 bytes.
607621
# * last byte of buffer is first byte of separator, i.e.
608622
# buffer='abcdefghijklmnopqrS'. We may safely consume
609623
# everything except that last byte, but this require to
@@ -616,26 +630,35 @@ async def readuntil(self, separator=b'\n'):
616630
# messages :)
617631

618632
# `offset` is the number of bytes from the beginning of the buffer
619-
# where there is no occurrence of `separator`.
633+
# where there is no occurrence of any `separator`.
620634
offset = 0
621635

622-
# Loop until we find `separator` in the buffer, exceed the buffer size,
636+
# Loop until we find a `separator` in the buffer, exceed the buffer size,
623637
# or an EOF has happened.
624638
while True:
625639
buflen = len(self._buffer)
626640

627-
# Check if we now have enough data in the buffer for `separator` to
628-
# fit.
629-
if buflen - offset >= seplen:
630-
isep = self._buffer.find(separator, offset)
631-
632-
if isep != -1:
633-
# `separator` is in the buffer. `isep` will be used later
634-
# to retrieve the data.
641+
# Check if we now have enough data in the buffer for shortest
642+
# separator to fit.
643+
if buflen - offset >= min_seplen:
644+
match_start = None
645+
match_end = None
646+
for sep in separator:
647+
isep = self._buffer.find(sep, offset)
648+
649+
if isep != -1:
650+
# `separator` is in the buffer. `match_start` and
651+
# `match_end` will be used later to retrieve the
652+
# data.
653+
end = isep + len(sep)
654+
if match_end is None or end < match_end:
655+
match_end = end
656+
match_start = isep
657+
if match_end is not None:
635658
break
636659

637660
# see upper comment for explanation.
638-
offset = buflen + 1 - seplen
661+
offset = max(0, buflen + 1 - max_seplen)
639662
if offset > self._limit:
640663
raise exceptions.LimitOverrunError(
641664
'Separator is not found, and chunk exceed the limit',
@@ -644,7 +667,7 @@ async def readuntil(self, separator=b'\n'):
644667
# Complete message (with full separator) may be present in buffer
645668
# even when EOF flag is set. This may happen when the last chunk
646669
# adds data which makes separator be found. That's why we check for
647-
# EOF *ater* inspecting the buffer.
670+
# EOF *after* inspecting the buffer.
648671
if self._eof:
649672
chunk = bytes(self._buffer)
650673
self._buffer.clear()
@@ -653,12 +676,12 @@ async def readuntil(self, separator=b'\n'):
653676
# _wait_for_data() will resume reading if stream was paused.
654677
await self._wait_for_data('readuntil')
655678

656-
if isep > self._limit:
679+
if match_start > self._limit:
657680
raise exceptions.LimitOverrunError(
658-
'Separator is found, but chunk is longer than limit', isep)
681+
'Separator is found, but chunk is longer than limit', match_start)
659682

660-
chunk = self._buffer[:isep + seplen]
661-
del self._buffer[:isep + seplen]
683+
chunk = self._buffer[:match_end]
684+
del self._buffer[:match_end]
662685
self._maybe_resume_transport()
663686
return bytes(chunk)
664687

Lib/test/test_asyncio/test_streams.py

+46
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,10 @@ def test_readuntil_separator(self):
383383
stream = asyncio.StreamReader(loop=self.loop)
384384
with self.assertRaisesRegex(ValueError, 'Separator should be'):
385385
self.loop.run_until_complete(stream.readuntil(separator=b''))
386+
with self.assertRaisesRegex(ValueError, 'Separator should be'):
387+
self.loop.run_until_complete(stream.readuntil(separator=[b'']))
388+
with self.assertRaisesRegex(ValueError, 'Separator should contain'):
389+
self.loop.run_until_complete(stream.readuntil(separator=[]))
386390

387391
def test_readuntil_multi_chunks(self):
388392
stream = asyncio.StreamReader(loop=self.loop)
@@ -466,6 +470,48 @@ def test_readuntil_limit_found_sep(self):
466470

467471
self.assertEqual(b'some dataAAA', stream._buffer)
468472

473+
def test_readuntil_multi_separator(self):
474+
stream = asyncio.StreamReader(loop=self.loop)
475+
476+
# Simple case
477+
stream.feed_data(b'line 1\nline 2\r')
478+
data = self.loop.run_until_complete(stream.readuntil([b'\r', b'\n']))
479+
self.assertEqual(b'line 1\n', data)
480+
data = self.loop.run_until_complete(stream.readuntil([b'\r', b'\n']))
481+
self.assertEqual(b'line 2\r', data)
482+
self.assertEqual(b'', stream._buffer)
483+
484+
# First end position matches, even if that's a longer match
485+
stream.feed_data(b'ABCDEFG')
486+
data = self.loop.run_until_complete(stream.readuntil([b'DEF', b'BCDE']))
487+
self.assertEqual(b'ABCDE', data)
488+
self.assertEqual(b'FG', stream._buffer)
489+
490+
def test_readuntil_multi_separator_limit(self):
491+
stream = asyncio.StreamReader(loop=self.loop, limit=3)
492+
stream.feed_data(b'some dataA')
493+
494+
with self.assertRaisesRegex(asyncio.LimitOverrunError,
495+
'is found') as cm:
496+
self.loop.run_until_complete(stream.readuntil([b'A', b'ome dataA']))
497+
498+
self.assertEqual(b'some dataA', stream._buffer)
499+
500+
def test_readuntil_multi_separator_negative_offset(self):
501+
# If the buffer is big enough for the smallest separator (but does
502+
# not contain it) but too small for the largest, `offset` must not
503+
# become negative.
504+
stream = asyncio.StreamReader(loop=self.loop)
505+
stream.feed_data(b'data')
506+
507+
readuntil_task = self.loop.create_task(stream.readuntil([b'A', b'long sep']))
508+
self.loop.call_soon(stream.feed_data, b'Z')
509+
self.loop.call_soon(stream.feed_data, b'Aaaa')
510+
511+
data = self.loop.run_until_complete(readuntil_task)
512+
self.assertEqual(b'dataZA', data)
513+
self.assertEqual(b'aaa', stream._buffer)
514+
469515
def test_readexactly_zero_or_less(self):
470516
# Read exact number of bytes (zero or less).
471517
stream = asyncio.StreamReader(loop=self.loop)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Accept an iterable of separators in :meth:`asyncio.StreamReader.readuntil`, stopping
2+
when one of them is encountered.

0 commit comments

Comments
 (0)