Skip to content

Commit 76c699b

Browse files
sleiderryouknowone
authored andcommitted
Update contextlib from CPython 3.12
1 parent c901bc0 commit 76c699b

File tree

2 files changed

+136
-13
lines changed

2 files changed

+136
-13
lines changed

Lib/contextlib.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,17 @@ def __exit__(self, typ, value, traceback):
145145
except StopIteration:
146146
return False
147147
else:
148-
raise RuntimeError("generator didn't stop")
148+
try:
149+
raise RuntimeError("generator didn't stop")
150+
finally:
151+
self.gen.close()
149152
else:
150153
if value is None:
151154
# Need to force instantiation so we can reliably
152155
# tell if we get the same exception back
153156
value = typ()
154157
try:
155-
self.gen.throw(typ, value, traceback)
158+
self.gen.throw(value)
156159
except StopIteration as exc:
157160
# Suppress StopIteration *unless* it's the same exception that
158161
# was passed to throw(). This prevents a StopIteration
@@ -187,7 +190,10 @@ def __exit__(self, typ, value, traceback):
187190
raise
188191
exc.__traceback__ = traceback
189192
return False
190-
raise RuntimeError("generator didn't stop after throw()")
193+
try:
194+
raise RuntimeError("generator didn't stop after throw()")
195+
finally:
196+
self.gen.close()
191197

192198
class _AsyncGeneratorContextManager(
193199
_GeneratorContextManagerBase,
@@ -212,14 +218,17 @@ async def __aexit__(self, typ, value, traceback):
212218
except StopAsyncIteration:
213219
return False
214220
else:
215-
raise RuntimeError("generator didn't stop")
221+
try:
222+
raise RuntimeError("generator didn't stop")
223+
finally:
224+
await self.gen.aclose()
216225
else:
217226
if value is None:
218227
# Need to force instantiation so we can reliably
219228
# tell if we get the same exception back
220229
value = typ()
221230
try:
222-
await self.gen.athrow(typ, value, traceback)
231+
await self.gen.athrow(value)
223232
except StopAsyncIteration as exc:
224233
# Suppress StopIteration *unless* it's the same exception that
225234
# was passed to throw(). This prevents a StopIteration
@@ -254,7 +263,10 @@ async def __aexit__(self, typ, value, traceback):
254263
raise
255264
exc.__traceback__ = traceback
256265
return False
257-
raise RuntimeError("generator didn't stop after athrow()")
266+
try:
267+
raise RuntimeError("generator didn't stop after athrow()")
268+
finally:
269+
await self.gen.aclose()
258270

259271

260272
def contextmanager(func):
@@ -441,7 +453,16 @@ def __exit__(self, exctype, excinst, exctb):
441453
# exactly reproduce the limitations of the CPython interpreter.
442454
#
443455
# See http://bugs.python.org/issue12029 for more details
444-
return exctype is not None and issubclass(exctype, self._exceptions)
456+
if exctype is None:
457+
return
458+
if issubclass(exctype, self._exceptions):
459+
return True
460+
if issubclass(exctype, BaseExceptionGroup):
461+
match, rest = excinst.split(self._exceptions)
462+
if rest is None:
463+
return True
464+
raise rest
465+
return False
445466

446467

447468
class _BaseExitStack:

Lib/test/test_contextlib.py

+108-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from contextlib import * # Tests __all__
1111
from test import support
1212
from test.support import os_helper
13+
from test.support.testcase import ExceptionIsLikeMixin
1314
import weakref
1415

1516

@@ -158,9 +159,45 @@ def whoo():
158159
yield
159160
ctx = whoo()
160161
ctx.__enter__()
161-
self.assertRaises(
162-
RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
163-
)
162+
with self.assertRaises(RuntimeError):
163+
ctx.__exit__(TypeError, TypeError("foo"), None)
164+
if support.check_impl_detail(cpython=True):
165+
# The "gen" attribute is an implementation detail.
166+
self.assertFalse(ctx.gen.gi_suspended)
167+
168+
def test_contextmanager_trap_no_yield(self):
169+
@contextmanager
170+
def whoo():
171+
if False:
172+
yield
173+
ctx = whoo()
174+
with self.assertRaises(RuntimeError):
175+
ctx.__enter__()
176+
177+
def test_contextmanager_trap_second_yield(self):
178+
@contextmanager
179+
def whoo():
180+
yield
181+
yield
182+
ctx = whoo()
183+
ctx.__enter__()
184+
with self.assertRaises(RuntimeError):
185+
ctx.__exit__(None, None, None)
186+
if support.check_impl_detail(cpython=True):
187+
# The "gen" attribute is an implementation detail.
188+
self.assertFalse(ctx.gen.gi_suspended)
189+
190+
def test_contextmanager_non_normalised(self):
191+
@contextmanager
192+
def whoo():
193+
try:
194+
yield
195+
except RuntimeError:
196+
raise SyntaxError
197+
ctx = whoo()
198+
ctx.__enter__()
199+
with self.assertRaises(SyntaxError):
200+
ctx.__exit__(RuntimeError, None, None)
164201

165202
def test_contextmanager_except(self):
166203
state = []
@@ -241,6 +278,23 @@ def test_issue29692():
241278
self.assertEqual(ex.args[0], 'issue29692:Unchained')
242279
self.assertIsNone(ex.__cause__)
243280

281+
def test_contextmanager_wrap_runtimeerror(self):
282+
@contextmanager
283+
def woohoo():
284+
try:
285+
yield
286+
except Exception as exc:
287+
raise RuntimeError(f'caught {exc}') from exc
288+
with self.assertRaises(RuntimeError):
289+
with woohoo():
290+
1 / 0
291+
# If the context manager wrapped StopIteration in a RuntimeError,
292+
# we also unwrap it, because we can't tell whether the wrapping was
293+
# done by the generator machinery or by the generator itself.
294+
with self.assertRaises(StopIteration):
295+
with woohoo():
296+
raise StopIteration
297+
244298
def _create_contextmanager_attribs(self):
245299
def attribs(**kw):
246300
def decorate(func):
@@ -252,6 +306,7 @@ def decorate(func):
252306
@attribs(foo='bar')
253307
def baz(spam):
254308
"""Whee!"""
309+
yield
255310
return baz
256311

257312
def test_contextmanager_attribs(self):
@@ -308,8 +363,11 @@ def woohoo(a, *, b):
308363

309364
def test_recursive(self):
310365
depth = 0
366+
ncols = 0
311367
@contextmanager
312368
def woohoo():
369+
nonlocal ncols
370+
ncols += 1
313371
nonlocal depth
314372
before = depth
315373
depth += 1
@@ -323,6 +381,7 @@ def recursive():
323381
recursive()
324382

325383
recursive()
384+
self.assertEqual(ncols, 10)
326385
self.assertEqual(depth, 0)
327386

328387

@@ -374,12 +433,10 @@ class FileContextTestCase(unittest.TestCase):
374433
def testWithOpen(self):
375434
tfn = tempfile.mktemp()
376435
try:
377-
f = None
378436
with open(tfn, "w", encoding="utf-8") as f:
379437
self.assertFalse(f.closed)
380438
f.write("Booh\n")
381439
self.assertTrue(f.closed)
382-
f = None
383440
with self.assertRaises(ZeroDivisionError):
384441
with open(tfn, "r", encoding="utf-8") as f:
385442
self.assertFalse(f.closed)
@@ -1160,7 +1217,7 @@ class TestRedirectStderr(TestRedirectStream, unittest.TestCase):
11601217
orig_stream = "stderr"
11611218

11621219

1163-
class TestSuppress(unittest.TestCase):
1220+
class TestSuppress(ExceptionIsLikeMixin, unittest.TestCase):
11641221

11651222
@support.requires_docstrings
11661223
def test_instance_docs(self):
@@ -1214,6 +1271,51 @@ def test_cm_is_reentrant(self):
12141271
1/0
12151272
self.assertTrue(outer_continued)
12161273

1274+
# TODO: RUSTPYTHON
1275+
@unittest.expectedFailure
1276+
def test_exception_groups(self):
1277+
eg_ve = lambda: ExceptionGroup(
1278+
"EG with ValueErrors only",
1279+
[ValueError("ve1"), ValueError("ve2"), ValueError("ve3")],
1280+
)
1281+
eg_all = lambda: ExceptionGroup(
1282+
"EG with many types of exceptions",
1283+
[ValueError("ve1"), KeyError("ke1"), ValueError("ve2"), KeyError("ke2")],
1284+
)
1285+
with suppress(ValueError):
1286+
raise eg_ve()
1287+
with suppress(ValueError, KeyError):
1288+
raise eg_all()
1289+
with self.assertRaises(ExceptionGroup) as eg1:
1290+
with suppress(ValueError):
1291+
raise eg_all()
1292+
self.assertExceptionIsLike(
1293+
eg1.exception,
1294+
ExceptionGroup(
1295+
"EG with many types of exceptions",
1296+
[KeyError("ke1"), KeyError("ke2")],
1297+
),
1298+
)
1299+
1300+
# Check handling of BaseExceptionGroup, using GeneratorExit so that
1301+
# we don't accidentally discard a ctrl-c with KeyboardInterrupt.
1302+
with suppress(GeneratorExit):
1303+
raise BaseExceptionGroup("message", [GeneratorExit()])
1304+
# If we raise a BaseException group, we can still suppress parts
1305+
with self.assertRaises(BaseExceptionGroup) as eg1:
1306+
with suppress(KeyError):
1307+
raise BaseExceptionGroup("message", [GeneratorExit("g"), KeyError("k")])
1308+
self.assertExceptionIsLike(
1309+
eg1.exception, BaseExceptionGroup("message", [GeneratorExit("g")]),
1310+
)
1311+
# If we suppress all the leaf BaseExceptions, we get a non-base ExceptionGroup
1312+
with self.assertRaises(ExceptionGroup) as eg1:
1313+
with suppress(GeneratorExit):
1314+
raise BaseExceptionGroup("message", [GeneratorExit("g"), KeyError("k")])
1315+
self.assertExceptionIsLike(
1316+
eg1.exception, ExceptionGroup("message", [KeyError("k")]),
1317+
)
1318+
12171319

12181320
class TestChdir(unittest.TestCase):
12191321
def make_relative_path(self, *parts):

0 commit comments

Comments
 (0)