From d5053a3d07998a3114390804d48a161c81f7d4f8 Mon Sep 17 00:00:00 2001
From: Thomas Grainger <tagrain@gmail.com>
Date: Thu, 23 Jan 2025 15:53:53 +0000
Subject: [PATCH] gh-128479: fix asyncio staggered race leaking tasks, and
 logging unhandled exception.append exception (GH-128475) (cherry picked from
 commit ec91e1c2762412f1408b0dfb5d281873b852affe)

Co-authored-by: Thomas Grainger <tagrain@gmail.com>
Co-authored-by: Peter Bierma <zintensitydev@gmail.com>
---
 Lib/asyncio/staggered.py                      | 72 ++++++++++++-------
 Lib/test/test_asyncio/test_staggered.py       | 27 +++++++
 ...-01-04-11-10-04.gh-issue-128479.jvOrF-.rst |  1 +
 3 files changed, 76 insertions(+), 24 deletions(-)
 create mode 100644 Misc/NEWS.d/next/Library/2025-01-04-11-10-04.gh-issue-128479.jvOrF-.rst

diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py
index 0f4df8855a80b9..0afed64fdf9c0f 100644
--- a/Lib/asyncio/staggered.py
+++ b/Lib/asyncio/staggered.py
@@ -66,8 +66,27 @@ async def staggered_race(coro_fns, delay, *, loop=None):
     enum_coro_fns = enumerate(coro_fns)
     winner_result = None
     winner_index = None
+    unhandled_exceptions = []
     exceptions = []
-    running_tasks = []
+    running_tasks = set()
+    on_completed_fut = None
+
+    def task_done(task):
+        running_tasks.discard(task)
+        if (
+            on_completed_fut is not None
+            and not on_completed_fut.done()
+            and not running_tasks
+        ):
+            on_completed_fut.set_result(None)
+
+        if task.cancelled():
+            return
+
+        exc = task.exception()
+        if exc is None:
+            return
+        unhandled_exceptions.append(exc)
 
     async def run_one_coro(ok_to_start, previous_failed) -> None:
         # in eager tasks this waits for the calling task to append this task
@@ -91,11 +110,11 @@ async def run_one_coro(ok_to_start, previous_failed) -> None:
         this_failed = locks.Event()
         next_ok_to_start = locks.Event()
         next_task = loop.create_task(run_one_coro(next_ok_to_start, this_failed))
-        running_tasks.append(next_task)
+        running_tasks.add(next_task)
+        next_task.add_done_callback(task_done)
         # next_task has been appended to running_tasks so next_task is ok to
         # start.
         next_ok_to_start.set()
-        assert len(running_tasks) == this_index + 2
         # Prepare place to put this coroutine's exceptions if not won
         exceptions.append(None)
         assert len(exceptions) == this_index + 1
@@ -120,31 +139,36 @@ async def run_one_coro(ok_to_start, previous_failed) -> None:
             # up as done() == True, cancelled() == False, exception() ==
             # asyncio.CancelledError. This behavior is specified in
             # https://bugs.python.org/issue30048
-            for i, t in enumerate(running_tasks):
-                if i != this_index:
+            current_task = tasks.current_task(loop)
+            for t in running_tasks:
+                if t is not current_task:
                     t.cancel()
 
-    ok_to_start = locks.Event()
-    first_task = loop.create_task(run_one_coro(ok_to_start, None))
-    running_tasks.append(first_task)
-    # first_task has been appended to running_tasks so first_task is ok to start.
-    ok_to_start.set()
+    propagate_cancellation_error = None
     try:
-        # Wait for a growing list of tasks to all finish: poor man's version of
-        # curio's TaskGroup or trio's nursery
-        done_count = 0
-        while done_count != len(running_tasks):
-            done, _ = await tasks.wait(running_tasks)
-            done_count = len(done)
+        ok_to_start = locks.Event()
+        first_task = loop.create_task(run_one_coro(ok_to_start, None))
+        running_tasks.add(first_task)
+        first_task.add_done_callback(task_done)
+        # first_task has been appended to running_tasks so first_task is ok to start.
+        ok_to_start.set()
+        propagate_cancellation_error = None
+        # Make sure no tasks are left running if we leave this function
+        while running_tasks:
+            on_completed_fut = loop.create_future()
+            try:
+                await on_completed_fut
+            except exceptions_mod.CancelledError as ex:
+                propagate_cancellation_error = ex
+                for task in running_tasks:
+                    task.cancel(*ex.args)
+            on_completed_fut = None
+        if __debug__ and unhandled_exceptions:
             # If run_one_coro raises an unhandled exception, it's probably a
             # programming error, and I want to see it.
-            if __debug__:
-                for d in done:
-                    if d.done() and not d.cancelled() and d.exception():
-                        raise d.exception()
+            raise ExceptionGroup("staggered race failed", unhandled_exceptions)
+        if propagate_cancellation_error is not None:
+            raise propagate_cancellation_error
         return winner_result, winner_index, exceptions
     finally:
-        del exceptions
-        # Make sure no tasks are left running if we leave this function
-        for t in running_tasks:
-            t.cancel()
+        del exceptions, propagate_cancellation_error, unhandled_exceptions
diff --git a/Lib/test/test_asyncio/test_staggered.py b/Lib/test/test_asyncio/test_staggered.py
index 74941f704c4890..40455a3804e3dd 100644
--- a/Lib/test/test_asyncio/test_staggered.py
+++ b/Lib/test/test_asyncio/test_staggered.py
@@ -122,3 +122,30 @@ async def do_set():
         self.assertIsNone(excs[0], None)
         self.assertIsInstance(excs[1], asyncio.CancelledError)
         self.assertIsInstance(excs[2], asyncio.CancelledError)
+
+
+    async def test_cancelled(self):
+        log = []
+        with self.assertRaises(TimeoutError):
+            async with asyncio.timeout(None) as cs_outer, asyncio.timeout(None) as cs_inner:
+                async def coro_fn():
+                    cs_inner.reschedule(-1)
+                    await asyncio.sleep(0)
+                    try:
+                        await asyncio.sleep(0)
+                    except asyncio.CancelledError:
+                        log.append("cancelled 1")
+
+                    cs_outer.reschedule(-1)
+                    await asyncio.sleep(0)
+                    try:
+                        await asyncio.sleep(0)
+                    except asyncio.CancelledError:
+                        log.append("cancelled 2")
+                try:
+                    await staggered_race([coro_fn], delay=None)
+                except asyncio.CancelledError:
+                    log.append("cancelled 3")
+                    raise
+
+        self.assertListEqual(log, ["cancelled 1", "cancelled 2", "cancelled 3"])
diff --git a/Misc/NEWS.d/next/Library/2025-01-04-11-10-04.gh-issue-128479.jvOrF-.rst b/Misc/NEWS.d/next/Library/2025-01-04-11-10-04.gh-issue-128479.jvOrF-.rst
new file mode 100644
index 00000000000000..fc3b4d5a5273a6
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2025-01-04-11-10-04.gh-issue-128479.jvOrF-.rst
@@ -0,0 +1 @@
+Fix :func:`!asyncio.staggered.staggered_race` leaking tasks and issuing an unhandled exception.