Skip to content

Update unittest from CPython 3.11 #4560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions Lib/unittest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,30 @@ def testMultiply(self):
'defaultTestLoader', 'SkipTest', 'skip', 'skipIf', 'skipUnless',
'expectedFailure', 'TextTestResult', 'installHandler',
'registerResult', 'removeResult', 'removeHandler',
'addModuleCleanup']
'addModuleCleanup', 'doModuleCleanups', 'enterModuleContext']

# Expose obsolete functions for backwards compatibility
# bpo-5846: Deprecated in Python 3.11, scheduled for removal in Python 3.13.
__all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases'])

__unittest = True

from .result import TestResult
from .case import (addModuleCleanup, TestCase, FunctionTestCase, SkipTest, skip,
skipIf, skipUnless, expectedFailure)
skipIf, skipUnless, expectedFailure, doModuleCleanups,
enterModuleContext)
from .suite import BaseTestSuite, TestSuite
from .loader import (TestLoader, defaultTestLoader, makeSuite, getTestCaseNames,
findTestCases)
from .loader import TestLoader, defaultTestLoader
from .main import TestProgram, main
from .runner import TextTestRunner, TextTestResult
from .signals import installHandler, registerResult, removeResult, removeHandler
# IsolatedAsyncioTestCase will be imported lazily.
from .loader import makeSuite, getTestCaseNames, findTestCases

# deprecated
_TextTestResult = TextTestResult


# There are no tests here, so don't try to run anything discovered from
# introspecting the symbols (e.g. FunctionTestCase). Instead, all our
# tests come from within unittest.test.
Expand Down
148 changes: 60 additions & 88 deletions Lib/unittest/async_case.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import contextvars
import inspect
import warnings

from .case import TestCase

Expand Down Expand Up @@ -32,8 +34,8 @@ class IsolatedAsyncioTestCase(TestCase):

def __init__(self, methodName='runTest'):
super().__init__(methodName)
self._asyncioTestLoop = None
self._asyncioCallsQueue = None
self._asyncioRunner = None
self._asyncioTestContext = contextvars.copy_context()

async def asyncSetUp(self):
pass
Expand All @@ -56,115 +58,85 @@ def addAsyncCleanup(self, func, /, *args, **kwargs):
# 3. Regular "def func()" that returns awaitable object
self.addCleanup(*(func, *args), **kwargs)

async def enterAsyncContext(self, cm):
"""Enters the supplied asynchronous context manager.

If successful, also adds its __aexit__ method as a cleanup
function and returns the result of the __aenter__ method.
"""
# We look up the special methods on the type to match the with
# statement.
cls = type(cm)
try:
enter = cls.__aenter__
exit = cls.__aexit__
except AttributeError:
raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does "
f"not support the asynchronous context manager protocol"
) from None
result = await enter(cm)
self.addAsyncCleanup(exit, cm, None, None, None)
return result

def _callSetUp(self):
self.setUp()
# Force loop to be initialized and set as the current loop
# so that setUp functions can use get_event_loop() and get the
# correct loop instance.
self._asyncioRunner.get_loop()
self._asyncioTestContext.run(self.setUp)
self._callAsync(self.asyncSetUp)

def _callTestMethod(self, method):
self._callMaybeAsync(method)
if self._callMaybeAsync(method) is not None:
warnings.warn(f'It is deprecated to return a value that is not None from a '
f'test case ({method})', DeprecationWarning, stacklevel=4)

def _callTearDown(self):
self._callAsync(self.asyncTearDown)
self.tearDown()
self._asyncioTestContext.run(self.tearDown)

def _callCleanup(self, function, *args, **kwargs):
self._callMaybeAsync(function, *args, **kwargs)

def _callAsync(self, func, /, *args, **kwargs):
assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
ret = func(*args, **kwargs)
assert inspect.isawaitable(ret), f'{func!r} returned non-awaitable'
fut = self._asyncioTestLoop.create_future()
self._asyncioCallsQueue.put_nowait((fut, ret))
return self._asyncioTestLoop.run_until_complete(fut)
assert self._asyncioRunner is not None, 'asyncio runner is not initialized'
assert inspect.iscoroutinefunction(func), f'{func!r} is not an async function'
return self._asyncioRunner.run(
func(*args, **kwargs),
context=self._asyncioTestContext
)

def _callMaybeAsync(self, func, /, *args, **kwargs):
assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
ret = func(*args, **kwargs)
if inspect.isawaitable(ret):
fut = self._asyncioTestLoop.create_future()
self._asyncioCallsQueue.put_nowait((fut, ret))
return self._asyncioTestLoop.run_until_complete(fut)
assert self._asyncioRunner is not None, 'asyncio runner is not initialized'
if inspect.iscoroutinefunction(func):
return self._asyncioRunner.run(
func(*args, **kwargs),
context=self._asyncioTestContext,
)
else:
return ret

async def _asyncioLoopRunner(self, fut):
self._asyncioCallsQueue = queue = asyncio.Queue()
fut.set_result(None)
while True:
query = await queue.get()
queue.task_done()
if query is None:
return
fut, awaitable = query
try:
ret = await awaitable
if not fut.cancelled():
fut.set_result(ret)
except (SystemExit, KeyboardInterrupt):
raise
except (BaseException, asyncio.CancelledError) as ex:
if not fut.cancelled():
fut.set_exception(ex)

def _setupAsyncioLoop(self):
assert self._asyncioTestLoop is None, 'asyncio test loop already initialized'
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.set_debug(True)
self._asyncioTestLoop = loop
fut = loop.create_future()
self._asyncioCallsTask = loop.create_task(self._asyncioLoopRunner(fut))
loop.run_until_complete(fut)

def _tearDownAsyncioLoop(self):
assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
loop = self._asyncioTestLoop
self._asyncioTestLoop = None
self._asyncioCallsQueue.put_nowait(None)
loop.run_until_complete(self._asyncioCallsQueue.join())
return self._asyncioTestContext.run(func, *args, **kwargs)

try:
# cancel all tasks
to_cancel = asyncio.all_tasks(loop)
if not to_cancel:
return

for task in to_cancel:
task.cancel()

loop.run_until_complete(
asyncio.gather(*to_cancel, return_exceptions=True))

for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler({
'message': 'unhandled exception during test shutdown',
'exception': task.exception(),
'task': task,
})
# shutdown asyncgens
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
# Prevent our executor environment from leaking to future tests.
loop.run_until_complete(loop.shutdown_default_executor())
asyncio.set_event_loop(None)
loop.close()
def _setupAsyncioRunner(self):
assert self._asyncioRunner is None, 'asyncio runner is already initialized'
runner = asyncio.Runner(debug=True)
self._asyncioRunner = runner

def _tearDownAsyncioRunner(self):
runner = self._asyncioRunner
runner.close()

def run(self, result=None):
self._setupAsyncioLoop()
self._setupAsyncioRunner()
try:
return super().run(result)
finally:
self._tearDownAsyncioLoop()
self._tearDownAsyncioRunner()

def debug(self):
self._setupAsyncioLoop()
self._setupAsyncioRunner()
super().debug()
self._tearDownAsyncioLoop()
self._tearDownAsyncioRunner()

def __del__(self):
if self._asyncioTestLoop is not None:
self._tearDownAsyncioLoop()
if self._asyncioRunner is not None:
self._tearDownAsyncioRunner()
Loading