From b7cb6e869fdb59b0538658e7bc3b88430a1f4292 Mon Sep 17 00:00:00 2001 From: Peter Bierma Date: Thu, 26 Sep 2024 01:11:17 -0400 Subject: [PATCH 1/3] gh-124309: Modernize the `staggered_race` implementation to support eager task factories (GH-124390) (cherry picked from commit de929f353c413459834a2a37b2d9b0240673d874) Co-authored-by: Peter Bierma Co-authored-by: Thomas Grainger Co-authored-by: Jelle Zijlstra Co-authored-by: Carol Willing Co-authored-by: Kumar Aditya --- Lib/asyncio/base_events.py | 2 +- Lib/asyncio/staggered.py | 79 +++++-------------- .../test_asyncio/test_eager_task_factory.py | 47 +++++++++++ Lib/test/test_asyncio/test_staggered.py | 37 ++++++++- ...-09-23-18-18-23.gh-issue-124309.iFcarA.rst | 1 + 5 files changed, 100 insertions(+), 66 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index e4a39f4d345c79..47da3e8186e783 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -1144,7 +1144,7 @@ async def create_connection( (functools.partial(self._connect_sock, exceptions, addrinfo, laddr_infos) for addrinfo in infos), - happy_eyeballs_delay, loop=self) + happy_eyeballs_delay) if sock is None: exceptions = [exc for sub in exceptions for exc in sub] diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py index c3a7441a7b091d..4458d01dece0e6 100644 --- a/Lib/asyncio/staggered.py +++ b/Lib/asyncio/staggered.py @@ -4,13 +4,14 @@ import contextlib -from . import events -from . import exceptions as exceptions_mod from . import locks from . import tasks +from . import taskgroups +class _Done(Exception): + pass -async def staggered_race(coro_fns, delay, *, loop=None): +async def staggered_race(coro_fns, delay): """Run coroutines with staggered start times and take the first to finish. This method takes an iterable of coroutine functions. The first one is @@ -42,8 +43,6 @@ async def staggered_race(coro_fns, delay, *, loop=None): delay: amount of time, in seconds, between starting coroutines. If ``None``, the coroutines will run sequentially. - loop: the event loop to use. - Returns: tuple *(winner_result, winner_index, exceptions)* where @@ -62,36 +61,11 @@ async def staggered_race(coro_fns, delay, *, loop=None): """ # TODO: when we have aiter() and anext(), allow async iterables in coro_fns. - loop = loop or events.get_running_loop() - enum_coro_fns = enumerate(coro_fns) winner_result = None winner_index = None exceptions = [] - running_tasks = [] - - async def run_one_coro(previous_failed) -> None: - # Wait for the previous task to finish, or for delay seconds - if previous_failed is not None: - with contextlib.suppress(exceptions_mod.TimeoutError): - # Use asyncio.wait_for() instead of asyncio.wait() here, so - # that if we get cancelled at this point, Event.wait() is also - # cancelled, otherwise there will be a "Task destroyed but it is - # pending" later. - await tasks.wait_for(previous_failed.wait(), delay) - # Get the next coroutine to run - try: - this_index, coro_fn = next(enum_coro_fns) - except StopIteration: - return - # Start task that will run the next coroutine - this_failed = locks.Event() - next_task = loop.create_task(run_one_coro(this_failed)) - running_tasks.append(next_task) - assert len(running_tasks) == this_index + 2 - # Prepare place to put this coroutine's exceptions if not won - exceptions.append(None) - assert len(exceptions) == this_index + 1 + async def run_one_coro(this_index, coro_fn, this_failed): try: result = await coro_fn() except (SystemExit, KeyboardInterrupt): @@ -105,34 +79,17 @@ async def run_one_coro(previous_failed) -> None: assert winner_index is None winner_index = this_index winner_result = result - # Cancel all other tasks. We take care to not cancel the current - # task as well. If we do so, then since there is no `await` after - # here and CancelledError are usually thrown at one, we will - # encounter a curious corner case where the current task will end - # up as done() == True, cancelled() == False, exception() == - # asyncio.CancelledError. This behavior is specified in - # https://bugs.python.org/issue30048 - for i, t in enumerate(running_tasks): - if i != this_index: - t.cancel() - - first_task = loop.create_task(run_one_coro(None)) - running_tasks.append(first_task) + raise _Done + try: - # Wait for a growing list of tasks to all finish: poor man's version of - # curio's TaskGroup or trio's nursery - done_count = 0 - while done_count != len(running_tasks): - done, _ = await tasks.wait(running_tasks) - done_count = len(done) - # If run_one_coro raises an unhandled exception, it's probably a - # programming error, and I want to see it. - if __debug__: - for d in done: - if d.done() and not d.cancelled() and d.exception(): - raise d.exception() - return winner_result, winner_index, exceptions - finally: - # Make sure no tasks are left running if we leave this function - for t in running_tasks: - t.cancel() + async with taskgroups.TaskGroup() as tg: + for this_index, coro_fn in enumerate(coro_fns): + this_failed = locks.Event() + exceptions.append(None) + tg.create_task(run_one_coro(this_index, coro_fn, this_failed)) + with contextlib.suppress(TimeoutError): + await tasks.wait_for(this_failed.wait(), delay) + except* _Done: + pass + + return winner_result, winner_index, exceptions diff --git a/Lib/test/test_asyncio/test_eager_task_factory.py b/Lib/test/test_asyncio/test_eager_task_factory.py index 0777f39b572486..1579ad1188d725 100644 --- a/Lib/test/test_asyncio/test_eager_task_factory.py +++ b/Lib/test/test_asyncio/test_eager_task_factory.py @@ -213,6 +213,53 @@ async def run(): self.run_coro(run()) + def test_staggered_race_with_eager_tasks(self): + # See https://github.com/python/cpython/issues/124309 + + async def fail(): + await asyncio.sleep(0) + raise ValueError("no good") + + async def run(): + winner, index, excs = await asyncio.staggered.staggered_race( + [ + lambda: asyncio.sleep(2, result="sleep2"), + lambda: asyncio.sleep(1, result="sleep1"), + lambda: fail() + ], + delay=0.25 + ) + self.assertEqual(winner, 'sleep1') + self.assertEqual(index, 1) + self.assertIsNone(excs[index]) + self.assertIsInstance(excs[0], asyncio.CancelledError) + self.assertIsInstance(excs[2], ValueError) + + self.run_coro(run()) + + def test_staggered_race_with_eager_tasks_no_delay(self): + # See https://github.com/python/cpython/issues/124309 + async def fail(): + raise ValueError("no good") + + async def run(): + winner, index, excs = await asyncio.staggered.staggered_race( + [ + lambda: fail(), + lambda: asyncio.sleep(1, result="sleep1"), + lambda: asyncio.sleep(0, result="sleep0"), + ], + delay=None + ) + self.assertEqual(winner, 'sleep1') + self.assertEqual(index, 1) + self.assertIsNone(excs[index]) + self.assertIsInstance(excs[0], ValueError) + self.assertEqual(len(excs), 2) + + self.run_coro(run()) + + class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase): Task = tasks._PyTask diff --git a/Lib/test/test_asyncio/test_staggered.py b/Lib/test/test_asyncio/test_staggered.py index e6e32f7dbbbcba..21a39b3f911747 100644 --- a/Lib/test/test_asyncio/test_staggered.py +++ b/Lib/test/test_asyncio/test_staggered.py @@ -82,16 +82,45 @@ async def test_none_successful(self): async def coro(index): raise ValueError(index) + for delay in [None, 0, 0.1, 1]: + with self.subTest(delay=delay): + winner, index, excs = await staggered_race( + [ + lambda: coro(0), + lambda: coro(1), + ], + delay=delay, + ) + + self.assertIs(winner, None) + self.assertIs(index, None) + self.assertEqual(len(excs), 2) + self.assertIsInstance(excs[0], ValueError) + self.assertIsInstance(excs[1], ValueError) + + async def test_long_delay_early_failure(self): + async def coro(index): + await asyncio.sleep(0) # Dummy coroutine for the 1 case + if index == 0: + await asyncio.sleep(0.1) # Dummy coroutine + raise ValueError(index) + + return f'Res: {index}' + winner, index, excs = await staggered_race( [ lambda: coro(0), lambda: coro(1), ], - delay=None, + delay=10, ) - self.assertIs(winner, None) - self.assertIs(index, None) + self.assertEqual(winner, 'Res: 1') + self.assertEqual(index, 1) self.assertEqual(len(excs), 2) self.assertIsInstance(excs[0], ValueError) - self.assertIsInstance(excs[1], ValueError) + self.assertIsNone(excs[1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst b/Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst new file mode 100644 index 00000000000000..89610fa44bf743 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst @@ -0,0 +1 @@ +Fixed :exc:`AssertionError` when using :func:`!asyncio.staggered.staggered_race` with :attr:`asyncio.eager_task_factory`. From 1d5c99d1c2fa74845fd965e3d0bd1c8c8ba1adc6 Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Sun, 29 Sep 2024 08:42:46 +0530 Subject: [PATCH 2/3] GH-124639: add back loop param to staggered_race (#124700) --- Lib/asyncio/staggered.py | 10 ++++++++-- Lib/test/test_asyncio/test_staggered.py | 19 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py index 4458d01dece0e6..6ccf5c3c269ff0 100644 --- a/Lib/asyncio/staggered.py +++ b/Lib/asyncio/staggered.py @@ -11,7 +11,7 @@ class _Done(Exception): pass -async def staggered_race(coro_fns, delay): +async def staggered_race(coro_fns, delay, *, loop=None): """Run coroutines with staggered start times and take the first to finish. This method takes an iterable of coroutine functions. The first one is @@ -82,7 +82,13 @@ async def run_one_coro(this_index, coro_fn, this_failed): raise _Done try: - async with taskgroups.TaskGroup() as tg: + tg = taskgroups.TaskGroup() + # Intentionally override the loop in the TaskGroup to avoid + # using the running loop, preserving backwards compatibility + # TaskGroup only starts using `_loop` after `__aenter__` + # so overriding it here is safe. + tg._loop = loop + async with tg: for this_index, coro_fn in enumerate(coro_fns): this_failed = locks.Event() exceptions.append(None) diff --git a/Lib/test/test_asyncio/test_staggered.py b/Lib/test/test_asyncio/test_staggered.py index 21a39b3f911747..8cd98394aea8f8 100644 --- a/Lib/test/test_asyncio/test_staggered.py +++ b/Lib/test/test_asyncio/test_staggered.py @@ -121,6 +121,25 @@ async def coro(index): self.assertIsInstance(excs[0], ValueError) self.assertIsNone(excs[1]) + def test_loop_argument(self): + loop = asyncio.new_event_loop() + async def coro(): + self.assertEqual(loop, asyncio.get_running_loop()) + return 'coro' + + async def main(): + winner, index, excs = await staggered_race( + [coro], + delay=0.1, + loop=loop + ) + + self.assertEqual(winner, 'coro') + self.assertEqual(index, 0) + + loop.run_until_complete(main()) + loop.close() + if __name__ == "__main__": unittest.main() From e2e34daf8d5600fc4e0c2e1e739161eddee0bd80 Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Sun, 29 Sep 2024 09:25:14 +0530 Subject: [PATCH 3/3] code review --- Lib/asyncio/base_events.py | 2 +- Lib/asyncio/staggered.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 47da3e8186e783..e4a39f4d345c79 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -1144,7 +1144,7 @@ async def create_connection( (functools.partial(self._connect_sock, exceptions, addrinfo, laddr_infos) for addrinfo in infos), - happy_eyeballs_delay) + happy_eyeballs_delay, loop=self) if sock is None: exceptions = [exc for sub in exceptions for exc in sub] diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py index 6ccf5c3c269ff0..889be2cb628483 100644 --- a/Lib/asyncio/staggered.py +++ b/Lib/asyncio/staggered.py @@ -43,6 +43,9 @@ async def staggered_race(coro_fns, delay, *, loop=None): delay: amount of time, in seconds, between starting coroutines. If ``None``, the coroutines will run sequentially. + loop: the event loop to use. + + Returns: tuple *(winner_result, winner_index, exceptions)* where