|
| 1 | +import asyncio |
1 | 2 | import concurrent.futures
|
| 3 | +import contextlib |
2 | 4 | import contextvars
|
3 | 5 | import functools
|
4 | 6 | import gc
|
5 | 7 | import random
|
| 8 | +import threading |
6 | 9 | import time
|
7 | 10 | import unittest
|
8 | 11 | import weakref
|
@@ -369,6 +372,88 @@ def sub(num):
|
369 | 372 | tp.shutdown()
|
370 | 373 | self.assertEqual(results, list(range(10)))
|
371 | 374 |
|
| 375 | + @isolated_context |
| 376 | + def test_context_manager(self): |
| 377 | + cvar = contextvars.ContextVar('cvar', default='initial') |
| 378 | + self.assertEqual(cvar.get(), 'initial') |
| 379 | + with contextvars.copy_context(): |
| 380 | + self.assertEqual(cvar.get(), 'initial') |
| 381 | + cvar.set('updated') |
| 382 | + self.assertEqual(cvar.get(), 'updated') |
| 383 | + self.assertEqual(cvar.get(), 'initial') |
| 384 | + |
| 385 | + def test_context_manager_as_binding(self): |
| 386 | + ctx = contextvars.copy_context() |
| 387 | + with ctx as ctx_as_binding: |
| 388 | + self.assertIs(ctx_as_binding, ctx) |
| 389 | + |
| 390 | + @isolated_context |
| 391 | + def test_context_manager_nested(self): |
| 392 | + cvar = contextvars.ContextVar('cvar', default='default') |
| 393 | + with contextvars.copy_context() as outer_ctx: |
| 394 | + cvar.set('outer') |
| 395 | + with contextvars.copy_context() as inner_ctx: |
| 396 | + self.assertIsNot(outer_ctx, inner_ctx) |
| 397 | + self.assertEqual(cvar.get(), 'outer') |
| 398 | + cvar.set('inner') |
| 399 | + self.assertEqual(outer_ctx[cvar], 'outer') |
| 400 | + self.assertEqual(cvar.get(), 'inner') |
| 401 | + self.assertEqual(cvar.get(), 'outer') |
| 402 | + self.assertEqual(cvar.get(), 'default') |
| 403 | + |
| 404 | + @isolated_context |
| 405 | + def test_context_manager_enter_again_after_exit(self): |
| 406 | + cvar = contextvars.ContextVar('cvar', default='initial') |
| 407 | + self.assertEqual(cvar.get(), 'initial') |
| 408 | + with contextvars.copy_context() as ctx: |
| 409 | + cvar.set('updated') |
| 410 | + self.assertEqual(cvar.get(), 'updated') |
| 411 | + self.assertEqual(cvar.get(), 'initial') |
| 412 | + with ctx: |
| 413 | + self.assertEqual(cvar.get(), 'updated') |
| 414 | + self.assertEqual(cvar.get(), 'initial') |
| 415 | + |
| 416 | + @threading_helper.requires_working_threading() |
| 417 | + def test_context_manager_rejects_exit_from_different_thread(self): |
| 418 | + ctx = contextvars.copy_context() |
| 419 | + thread = threading.Thread(target=ctx.__enter__) |
| 420 | + thread.start() |
| 421 | + thread.join() |
| 422 | + with self.assertRaises(RuntimeError): |
| 423 | + ctx.__exit__(None, None, None) |
| 424 | + |
| 425 | + def test_context_manager_is_not_reentrant(self): |
| 426 | + with self.subTest('context manager then context manager'): |
| 427 | + with contextvars.copy_context() as ctx: |
| 428 | + with self.assertRaises(RuntimeError): |
| 429 | + with ctx: |
| 430 | + pass |
| 431 | + with self.subTest('context manager then run method'): |
| 432 | + with contextvars.copy_context() as ctx: |
| 433 | + with self.assertRaises(RuntimeError): |
| 434 | + ctx.run(lambda: None) |
| 435 | + with self.subTest('run method then context manager'): |
| 436 | + ctx = contextvars.copy_context() |
| 437 | + |
| 438 | + def fn(): |
| 439 | + with self.assertRaises(RuntimeError): |
| 440 | + with ctx: |
| 441 | + pass |
| 442 | + |
| 443 | + ctx.run(fn) |
| 444 | + |
| 445 | + def test_context_manager_rejects_noncurrent_exit(self): |
| 446 | + with contextvars.copy_context() as outer_ctx: |
| 447 | + with contextvars.copy_context() as inner_ctx: |
| 448 | + self.assertIsNot(outer_ctx, inner_ctx) |
| 449 | + with self.assertRaises(RuntimeError): |
| 450 | + outer_ctx.__exit__(None, None, None) |
| 451 | + |
| 452 | + def test_context_manager_rejects_nonentered_exit(self): |
| 453 | + ctx = contextvars.copy_context() |
| 454 | + with self.assertRaises(RuntimeError): |
| 455 | + ctx.__exit__(None, None, None) |
| 456 | + |
372 | 457 |
|
373 | 458 | class GeneratorContextTest(unittest.TestCase):
|
374 | 459 | def test_default_is_none(self):
|
@@ -641,6 +726,145 @@ def makegen_outer():
|
641 | 726 | self.assertEqual(gen_outer.send(ctx_outer), ('inner', 'outer'))
|
642 | 727 | self.assertEqual(cvar.get(), 'updated')
|
643 | 728 |
|
| 729 | + |
| 730 | +class AsyncContextTest(unittest.IsolatedAsyncioTestCase): |
| 731 | + async def test_asyncio_independent_contexts(self): |
| 732 | + """Check that coroutines are run with independent contexts. |
| 733 | +
|
| 734 | + Changes to context variables outside a coroutine should not affect the |
| 735 | + values seen inside the coroutine and vice-versa. (This might be |
| 736 | + implemented by manually setting the context before executing each step |
| 737 | + of (send to) a coroutine, or by ensuring that the coroutine is an |
| 738 | + independent coroutine before executing any steps.) |
| 739 | + """ |
| 740 | + cvar = contextvars.ContextVar('cvar', default='A') |
| 741 | + updated1 = asyncio.Event() |
| 742 | + updated2 = asyncio.Event() |
| 743 | + |
| 744 | + async def task1(): |
| 745 | + self.assertIs(cvar.get(), 'A') |
| 746 | + await asyncio.sleep(0) |
| 747 | + cvar.set('B') |
| 748 | + await asyncio.sleep(0) |
| 749 | + updated1.set() |
| 750 | + await updated2.wait() |
| 751 | + self.assertIs(cvar.get(), 'B') |
| 752 | + |
| 753 | + async def task2(): |
| 754 | + await updated1.wait() |
| 755 | + self.assertIs(cvar.get(), 'A') |
| 756 | + await asyncio.sleep(0) |
| 757 | + cvar.set('C') |
| 758 | + await asyncio.sleep(0) |
| 759 | + updated2.set() |
| 760 | + await asyncio.sleep(0) |
| 761 | + self.assertIs(cvar.get(), 'C') |
| 762 | + |
| 763 | + async with asyncio.TaskGroup() as tg: |
| 764 | + tg.create_task(task1()) |
| 765 | + tg.create_task(task2()) |
| 766 | + |
| 767 | + self.assertIs(cvar.get(), 'A') |
| 768 | + |
| 769 | + async def test_asynccontextmanager_is_dependent_by_default(self): |
| 770 | + """Async generator in asynccontextmanager is dependent by default. |
| 771 | +
|
| 772 | + Context switches during the yield of a generator wrapped with |
| 773 | + contextlib.asynccontextmanager should be visible to the generator by |
| 774 | + default (for backwards compatibility). |
| 775 | + """ |
| 776 | + cvar = contextvars.ContextVar('cvar', default='A') |
| 777 | + |
| 778 | + @contextlib.asynccontextmanager |
| 779 | + async def makecm(): |
| 780 | + await asyncio.sleep(0) |
| 781 | + self.assertEqual(cvar.get(), 'A') |
| 782 | + await asyncio.sleep(0) |
| 783 | + # Everything above runs during __aenter__. |
| 784 | + yield cvar.get() |
| 785 | + # Everything below runs during __aexit__. |
| 786 | + await asyncio.sleep(0) |
| 787 | + self.assertEqual(cvar.get(), 'C') |
| 788 | + await asyncio.sleep(0) |
| 789 | + cvar.set('D') |
| 790 | + await asyncio.sleep(0) |
| 791 | + |
| 792 | + cm = makecm() |
| 793 | + val = await cm.__aenter__() |
| 794 | + self.assertEqual(val, 'A') |
| 795 | + self.assertEqual(cvar.get(), 'A') |
| 796 | + cvar.set('B') |
| 797 | + |
| 798 | + with contextvars.copy_context(): |
| 799 | + cvar.set('C') |
| 800 | + await cm.__aexit__(None, None, None) |
| 801 | + self.assertEqual(cvar.get(), 'D') |
| 802 | + self.assertEqual(cvar.get(), 'B') |
| 803 | + |
| 804 | + async def test_asynccontextmanager_independent(self): |
| 805 | + cvar = contextvars.ContextVar('cvar', default='A') |
| 806 | + |
| 807 | + @contextlib.asynccontextmanager |
| 808 | + async def makecm(): |
| 809 | + # Context.__enter__ called from a generator makes the generator |
| 810 | + # independent while the `with` statement suite runs. |
| 811 | + # (Alternatively we could have set the generator's _context |
| 812 | + # property.) |
| 813 | + with contextvars.copy_context(): |
| 814 | + await asyncio.sleep(0) |
| 815 | + self.assertEqual(cvar.get(), 'A') |
| 816 | + await asyncio.sleep(0) |
| 817 | + # Everything above runs during __aenter__. |
| 818 | + yield cvar.get() |
| 819 | + # Everything below runs during __aexit__. |
| 820 | + await asyncio.sleep(0) |
| 821 | + self.assertEqual(cvar.get(), 'A') |
| 822 | + await asyncio.sleep(0) |
| 823 | + cvar.set('D') |
| 824 | + await asyncio.sleep(0) |
| 825 | + |
| 826 | + cm = makecm() |
| 827 | + val = await cm.__aenter__() |
| 828 | + self.assertEqual(val, 'A') |
| 829 | + self.assertEqual(cvar.get(), 'A') |
| 830 | + cvar.set('B') |
| 831 | + with contextvars.copy_context(): |
| 832 | + cvar.set('C') |
| 833 | + await cm.__aexit__(None, None, None) |
| 834 | + self.assertEqual(cvar.get(), 'C') |
| 835 | + self.assertEqual(cvar.get(), 'B') |
| 836 | + |
| 837 | + async def test_generator_switch_between_independent_dependent(self): |
| 838 | + cvar = contextvars.ContextVar('cvar', default='default') |
| 839 | + with contextvars.copy_context() as ctx1: |
| 840 | + cvar.set('in ctx1') |
| 841 | + with contextvars.copy_context() as ctx2: |
| 842 | + cvar.set('in ctx2') |
| 843 | + with contextvars.copy_context() as ctx3: |
| 844 | + cvar.set('in ctx3') |
| 845 | + |
| 846 | + async def makegen(): |
| 847 | + await asyncio.sleep(0) |
| 848 | + yield cvar.get() |
| 849 | + await asyncio.sleep(0) |
| 850 | + yield cvar.get() |
| 851 | + await asyncio.sleep(0) |
| 852 | + with ctx2: |
| 853 | + yield cvar.get() |
| 854 | + await asyncio.sleep(0) |
| 855 | + yield cvar.get() |
| 856 | + await asyncio.sleep(0) |
| 857 | + yield cvar.get() |
| 858 | + |
| 859 | + gen = makegen() |
| 860 | + self.assertEqual(await anext(gen), 'default') |
| 861 | + with ctx1: |
| 862 | + self.assertEqual(await anext(gen), 'in ctx1') |
| 863 | + self.assertEqual(await anext(gen), 'in ctx2') |
| 864 | + with ctx3: |
| 865 | + self.assertEqual(await anext(gen), 'in ctx2') |
| 866 | + self.assertEqual(await anext(gen), 'in ctx1') |
| 867 | + |
644 | 868 | # HAMT Tests
|
645 | 869 |
|
646 | 870 |
|
|
0 commit comments