Skip to content

Commit 5216178

Browse files
GH-111693: Propagate correct asyncio.CancelledError instance out of asyncio.Condition.wait() (#111694)
Also fix a race condition in `asyncio.Semaphore.acquire()` when cancelled.
1 parent c6ca562 commit 5216178

File tree

4 files changed

+153
-25
lines changed

4 files changed

+153
-25
lines changed

Lib/asyncio/futures.py

-3
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,6 @@ def _make_cancelled_error(self):
138138
exc = exceptions.CancelledError()
139139
else:
140140
exc = exceptions.CancelledError(self._cancel_message)
141-
exc.__context__ = self._cancelled_exc
142-
# Remove the reference since we don't need this anymore.
143-
self._cancelled_exc = None
144141
return exc
145142

146143
def cancel(self, msg=None):

Lib/asyncio/locks.py

+39-22
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ async def acquire(self):
9595
This method blocks until the lock is unlocked, then sets it to
9696
locked and returns True.
9797
"""
98+
# Implement fair scheduling, where thread always waits
99+
# its turn. Jumping the queue if all are cancelled is an optimization.
98100
if (not self._locked and (self._waiters is None or
99101
all(w.cancelled() for w in self._waiters))):
100102
self._locked = True
@@ -105,19 +107,22 @@ async def acquire(self):
105107
fut = self._get_loop().create_future()
106108
self._waiters.append(fut)
107109

108-
# Finally block should be called before the CancelledError
109-
# handling as we don't want CancelledError to call
110-
# _wake_up_first() and attempt to wake up itself.
111110
try:
112111
try:
113112
await fut
114113
finally:
115114
self._waiters.remove(fut)
116115
except exceptions.CancelledError:
116+
# Currently the only exception designed be able to occur here.
117+
118+
# Ensure the lock invariant: If lock is not claimed (or about
119+
# to be claimed by us) and there is a Task in waiters,
120+
# ensure that the Task at the head will run.
117121
if not self._locked:
118122
self._wake_up_first()
119123
raise
120124

125+
# assert self._locked is False
121126
self._locked = True
122127
return True
123128

@@ -139,17 +144,15 @@ def release(self):
139144
raise RuntimeError('Lock is not acquired.')
140145

141146
def _wake_up_first(self):
142-
"""Wake up the first waiter if it isn't done."""
147+
"""Ensure that the first waiter will wake up."""
143148
if not self._waiters:
144149
return
145150
try:
146151
fut = next(iter(self._waiters))
147152
except StopIteration:
148153
return
149154

150-
# .done() necessarily means that a waiter will wake up later on and
151-
# either take the lock, or, if it was cancelled and lock wasn't
152-
# taken already, will hit this again and wake up a new waiter.
155+
# .done() means that the waiter is already set to wake up.
153156
if not fut.done():
154157
fut.set_result(True)
155158

@@ -269,17 +272,22 @@ async def wait(self):
269272
self._waiters.remove(fut)
270273

271274
finally:
272-
# Must reacquire lock even if wait is cancelled
273-
cancelled = False
275+
# Must re-acquire lock even if wait is cancelled.
276+
# We only catch CancelledError here, since we don't want any
277+
# other (fatal) errors with the future to cause us to spin.
278+
err = None
274279
while True:
275280
try:
276281
await self.acquire()
277282
break
278-
except exceptions.CancelledError:
279-
cancelled = True
283+
except exceptions.CancelledError as e:
284+
err = e
280285

281-
if cancelled:
282-
raise exceptions.CancelledError
286+
if err:
287+
try:
288+
raise err # Re-raise most recent exception instance.
289+
finally:
290+
err = None # Break reference cycles.
283291

284292
async def wait_for(self, predicate):
285293
"""Wait until a predicate becomes true.
@@ -357,6 +365,7 @@ def __repr__(self):
357365

358366
def locked(self):
359367
"""Returns True if semaphore cannot be acquired immediately."""
368+
# Due to state, or FIFO rules (must allow others to run first).
360369
return self._value == 0 or (
361370
any(not w.cancelled() for w in (self._waiters or ())))
362371

@@ -370,6 +379,7 @@ async def acquire(self):
370379
True.
371380
"""
372381
if not self.locked():
382+
# Maintain FIFO, wait for others to start even if _value > 0.
373383
self._value -= 1
374384
return True
375385

@@ -378,22 +388,27 @@ async def acquire(self):
378388
fut = self._get_loop().create_future()
379389
self._waiters.append(fut)
380390

381-
# Finally block should be called before the CancelledError
382-
# handling as we don't want CancelledError to call
383-
# _wake_up_first() and attempt to wake up itself.
384391
try:
385392
try:
386393
await fut
387394
finally:
388395
self._waiters.remove(fut)
389396
except exceptions.CancelledError:
390-
if not fut.cancelled():
397+
# Currently the only exception designed be able to occur here.
398+
if fut.done() and not fut.cancelled():
399+
# Our Future was successfully set to True via _wake_up_next(),
400+
# but we are not about to successfully acquire(). Therefore we
401+
# must undo the bookkeeping already done and attempt to wake
402+
# up someone else.
391403
self._value += 1
392-
self._wake_up_next()
393404
raise
394405

395-
if self._value > 0:
396-
self._wake_up_next()
406+
finally:
407+
# New waiters may have arrived but had to wait due to FIFO.
408+
# Wake up as many as are allowed.
409+
while self._value > 0:
410+
if not self._wake_up_next():
411+
break # There was no-one to wake up.
397412
return True
398413

399414
def release(self):
@@ -408,13 +423,15 @@ def release(self):
408423
def _wake_up_next(self):
409424
"""Wake up the first waiter that isn't done."""
410425
if not self._waiters:
411-
return
426+
return False
412427

413428
for fut in self._waiters:
414429
if not fut.done():
415430
self._value -= 1
416431
fut.set_result(True)
417-
return
432+
# `fut` is now `done()` and not `cancelled()`.
433+
return True
434+
return False
418435

419436

420437
class BoundedSemaphore(Semaphore):

Lib/test/test_asyncio/test_locks.py

+113
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,63 @@ async def test_timeout_in_block(self):
758758
with self.assertRaises(asyncio.TimeoutError):
759759
await asyncio.wait_for(condition.wait(), timeout=0.5)
760760

761+
async def test_cancelled_error_wakeup(self):
762+
# Test that a cancelled error, received when awaiting wakeup,
763+
# will be re-raised un-modified.
764+
wake = False
765+
raised = None
766+
cond = asyncio.Condition()
767+
768+
async def func():
769+
nonlocal raised
770+
async with cond:
771+
with self.assertRaises(asyncio.CancelledError) as err:
772+
await cond.wait_for(lambda: wake)
773+
raised = err.exception
774+
raise raised
775+
776+
task = asyncio.create_task(func())
777+
await asyncio.sleep(0)
778+
# Task is waiting on the condition, cancel it there.
779+
task.cancel(msg="foo")
780+
with self.assertRaises(asyncio.CancelledError) as err:
781+
await task
782+
self.assertEqual(err.exception.args, ("foo",))
783+
# We should have got the _same_ exception instance as the one
784+
# originally raised.
785+
self.assertIs(err.exception, raised)
786+
787+
async def test_cancelled_error_re_aquire(self):
788+
# Test that a cancelled error, received when re-aquiring lock,
789+
# will be re-raised un-modified.
790+
wake = False
791+
raised = None
792+
cond = asyncio.Condition()
793+
794+
async def func():
795+
nonlocal raised
796+
async with cond:
797+
with self.assertRaises(asyncio.CancelledError) as err:
798+
await cond.wait_for(lambda: wake)
799+
raised = err.exception
800+
raise raised
801+
802+
task = asyncio.create_task(func())
803+
await asyncio.sleep(0)
804+
# Task is waiting on the condition
805+
await cond.acquire()
806+
wake = True
807+
cond.notify()
808+
await asyncio.sleep(0)
809+
# Task is now trying to re-acquire the lock, cancel it there.
810+
task.cancel(msg="foo")
811+
cond.release()
812+
with self.assertRaises(asyncio.CancelledError) as err:
813+
await task
814+
self.assertEqual(err.exception.args, ("foo",))
815+
# We should have got the _same_ exception instance as the one
816+
# originally raised.
817+
self.assertIs(err.exception, raised)
761818

762819
class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
763820

@@ -1044,6 +1101,62 @@ async def c3(result):
10441101
await asyncio.gather(*tasks, return_exceptions=True)
10451102
self.assertEqual([2, 3], result)
10461103

1104+
async def test_acquire_fifo_order_4(self):
1105+
# Test that a successfule `acquire()` will wake up multiple Tasks
1106+
# that were waiting in the Semaphore queue due to FIFO rules.
1107+
sem = asyncio.Semaphore(0)
1108+
result = []
1109+
count = 0
1110+
1111+
async def c1(result):
1112+
# First task immediatlly waits for semaphore. It will be awoken by c2.
1113+
self.assertEqual(sem._value, 0)
1114+
await sem.acquire()
1115+
# We should have woken up all waiting tasks now.
1116+
self.assertEqual(sem._value, 0)
1117+
# Create a fourth task. It should run after c3, not c2.
1118+
nonlocal t4
1119+
t4 = asyncio.create_task(c4(result))
1120+
result.append(1)
1121+
return True
1122+
1123+
async def c2(result):
1124+
# The second task begins by releasing semaphore three times,
1125+
# for c1, c2, and c3.
1126+
sem.release()
1127+
sem.release()
1128+
sem.release()
1129+
self.assertEqual(sem._value, 2)
1130+
# It is locked, because c1 hasn't woken up yet.
1131+
self.assertTrue(sem.locked())
1132+
await sem.acquire()
1133+
result.append(2)
1134+
return True
1135+
1136+
async def c3(result):
1137+
await sem.acquire()
1138+
self.assertTrue(sem.locked())
1139+
result.append(3)
1140+
return True
1141+
1142+
async def c4(result):
1143+
result.append(4)
1144+
return True
1145+
1146+
t1 = asyncio.create_task(c1(result))
1147+
t2 = asyncio.create_task(c2(result))
1148+
t3 = asyncio.create_task(c3(result))
1149+
t4 = None
1150+
1151+
await asyncio.sleep(0)
1152+
# Three tasks are in the queue, the first hasn't woken up yet.
1153+
self.assertEqual(sem._value, 2)
1154+
self.assertEqual(len(sem._waiters), 3)
1155+
await asyncio.sleep(0)
1156+
1157+
tasks = [t1, t2, t3, t4]
1158+
await asyncio.gather(*tasks)
1159+
self.assertEqual([1, 2, 3, 4], result)
10471160

10481161
class BarrierTests(unittest.IsolatedAsyncioTestCase):
10491162

Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:func:`asyncio.Condition.wait()` now re-raises the same :exc:`CancelledError` instance that may have caused it to be interrupted. Fixed race condition in :func:`asyncio.Semaphore.aquire` when interrupted with a :exc:`CancelledError`.

0 commit comments

Comments
 (0)