From d37554a7e20c196330d88fb9a33db71eb6144150 Mon Sep 17 00:00:00 2001 From: Padraic Fanning Date: Sun, 5 Sep 2021 21:56:00 -0400 Subject: [PATCH 1/4] Update unittest to CPython 3.8 --- Lib/unittest/__init__.py | 10 +- Lib/unittest/async_case.py | 160 ++++++ Lib/unittest/case.py | 110 +++- Lib/unittest/loader.py | 4 +- Lib/unittest/mock.py | 1015 +++++++++++++++++++++++++----------- Lib/unittest/runner.py | 4 +- Lib/unittest/suite.py | 60 ++- 7 files changed, 1035 insertions(+), 328 deletions(-) create mode 100644 Lib/unittest/async_case.py diff --git a/Lib/unittest/__init__.py b/Lib/unittest/__init__.py index c55d563e0c..ace3a6fb1d 100644 --- a/Lib/unittest/__init__.py +++ b/Lib/unittest/__init__.py @@ -44,11 +44,12 @@ def testMultiply(self): SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. """ -__all__ = ['TestResult', 'TestCase', 'TestSuite', +__all__ = ['TestResult', 'TestCase', 'IsolatedAsyncioTestCase', 'TestSuite', 'TextTestRunner', 'TestLoader', 'FunctionTestCase', 'main', 'defaultTestLoader', 'SkipTest', 'skip', 'skipIf', 'skipUnless', 'expectedFailure', 'TextTestResult', 'installHandler', - 'registerResult', 'removeResult', 'removeHandler'] + 'registerResult', 'removeResult', 'removeHandler', + 'addModuleCleanup'] # Expose obsolete functions for backwards compatibility __all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases']) @@ -56,8 +57,9 @@ def testMultiply(self): __unittest = True from .result import TestResult -from .case import (TestCase, FunctionTestCase, SkipTest, skip, skipIf, - skipUnless, expectedFailure) +from .async_case import IsolatedAsyncioTestCase +from .case import (addModuleCleanup, TestCase, FunctionTestCase, SkipTest, skip, + skipIf, skipUnless, expectedFailure) from .suite import BaseTestSuite, TestSuite from .loader import (TestLoader, defaultTestLoader, makeSuite, getTestCaseNames, findTestCases) diff --git a/Lib/unittest/async_case.py b/Lib/unittest/async_case.py new file mode 100644 index 0000000000..520213c372 --- /dev/null +++ b/Lib/unittest/async_case.py @@ -0,0 +1,160 @@ +import asyncio +import inspect + +from .case import TestCase + + + +class IsolatedAsyncioTestCase(TestCase): + # Names intentionally have a long prefix + # to reduce a chance of clashing with user-defined attributes + # from inherited test case + # + # The class doesn't call loop.run_until_complete(self.setUp()) and family + # but uses a different approach: + # 1. create a long-running task that reads self.setUp() + # awaitable from queue along with a future + # 2. await the awaitable object passing in and set the result + # into the future object + # 3. Outer code puts the awaitable and the future object into a queue + # with waiting for the future + # The trick is necessary because every run_until_complete() call + # creates a new task with embedded ContextVar context. + # To share contextvars between setUp(), test and tearDown() we need to execute + # them inside the same task. + + # Note: the test case modifies event loop policy if the policy was not instantiated + # yet. + # asyncio.get_event_loop_policy() creates a default policy on demand but never + # returns None + # I believe this is not an issue in user level tests but python itself for testing + # should reset a policy in every test module + # by calling asyncio.set_event_loop_policy(None) in tearDownModule() + + def __init__(self, methodName='runTest'): + super().__init__(methodName) + self._asyncioTestLoop = None + self._asyncioCallsQueue = None + + async def asyncSetUp(self): + pass + + async def asyncTearDown(self): + pass + + def addAsyncCleanup(self, func, /, *args, **kwargs): + # A trivial trampoline to addCleanup() + # the function exists because it has a different semantics + # and signature: + # addCleanup() accepts regular functions + # but addAsyncCleanup() accepts coroutines + # + # We intentionally don't add inspect.iscoroutinefunction() check + # for func argument because there is no way + # to check for async function reliably: + # 1. It can be "async def func()" iself + # 2. Class can implement "async def __call__()" method + # 3. Regular "def func()" that returns awaitable object + self.addCleanup(*(func, *args), **kwargs) + + def _callSetUp(self): + self.setUp() + self._callAsync(self.asyncSetUp) + + def _callTestMethod(self, method): + self._callMaybeAsync(method) + + def _callTearDown(self): + self._callAsync(self.asyncTearDown) + self.tearDown() + + def _callCleanup(self, function, *args, **kwargs): + self._callMaybeAsync(function, *args, **kwargs) + + def _callAsync(self, func, /, *args, **kwargs): + assert self._asyncioTestLoop is not None + ret = func(*args, **kwargs) + assert inspect.isawaitable(ret) + fut = self._asyncioTestLoop.create_future() + self._asyncioCallsQueue.put_nowait((fut, ret)) + return self._asyncioTestLoop.run_until_complete(fut) + + def _callMaybeAsync(self, func, /, *args, **kwargs): + assert self._asyncioTestLoop is not None + ret = func(*args, **kwargs) + if inspect.isawaitable(ret): + fut = self._asyncioTestLoop.create_future() + self._asyncioCallsQueue.put_nowait((fut, ret)) + return self._asyncioTestLoop.run_until_complete(fut) + else: + return ret + + async def _asyncioLoopRunner(self, fut): + self._asyncioCallsQueue = queue = asyncio.Queue() + fut.set_result(None) + while True: + query = await queue.get() + queue.task_done() + if query is None: + return + fut, awaitable = query + try: + ret = await awaitable + if not fut.cancelled(): + fut.set_result(ret) + except (SystemExit, KeyboardInterrupt): + raise + except (BaseException, asyncio.CancelledError) as ex: + if not fut.cancelled(): + fut.set_exception(ex) + + def _setupAsyncioLoop(self): + assert self._asyncioTestLoop is None + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.set_debug(True) + self._asyncioTestLoop = loop + fut = loop.create_future() + self._asyncioCallsTask = loop.create_task(self._asyncioLoopRunner(fut)) + loop.run_until_complete(fut) + + def _tearDownAsyncioLoop(self): + assert self._asyncioTestLoop is not None + loop = self._asyncioTestLoop + self._asyncioTestLoop = None + self._asyncioCallsQueue.put_nowait(None) + loop.run_until_complete(self._asyncioCallsQueue.join()) + + try: + # cancel all tasks + to_cancel = asyncio.all_tasks(loop) + if not to_cancel: + return + + for task in to_cancel: + task.cancel() + + loop.run_until_complete( + asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)) + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler({ + 'message': 'unhandled exception during test shutdown', + 'exception': task.exception(), + 'task': task, + }) + # shutdown asyncgens + loop.run_until_complete(loop.shutdown_asyncgens()) + finally: + asyncio.set_event_loop(None) + loop.close() + + def run(self, result=None): + self._setupAsyncioLoop() + try: + return super().run(result) + finally: + self._tearDownAsyncioLoop() diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index c0170d1825..3223c0bff6 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -10,6 +10,7 @@ import collections import contextlib import traceback +import types from . import result from .util import (strclass, safe_repr, _count_diff_all_purpose, @@ -84,6 +85,30 @@ def testPartExecutor(self, test_case, isTest=False): def _id(obj): return obj + +_module_cleanups = [] +def addModuleCleanup(function, /, *args, **kwargs): + """Same as addCleanup, except the cleanup items are called even if + setUpModule fails (unlike tearDownModule).""" + _module_cleanups.append((function, args, kwargs)) + + +def doModuleCleanups(): + """Execute all module cleanup functions. Normally called for you after + tearDownModule.""" + exceptions = [] + while _module_cleanups: + function, args, kwargs = _module_cleanups.pop() + try: + function(*args, **kwargs) + except Exception as exc: + exceptions.append(exc) + if exceptions: + # Swallows all but first exception. If a multi-exception handler + # gets written we should use that here instead. + raise exceptions[0] + + def skip(reason): """ Unconditionally skip a test. @@ -98,6 +123,10 @@ def skip_wrapper(*args, **kwargs): test_item.__unittest_skip__ = True test_item.__unittest_skip_why__ = reason return test_item + if isinstance(reason, types.FunctionType): + test_item = reason + reason = '' + return decorator(test_item) return decorator def skipIf(condition, reason): @@ -157,16 +186,11 @@ def handle(self, name, args, kwargs): if not _is_subtype(self.expected, self._base_type): raise TypeError('%s() arg 1 must be %s' % (name, self._base_type_str)) - if args and args[0] is None: - warnings.warn("callable is None", - DeprecationWarning, 3) - args = () if not args: self.msg = kwargs.pop('msg', None) if kwargs: - warnings.warn('%r is an invalid keyword argument for ' - 'this function' % next(iter(kwargs)), - DeprecationWarning, 3) + raise TypeError('%r is an invalid keyword argument for ' + 'this function' % (next(iter(kwargs)),)) return self callable_obj, *args = args @@ -227,7 +251,7 @@ class _AssertWarnsContext(_AssertRaisesBaseContext): def __enter__(self): # The __warningregistry__'s need to be in a pristine state for tests # to work properly. - for v in sys.modules.values(): + for v in list(sys.modules.values()): if getattr(v, '__warningregistry__', None): v.__warningregistry__ = {} self.warnings_manager = warnings.catch_warnings(record=True) @@ -395,6 +419,8 @@ class TestCase(object): _classSetupFailed = False + _class_cleanups = [] + def __init__(self, methodName='runTest'): """Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does @@ -442,13 +468,36 @@ def addTypeEqualityFunc(self, typeobj, function): """ self._type_equality_funcs[typeobj] = function - def addCleanup(self, function, *args, **kwargs): + def addCleanup(*args, **kwargs): """Add a function, with arguments, to be called when the test is completed. Functions added are called on a LIFO basis and are called after tearDown on test failure or success. Cleanup items are called even if setUp fails (unlike tearDown).""" + if len(args) >= 2: + self, function, *args = args + elif not args: + raise TypeError("descriptor 'addCleanup' of 'TestCase' object " + "needs an argument") + elif 'function' in kwargs: + function = kwargs.pop('function') + self, *args = args + import warnings + warnings.warn("Passing 'function' as keyword argument is deprecated", + DeprecationWarning, stacklevel=2) + else: + raise TypeError('addCleanup expected at least 1 positional ' + 'argument, got %d' % (len(args)-1)) + args = tuple(args) + self._cleanups.append((function, args, kwargs)) + addCleanup.__text_signature__ = '($self, function, /, *args, **kwargs)' + + @classmethod + def addClassCleanup(cls, function, /, *args, **kwargs): + """Same as addCleanup, except the cleanup items are called even if + setUpClass fails (unlike tearDownClass).""" + cls._class_cleanups.append((function, args, kwargs)) def setUp(self): "Hook method for setting up the test fixture before exercising it." @@ -480,7 +529,7 @@ def shortDescription(self): the specified test method's docstring. """ doc = self._testMethodDoc - return doc and doc.split("\n")[0].strip() or None + return doc.strip().split("\n")[0].strip() if doc else None def id(self): @@ -519,7 +568,7 @@ def subTest(self, msg=_subtest_msg_sentinel, **params): case as failed but resumes execution at the end of the enclosed block, allowing further test code to be executed. """ - if not self._outcome.result_supports_subtests: + if self._outcome is None or not self._outcome.result_supports_subtests: yield return parent = self._subtest @@ -577,6 +626,18 @@ def _addUnexpectedSuccess(self, result): else: addUnexpectedSuccess(self) + def _callSetUp(self): + self.setUp() + + def _callTestMethod(self, method): + method() + + def _callTearDown(self): + self.tearDown() + + def _callCleanup(self, function, /, *args, **kwargs): + function(*args, **kwargs) + def run(self, result=None): orig_result = result if result is None: @@ -608,14 +669,14 @@ def run(self, result=None): self._outcome = outcome with outcome.testPartExecutor(self): - self.setUp() + self._callSetUp() if outcome.success: outcome.expecting_failure = expecting_failure with outcome.testPartExecutor(self, isTest=True): - testMethod() + self._callTestMethod(testMethod) outcome.expecting_failure = False with outcome.testPartExecutor(self): - self.tearDown() + self._callTearDown() self.doCleanups() for test, reason in outcome.skipped: @@ -653,12 +714,24 @@ def doCleanups(self): while self._cleanups: function, args, kwargs = self._cleanups.pop() with outcome.testPartExecutor(self): - function(*args, **kwargs) + self._callCleanup(function, *args, **kwargs) # return this for backwards compatibility - # even though we no longer us it internally + # even though we no longer use it internally return outcome.success + @classmethod + def doClassCleanups(cls): + """Execute all class cleanup functions. Normally called for you after + tearDownClass.""" + cls.tearDown_exceptions = [] + while cls._class_cleanups: + function, args, kwargs = cls._class_cleanups.pop() + try: + function(*args, **kwargs) + except Exception as exc: + cls.tearDown_exceptions.append(sys.exc_info()) + def __call__(self, *args, **kwds): return self.run(*args, **kwds) @@ -1167,9 +1240,8 @@ def assertDictContainsSubset(self, subset, dictionary, msg=None): def assertCountEqual(self, first, second, msg=None): - """An unordered sequence comparison asserting that the same elements, - regardless of order. If the same element occurs more than once, - it verifies that the elements occur the same number of times. + """Asserts that two iterables have the same elements, the same number of + times, without regard to order. self.assertEqual(Counter(list(first)), Counter(list(second))) diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py index d936a96e73..ba7105e1ad 100644 --- a/Lib/unittest/loader.py +++ b/Lib/unittest/loader.py @@ -229,7 +229,9 @@ def shouldIncludeMethod(attrname): testFunc = getattr(testCaseClass, attrname) if not callable(testFunc): return False - fullName = '%s.%s' % (testCaseClass.__module__, testFunc.__qualname__) + fullName = f'%s.%s.%s' % ( + testCaseClass.__module__, testCaseClass.__qualname__, attrname + ) return self.testNamePatterns is None or \ any(fnmatchcase(fullName, pattern) for pattern in self.testNamePatterns) testFnNames = list(filter(shouldIncludeMethod, dir(testCaseClass))) diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py index 382696d6c7..3629cf6109 100644 --- a/Lib/unittest/mock.py +++ b/Lib/unittest/mock.py @@ -13,6 +13,7 @@ 'ANY', 'call', 'create_autospec', + 'AsyncMock', 'FILTER_DIR', 'NonCallableMock', 'NonCallableMagicMock', @@ -24,30 +25,41 @@ __version__ = '1.0' - +import asyncio +import contextlib +import io import inspect import pprint import sys import builtins -from types import ModuleType +from types import CodeType, ModuleType, MethodType +from unittest.util import safe_repr from functools import wraps, partial _builtins = {name for name in dir(builtins) if not name.startswith('_')} -BaseExceptions = (BaseException,) -if 'java' in sys.platform: - # jython - import java - BaseExceptions = (BaseException, java.lang.Throwable) - - FILTER_DIR = True # Workaround for issue #12370 # Without this, the __class__ properties wouldn't be set correctly _safe_super = super +def _is_async_obj(obj): + if _is_instance_mock(obj) and not isinstance(obj, AsyncMock): + return False + if hasattr(obj, '__func__'): + obj = getattr(obj, '__func__') + return asyncio.iscoroutinefunction(obj) or inspect.isawaitable(obj) + + +def _is_async_func(func): + if getattr(func, '__code__', None): + return asyncio.iscoroutinefunction(func) + else: + return False + + def _is_instance_mock(obj): # can't use isinstance on Mock objects because they override __class__ # The base class for all mocks is NonCallableMock @@ -56,11 +68,20 @@ def _is_instance_mock(obj): def _is_exception(obj): return ( - isinstance(obj, BaseExceptions) or - isinstance(obj, type) and issubclass(obj, BaseExceptions) + isinstance(obj, BaseException) or + isinstance(obj, type) and issubclass(obj, BaseException) ) +def _extract_mock(obj): + # Autospecced functions will return a FunctionType with "mock" attribute + # which is the actual mock object that needs to be used. + if isinstance(obj, FunctionTypes) and hasattr(obj, 'mock'): + return obj.mock + else: + return obj + + def _get_signature_object(func, as_instance, eat_self): """ Given an arbitrary, possibly callable object, try to create a suitable @@ -69,10 +90,7 @@ def _get_signature_object(func, as_instance, eat_self): """ if isinstance(func, type) and not as_instance: # If it's a type and should be modelled as a type, use __init__. - try: - func = func.__init__ - except AttributeError: - return None + func = func.__init__ # Skip the `self` argument in __init__ eat_self = True elif not isinstance(func, FunctionTypes): @@ -98,10 +116,11 @@ def _check_signature(func, mock, skipfirst, instance=False): if sig is None: return func, sig = sig - def checksig(_mock_self, *args, **kwargs): + def checksig(self, /, *args, **kwargs): sig.bind(*args, **kwargs) _copy_func_details(func, checksig) type(mock)._mock_check_sig = checksig + type(mock).__signature__ = sig def _copy_func_details(func, funcopy): @@ -120,6 +139,8 @@ def _copy_func_details(func, funcopy): def _callable(obj): if isinstance(obj, type): return True + if isinstance(obj, (staticmethod, classmethod, MethodType)): + return _callable(obj.__func__) if getattr(obj, '__call__', None) is not None: return True return False @@ -150,8 +171,6 @@ def _set_signature(mock, original, instance=False): # creates a function with signature (*args, **kwargs) that delegates to a # mock. It still does signature checking by calling a lambda with the same # signature as the original. - if not _callable(original): - return skipfirst = isinstance(original, type) result = _get_signature_object(original, instance, skipfirst) @@ -171,17 +190,13 @@ def checksig(*args, **kwargs): return mock(*args, **kwargs)""" % name exec (src, context) funcopy = context[name] - _setup_func(funcopy, mock) + _setup_func(funcopy, mock, sig) return funcopy -def _setup_func(funcopy, mock): +def _setup_func(funcopy, mock, sig): funcopy.mock = mock - # can't use isinstance with mocks - if not _is_instance_mock(mock): - return - def assert_called_with(*args, **kwargs): return mock.assert_called_with(*args, **kwargs) def assert_called(*args, **kwargs): @@ -223,10 +238,38 @@ def reset_mock(): funcopy.assert_called = assert_called funcopy.assert_not_called = assert_not_called funcopy.assert_called_once = assert_called_once + funcopy.__signature__ = sig mock._mock_delegate = funcopy +def _setup_async_mock(mock): + mock._is_coroutine = asyncio.coroutines._is_coroutine + mock.await_count = 0 + mock.await_args = None + mock.await_args_list = _CallList() + + # Mock is not configured yet so the attributes are set + # to a function and then the corresponding mock helper function + # is called when the helper is accessed similar to _setup_func. + def wrapper(attr, /, *args, **kwargs): + return getattr(mock.mock, attr)(*args, **kwargs) + + for attribute in ('assert_awaited', + 'assert_awaited_once', + 'assert_awaited_with', + 'assert_awaited_once_with', + 'assert_any_await', + 'assert_has_awaits', + 'assert_not_awaited'): + + # setattr(mock, attribute, wrapper) causes late binding + # hence attribute will always be the last value in the loop + # Use partial(wrapper, attribute) to ensure the attribute is bound + # correctly. + setattr(mock, attribute, partial(wrapper, attribute)) + + def _is_magic(name): return '__%s__' % name[2:-2] == name @@ -265,12 +308,6 @@ def __reduce__(self): _deleted = sentinel.DELETED -def _copy(value): - if type(value) in (dict, list, tuple, set): - return type(value)(value) - return value - - _allowed_names = { 'return_value', '_mock_return_value', 'side_effect', '_mock_side_effect', '_mock_parent', '_mock_new_parent', @@ -318,6 +355,8 @@ def __repr__(self): def _check_and_set_parent(parent, value, name, new_name): + value = _extract_mock(value) + if not _is_instance_mock(value): return False if ((value._mock_name or value._mock_new_name) or @@ -345,15 +384,13 @@ def _check_and_set_parent(parent, value, name, new_name): class _MockIter(object): def __init__(self, obj): self.obj = iter(obj) - def __iter__(self): - return self def __next__(self): return next(self.obj) class Base(object): _mock_return_value = DEFAULT _mock_side_effect = None - def __init__(self, *args, **kwargs): + def __init__(self, /, *args, **kwargs): pass @@ -361,12 +398,25 @@ def __init__(self, *args, **kwargs): class NonCallableMock(Base): """A non-callable version of `Mock`""" - def __new__(cls, *args, **kw): + def __new__(cls, /, *args, **kw): # every instance has its own class # so we can create magic methods on the # class without stomping on other mocks - new = type(cls.__name__, (cls,), {'__doc__': cls.__doc__}) - instance = object.__new__(new) + bases = (cls,) + if not issubclass(cls, AsyncMock): + # Check if spec is an async object or function + sig = inspect.signature(NonCallableMock.__init__) + bound_args = sig.bind_partial(cls, *args, **kw).arguments + spec_arg = [ + arg for arg in bound_args.keys() + if arg.startswith('spec') + ] + if spec_arg: + # what if spec_set is different than spec? + if _is_async_obj(bound_args[spec_arg[0]]): + bases = (AsyncMockMixin, cls,) + new = type(cls.__name__, bases, {'__doc__': cls.__doc__}) + instance = _safe_super(NonCallableMock, cls).__new__(new) return instance @@ -420,10 +470,12 @@ def attach_mock(self, mock, attribute): Attach a mock as an attribute of this one, replacing its name and parent. Calls to the attached mock will be recorded in the `method_calls` and `mock_calls` attributes of this one.""" - mock._mock_parent = None - mock._mock_new_parent = None - mock._mock_name = '' - mock._mock_new_name = None + inner_mock = _extract_mock(mock) + + inner_mock._mock_parent = None + inner_mock._mock_new_parent = None + inner_mock._mock_name = '' + inner_mock._mock_new_name = None setattr(self, attribute, mock) @@ -441,12 +493,17 @@ def _mock_add_spec(self, spec, spec_set, _spec_as_instance=False, _eat_self=False): _spec_class = None _spec_signature = None + _spec_asyncs = [] + + for attr in dir(spec): + if asyncio.iscoroutinefunction(getattr(spec, attr, None)): + _spec_asyncs.append(attr) if spec is not None and not _is_list(spec): if isinstance(spec, type): _spec_class = spec else: - _spec_class = _get_class(spec) + _spec_class = type(spec) res = _get_signature_object(spec, _spec_as_instance, _eat_self) _spec_signature = res and res[1] @@ -458,7 +515,7 @@ def _mock_add_spec(self, spec, spec_set, _spec_as_instance=False, __dict__['_spec_set'] = spec_set __dict__['_spec_signature'] = _spec_signature __dict__['_mock_methods'] = spec - + __dict__['_spec_asyncs'] = _spec_asyncs def __get_return_value(self): ret = self._mock_return_value @@ -541,7 +598,7 @@ def reset_mock(self, visited=None,*, return_value=False, side_effect=False): self._mock_side_effect = None for child in self._mock_children.values(): - if isinstance(child, _SpecState): + if isinstance(child, _SpecState) or child is _deleted: continue child.reset_mock(visited) @@ -550,7 +607,7 @@ def reset_mock(self, visited=None,*, return_value=False, side_effect=False): ret.reset_mock(visited) - def configure_mock(self, **kwargs): + def configure_mock(self, /, **kwargs): """Set attributes on the mock through keyword arguments. Attributes plus return values and side effects can be set on child @@ -582,7 +639,8 @@ def __getattr__(self, name): raise AttributeError(name) if not self._mock_unsafe: if name.startswith(('assert', 'assret')): - raise AttributeError(name) + raise AttributeError("Attributes cannot start with 'assert' " + "or 'assret'") result = self._mock_children.get(name) if result is _deleted: @@ -618,7 +676,7 @@ def _extract_mock_name(self): dot = '.' if _name_list == ['()']: dot = '' - seen = set() + while _parent is not None: last = _parent @@ -629,11 +687,6 @@ def _extract_mock_name(self): _parent = _parent._mock_new_parent - # use ids here so as not to call __hash__ on the mocks - if id(_parent) in seen: - break - seen.add(id(_parent)) - _name_list = list(reversed(_name_list)) _first = last._mock_name or 'mock' if len(_name_list) > 1: @@ -671,12 +724,14 @@ def __dir__(self): extras = self._mock_methods or [] from_type = dir(type(self)) from_dict = list(self.__dict__) + from_child_mocks = [ + m_name for m_name, m_value in self._mock_children.items() + if m_value is not _deleted] from_type = [e for e in from_type if not e.startswith('_')] from_dict = [e for e in from_dict if not e.startswith('_') or _is_magic(e)] - return sorted(set(extras + from_type + from_dict + - list(self._mock_children))) + return sorted(set(extras + from_type + from_dict + from_child_mocks)) def __setattr__(self, name, value): @@ -726,11 +781,10 @@ def __delattr__(self, name): # not set on the instance itself return - if name in self.__dict__: - object.__delattr__(self, name) - obj = self._mock_children.get(name, _missing) - if obj is _deleted: + if name in self.__dict__: + _safe_super(NonCallableMock, self).__delattr__(name) + elif obj is _deleted: raise AttributeError(name) if obj is not _missing: del self._mock_children[name] @@ -742,14 +796,45 @@ def _format_mock_call_signature(self, args, kwargs): return _format_call_signature(name, args, kwargs) - def _format_mock_failure_message(self, args, kwargs): - message = 'Expected call: %s\nActual call: %s' + def _format_mock_failure_message(self, args, kwargs, action='call'): + message = 'expected %s not found.\nExpected: %s\nActual: %s' expected_string = self._format_mock_call_signature(args, kwargs) call_args = self.call_args - if len(call_args) == 3: - call_args = call_args[1:] actual_string = self._format_mock_call_signature(*call_args) - return message % (expected_string, actual_string) + return message % (action, expected_string, actual_string) + + + def _get_call_signature_from_name(self, name): + """ + * If call objects are asserted against a method/function like obj.meth1 + then there could be no name for the call object to lookup. Hence just + return the spec_signature of the method/function being asserted against. + * If the name is not empty then remove () and split by '.' to get + list of names to iterate through the children until a potential + match is found. A child mock is created only during attribute access + so if we get a _SpecState then no attributes of the spec were accessed + and can be safely exited. + """ + if not name: + return self._spec_signature + + sig = None + names = name.replace('()', '').split('.') + children = self._mock_children + + for name in names: + child = children.get(name) + if child is None or isinstance(child, _SpecState): + break + else: + # If an autospecced object is attached using attach_mock the + # child would be a function with mock object as attribute from + # which signature has to be derived. + child = _extract_mock(child) + children = child._mock_children + sig = child._spec_signature + + return sig def _call_matcher(self, _call): @@ -759,7 +844,12 @@ def _call_matcher(self, _call): This is a best effort method which relies on the spec's signature, if available, or falls back on the arguments themselves. """ - sig = self._spec_signature + + if isinstance(_call, tuple) and len(_call) > 2: + sig = self._get_call_signature_from_name(_call[0]) + else: + sig = self._spec_signature + if sig is not None: if len(_call) == 2: name = '' @@ -773,42 +863,45 @@ def _call_matcher(self, _call): else: return _call - def assert_not_called(_mock_self): + def assert_not_called(self): """assert that the mock was never called. """ - self = _mock_self if self.call_count != 0: - msg = ("Expected '%s' to not have been called. Called %s times." % - (self._mock_name or 'mock', self.call_count)) + msg = ("Expected '%s' to not have been called. Called %s times.%s" + % (self._mock_name or 'mock', + self.call_count, + self._calls_repr())) raise AssertionError(msg) - def assert_called(_mock_self): + def assert_called(self): """assert that the mock was called at least once """ - self = _mock_self if self.call_count == 0: msg = ("Expected '%s' to have been called." % - self._mock_name or 'mock') + (self._mock_name or 'mock')) raise AssertionError(msg) - def assert_called_once(_mock_self): + def assert_called_once(self): """assert that the mock was called only once. """ - self = _mock_self if not self.call_count == 1: - msg = ("Expected '%s' to have been called once. Called %s times." % - (self._mock_name or 'mock', self.call_count)) + msg = ("Expected '%s' to have been called once. Called %s times.%s" + % (self._mock_name or 'mock', + self.call_count, + self._calls_repr())) raise AssertionError(msg) - def assert_called_with(_mock_self, *args, **kwargs): - """assert that the mock was called with the specified arguments. + def assert_called_with(self, /, *args, **kwargs): + """assert that the last call was made with the specified arguments. Raises an AssertionError if the args and keyword args passed in are different to the last call to the mock.""" - self = _mock_self if self.call_args is None: expected = self._format_mock_call_signature(args, kwargs) - raise AssertionError('Expected call: %s\nNot called' % (expected,)) + actual = 'not called.' + error_message = ('expected call not found.\nExpected: %s\nActual: %s' + % (expected, actual)) + raise AssertionError(error_message) def _error_message(): msg = self._format_mock_failure_message(args, kwargs) @@ -820,13 +913,14 @@ def _error_message(): raise AssertionError(_error_message()) from cause - def assert_called_once_with(_mock_self, *args, **kwargs): + def assert_called_once_with(self, /, *args, **kwargs): """assert that the mock was called exactly once and that that call was with the specified arguments.""" - self = _mock_self if not self.call_count == 1: - msg = ("Expected '%s' to be called once. Called %s times." % - (self._mock_name or 'mock', self.call_count)) + msg = ("Expected '%s' to be called once. Called %s times.%s" + % (self._mock_name or 'mock', + self.call_count, + self._calls_repr())) raise AssertionError(msg) return self.assert_called_with(*args, **kwargs) @@ -842,13 +936,21 @@ def assert_has_calls(self, calls, any_order=False): If `any_order` is True then the calls can be in any order, but they must all appear in `mock_calls`.""" expected = [self._call_matcher(c) for c in calls] - cause = expected if isinstance(expected, Exception) else None + cause = next((e for e in expected if isinstance(e, Exception)), None) all_calls = _CallList(self._call_matcher(c) for c in self.mock_calls) if not any_order: if expected not in all_calls: + if cause is None: + problem = 'Calls not found.' + else: + problem = ('Error processing expected calls.\n' + 'Errors: {}').format( + [e if isinstance(e, Exception) else None + for e in expected]) raise AssertionError( - 'Calls not found.\nExpected: %r\n' - 'Actual: %r' % (_CallList(calls), self.mock_calls) + f'{problem}\n' + f'Expected: {_CallList(calls)}' + f'{self._calls_repr(prefix="Actual").rstrip(".")}' ) from cause return @@ -862,11 +964,13 @@ def assert_has_calls(self, calls, any_order=False): not_found.append(kall) if not_found: raise AssertionError( - '%r not all found in call list' % (tuple(not_found),) + '%r does not contain all of %r in its call list, ' + 'found %r instead' % (self._mock_name or 'mock', + tuple(not_found), all_calls) ) from cause - def assert_any_call(self, *args, **kwargs): + def assert_any_call(self, /, *args, **kwargs): """assert the mock has been called with the specified arguments. The assert passes if the mock has *ever* been called, unlike @@ -882,7 +986,7 @@ def assert_any_call(self, *args, **kwargs): ) from cause - def _get_child_mock(self, **kw): + def _get_child_mock(self, /, **kw): """Create the child mocks for attributes and return value. By default child mocks will be the same type as the parent. Subclasses of Mock may want to override this to customize the way @@ -890,11 +994,25 @@ def _get_child_mock(self, **kw): For non-callable mocks the callable variant will be used (rather than any custom subclass).""" + _new_name = kw.get("_new_name") + if _new_name in self.__dict__['_spec_asyncs']: + return AsyncMock(**kw) + _type = type(self) - if not issubclass(_type, CallableMixin): + if issubclass(_type, MagicMock) and _new_name in _async_method_magics: + # Any asynchronous magic becomes an AsyncMock + klass = AsyncMock + elif issubclass(_type, AsyncMockMixin): + if (_new_name in _all_sync_magics or + self._mock_methods and _new_name in self._mock_methods): + # Any synchronous method on AsyncMock becomes a MagicMock + klass = MagicMock + else: + klass = AsyncMock + elif not issubclass(_type, CallableMixin): if issubclass(_type, NonCallableMagicMock): klass = MagicMock - elif issubclass(_type, NonCallableMock) : + elif issubclass(_type, NonCallableMock): klass = Mock else: klass = _type.__mro__[1] @@ -907,6 +1025,19 @@ def _get_child_mock(self, **kw): return klass(**kw) + def _calls_repr(self, prefix="Calls"): + """Renders self.mock_calls as a string. + + Example: "\nCalls: [call(1), call(2)]." + + If self.mock_calls is empty, an empty string is returned. The + output will be truncated if very long. + """ + if not self.mock_calls: + return "" + return f"\n{prefix}: {safe_repr(self.mock_calls)}." + + def _try_iter(obj): if obj is None: @@ -923,14 +1054,12 @@ def _try_iter(obj): return obj - class CallableMixin(Base): def __init__(self, spec=None, side_effect=None, return_value=DEFAULT, wraps=None, name=None, spec_set=None, parent=None, _spec_state=None, _new_name='', _new_parent=None, **kwargs): self.__dict__['_mock_return_value'] = return_value - _safe_super(CallableMixin, self).__init__( spec, wraps, name, spec_set, parent, _spec_state, _new_name, _new_parent, **kwargs @@ -939,89 +1068,93 @@ def __init__(self, spec=None, side_effect=None, return_value=DEFAULT, self.side_effect = side_effect - def _mock_check_sig(self, *args, **kwargs): + def _mock_check_sig(self, /, *args, **kwargs): # stub method that can be replaced with one with a specific signature pass - def __call__(_mock_self, *args, **kwargs): + def __call__(self, /, *args, **kwargs): # can't use self in-case a function / method we are mocking uses self # in the signature - _mock_self._mock_check_sig(*args, **kwargs) - return _mock_self._mock_call(*args, **kwargs) + self._mock_check_sig(*args, **kwargs) + self._increment_mock_call(*args, **kwargs) + return self._mock_call(*args, **kwargs) + + def _mock_call(self, /, *args, **kwargs): + return self._execute_mock_call(*args, **kwargs) - def _mock_call(_mock_self, *args, **kwargs): - self = _mock_self + def _increment_mock_call(self, /, *args, **kwargs): self.called = True self.call_count += 1 - _new_name = self._mock_new_name - _new_parent = self._mock_new_parent + # handle call_args + # needs to be set here so assertions on call arguments pass before + # execution in the case of awaited calls _call = _Call((args, kwargs), two=True) self.call_args = _call self.call_args_list.append(_call) - self.mock_calls.append(_Call(('', args, kwargs))) - seen = set() - skip_next_dot = _new_name == '()' + # initial stuff for method_calls: do_method_calls = self._mock_parent is not None - name = self._mock_name - while _new_parent is not None: - this_mock_call = _Call((_new_name, args, kwargs)) - if _new_parent._mock_new_name: - dot = '.' - if skip_next_dot: - dot = '' + method_call_name = self._mock_name - skip_next_dot = False - if _new_parent._mock_new_name == '()': - skip_next_dot = True + # initial stuff for mock_calls: + mock_call_name = self._mock_new_name + is_a_call = mock_call_name == '()' + self.mock_calls.append(_Call(('', args, kwargs))) - _new_name = _new_parent._mock_new_name + dot + _new_name + # follow up the chain of mocks: + _new_parent = self._mock_new_parent + while _new_parent is not None: + # handle method_calls: if do_method_calls: - if _new_name == name: - this_method_call = this_mock_call - else: - this_method_call = _Call((name, args, kwargs)) - _new_parent.method_calls.append(this_method_call) - + _new_parent.method_calls.append(_Call((method_call_name, args, kwargs))) do_method_calls = _new_parent._mock_parent is not None if do_method_calls: - name = _new_parent._mock_name + '.' + name + method_call_name = _new_parent._mock_name + '.' + method_call_name + # handle mock_calls: + this_mock_call = _Call((mock_call_name, args, kwargs)) _new_parent.mock_calls.append(this_mock_call) + + if _new_parent._mock_new_name: + if is_a_call: + dot = '' + else: + dot = '.' + is_a_call = _new_parent._mock_new_name == '()' + mock_call_name = _new_parent._mock_new_name + dot + mock_call_name + + # follow the parental chain: _new_parent = _new_parent._mock_new_parent - # use ids here so as not to call __hash__ on the mocks - _new_parent_id = id(_new_parent) - if _new_parent_id in seen: - break - seen.add(_new_parent_id) + def _execute_mock_call(self, /, *args, **kwargs): + # separate from _increment_mock_call so that awaited functions are + # executed separately from their call, also AsyncMock overrides this method - ret_val = DEFAULT effect = self.side_effect if effect is not None: if _is_exception(effect): raise effect - - if not _callable(effect): + elif not _callable(effect): result = next(effect) if _is_exception(result): raise result - if result is DEFAULT: - result = self.return_value + else: + result = effect(*args, **kwargs) + + if result is not DEFAULT: return result - ret_val = effect(*args, **kwargs) + if self._mock_return_value is not DEFAULT: + return self.return_value - if (self._mock_wraps is not None and - self._mock_return_value is DEFAULT): + if self._mock_wraps is not None: return self._mock_wraps(*args, **kwargs) - if ret_val is DEFAULT: - ret_val = self.return_value - return ret_val + + return self.return_value @@ -1077,7 +1210,6 @@ class or instance) that acts as the specification for the mock object. If """ - def _dot_lookup(thing, comp, import_path): try: return getattr(thing, comp) @@ -1097,11 +1229,6 @@ def _importer(target): return thing -def _is_started(patcher): - # XXXX horrible - return hasattr(patcher, 'is_local') - - class _patch(object): attribute_name = None @@ -1150,6 +1277,8 @@ def copy(self): def __call__(self, func): if isinstance(func, type): return self.decorate_class(func) + if inspect.iscoroutinefunction(func): + return self.decorate_async_callable(func) return self.decorate_callable(func) @@ -1167,41 +1296,50 @@ def decorate_class(self, klass): return klass + @contextlib.contextmanager + def decoration_helper(self, patched, args, keywargs): + extra_args = [] + with contextlib.ExitStack() as exit_stack: + for patching in patched.patchings: + arg = exit_stack.enter_context(patching) + if patching.attribute_name is not None: + keywargs.update(arg) + elif patching.new is DEFAULT: + extra_args.append(arg) + + args += tuple(extra_args) + yield (args, keywargs) + + def decorate_callable(self, func): + # NB. Keep the method in sync with decorate_async_callable() if hasattr(func, 'patchings'): func.patchings.append(self) return func @wraps(func) def patched(*args, **keywargs): - extra_args = [] - entered_patchers = [] + with self.decoration_helper(patched, + args, + keywargs) as (newargs, newkeywargs): + return func(*newargs, **newkeywargs) - exc_info = tuple() - try: - for patching in patched.patchings: - arg = patching.__enter__() - entered_patchers.append(patching) - if patching.attribute_name is not None: - keywargs.update(arg) - elif patching.new is DEFAULT: - extra_args.append(arg) - - args += tuple(extra_args) - return func(*args, **keywargs) - except: - if (patching not in entered_patchers and - _is_started(patching)): - # the patcher may have been started, but an exception - # raised whilst entering one of its additional_patchers - entered_patchers.append(patching) - # Pass the exception to __exit__ - exc_info = sys.exc_info() - # re-raise the exception - raise - finally: - for patching in reversed(entered_patchers): - patching.__exit__(*exc_info) + patched.patchings = [self] + return patched + + + def decorate_async_callable(self, func): + # NB. Keep the method in sync with decorate_callable() + if hasattr(func, 'patchings'): + func.patchings.append(self) + return func + + @wraps(func) + async def patched(*args, **keywargs): + with self.decoration_helper(patched, + args, + keywargs) as (newargs, newkeywargs): + return await func(*newargs, **newkeywargs) patched.patchings = [self] return patched @@ -1275,8 +1413,10 @@ def __enter__(self): if isinstance(original, type): # If we're patching out a class and there is a spec inherit = True - - Klass = MagicMock + if spec is None and _is_async_obj(original): + Klass = AsyncMock + else: + Klass = MagicMock _kwargs = {} if new_callable is not None: Klass = new_callable @@ -1288,7 +1428,9 @@ def __enter__(self): not_callable = '__call__' not in this_spec else: not_callable = not callable(this_spec) - if not_callable: + if _is_async_obj(this_spec): + Klass = AsyncMock + elif not_callable: Klass = NonCallableMagicMock if spec is not None: @@ -1343,25 +1485,26 @@ def __enter__(self): self.temp_original = original self.is_local = local - setattr(self.target, self.attribute, new_attr) - if self.attribute_name is not None: - extra_args = {} - if self.new is DEFAULT: - extra_args[self.attribute_name] = new - for patching in self.additional_patchers: - arg = patching.__enter__() - if patching.new is DEFAULT: - extra_args.update(arg) - return extra_args - - return new - + self._exit_stack = contextlib.ExitStack() + try: + setattr(self.target, self.attribute, new_attr) + if self.attribute_name is not None: + extra_args = {} + if self.new is DEFAULT: + extra_args[self.attribute_name] = new + for patching in self.additional_patchers: + arg = self._exit_stack.enter_context(patching) + if patching.new is DEFAULT: + extra_args.update(arg) + return extra_args + + return new + except: + if not self.__exit__(*sys.exc_info()): + raise def __exit__(self, *exc_info): """Undo the patch.""" - if not _is_started(self): - raise RuntimeError('stop called on unstarted patcher') - if self.is_local and self.temp_original is not DEFAULT: setattr(self.target, self.attribute, self.temp_original) else: @@ -1376,9 +1519,9 @@ def __exit__(self, *exc_info): del self.temp_original del self.is_local del self.target - for patcher in reversed(self.additional_patchers): - if _is_started(patcher): - patcher.__exit__(*exc_info) + exit_stack = self._exit_stack + del self._exit_stack + return exit_stack.__exit__(*exc_info) def start(self): @@ -1394,9 +1537,9 @@ def stop(self): self._active_patches.remove(self) except ValueError: # If the patch hasn't been started this will fail - pass + return None - return self.__exit__() + return self.__exit__(None, None, None) @@ -1428,6 +1571,10 @@ def _patch_object( When used as a class decorator `patch.object` honours `patch.TEST_PREFIX` for choosing which methods to wrap. """ + if type(target) is str: + raise TypeError( + f"{target!r} must be the actual object to be patched, not a str" + ) getter = lambda: target return _patch( getter, attribute, new, spec, create, @@ -1494,8 +1641,9 @@ def patch( is patched with a `new` object. When the function/with statement exits the patch is undone. - If `new` is omitted, then the target is replaced with a - `MagicMock`. If `patch` is used as a decorator and `new` is + If `new` is omitted, then the target is replaced with an + `AsyncMock if the patched object is an async function or a + `MagicMock` otherwise. If `patch` is used as a decorator and `new` is omitted, the created mock is passed in as an extra argument to the decorated function. If `patch` is used as a context manager the created mock is returned by the context manager. @@ -1513,8 +1661,8 @@ def patch( patch to pass in the object being mocked as the spec/spec_set object. `new_callable` allows you to specify a different class, or callable object, - that will be called to create the `new` object. By default `MagicMock` is - used. + that will be called to create the `new` object. By default `AsyncMock` is + used for async functions and `MagicMock` for the rest. A more powerful form of `spec` is `autospec`. If you set `autospec=True` then the mock will be created with a spec from the object being replaced. @@ -1590,8 +1738,6 @@ class _patch_dict(object): """ def __init__(self, in_dict, values=(), clear=False, **kwargs): - if isinstance(in_dict, str): - in_dict = _importer(in_dict) self.in_dict = in_dict # support any argument supported by dict(...) constructor self.values = dict(values) @@ -1628,10 +1774,13 @@ def decorate_class(self, klass): def __enter__(self): """Patch the dict.""" self._patch_dict() + return self.in_dict def _patch_dict(self): values = self.values + if isinstance(self.in_dict, str): + self.in_dict = _importer(self.in_dict) in_dict = self.in_dict clear = self.clear @@ -1709,8 +1858,10 @@ def _patch_stopall(): # because there is no idivmod "divmod rdivmod neg pos abs invert " "complex int float index " - "trunc floor ceil " + "round trunc floor ceil " "bool next " + "fspath " + "aiter " ) numerics = ( @@ -1734,7 +1885,7 @@ def _patch_stopall(): def _get_method(name, func): "Turns a callable object (like a mock) into a real function" - def method(self, *args, **kw): + def method(self, /, *args, **kw): return func(self, *args, **kw) method.__name__ = name return method @@ -1745,11 +1896,18 @@ def method(self, *args, **kw): ' '.join([magic_methods, numerics, inplace, right]).split() } -_all_magics = _magics | _non_defaults +# Magic methods used for async `with` statements +_async_method_magics = {"__aenter__", "__aexit__", "__anext__"} +# Magic methods that are only used with async calls but are synchronous functions themselves +_sync_async_magics = {"__aiter__"} +_async_magics = _async_method_magics | _sync_async_magics + +_all_sync_magics = _magics | _non_defaults +_all_magics = _all_sync_magics | _async_magics _unsupported_magics = { '__getattr__', '__setattr__', - '__init__', '__new__', '__prepare__' + '__init__', '__new__', '__prepare__', '__instancecheck__', '__subclasscheck__', '__del__' } @@ -1758,6 +1916,7 @@ def method(self, *args, **kw): '__hash__': lambda self: object.__hash__(self), '__str__': lambda self: object.__str__(self), '__sizeof__': lambda self: object.__sizeof__(self), + '__fspath__': lambda self: f"{type(self).__name__}/{self._extract_mock_name()}/{id(self)}", } _return_values = { @@ -1773,6 +1932,7 @@ def method(self, *args, **kw): '__float__': 1.0, '__bool__': True, '__index__': 1, + '__aexit__': False, } @@ -1805,10 +1965,19 @@ def __iter__(): return iter(ret_val) return __iter__ +def _get_async_iter(self): + def __aiter__(): + ret_val = self.__aiter__._mock_return_value + if ret_val is DEFAULT: + return _AsyncIterator(iter([])) + return _AsyncIterator(iter(ret_val)) + return __aiter__ + _side_effect_methods = { '__eq__': _get_eq, '__ne__': _get_ne, '__iter__': _get_iter, + '__aiter__': _get_async_iter } @@ -1819,14 +1988,9 @@ def _set_return_value(mock, method, name): method.return_value = fixed return - return_calulator = _calculate_return_value.get(name) - if return_calulator is not None: - try: - return_value = return_calulator(mock) - except AttributeError: - # XXXX why do we return AttributeError here? - # set it as a side_effect instead? - return_value = AttributeError(name) + return_calculator = _calculate_return_value.get(name) + if return_calculator is not None: + return_value = return_calculator(mock) method.return_value = return_value return @@ -1836,21 +2000,22 @@ def _set_return_value(mock, method, name): -class MagicMixin(object): - def __init__(self, *args, **kw): +class MagicMixin(Base): + def __init__(self, /, *args, **kw): self._mock_set_magics() # make magic work for kwargs in init _safe_super(MagicMixin, self).__init__(*args, **kw) self._mock_set_magics() # fix magic broken by upper level init def _mock_set_magics(self): - these_magics = _magics + orig_magics = _magics | _async_method_magics + these_magics = orig_magics if getattr(self, "_mock_methods", None) is not None: - these_magics = _magics.intersection(self._mock_methods) + these_magics = orig_magics.intersection(self._mock_methods) remove_magics = set() - remove_magics = _magics - these_magics + remove_magics = orig_magics - these_magics for entry in remove_magics: if entry in type(self).__dict__: @@ -1878,6 +2043,11 @@ def mock_add_spec(self, spec, spec_set=False): self._mock_set_magics() +class AsyncMagicMixin(MagicMixin): + def __init__(self, /, *args, **kw): + self._mock_set_magics() # make magic work for kwargs in init + _safe_super(AsyncMagicMixin, self).__init__(*args, **kw) + self._mock_set_magics() # fix magic broken by upper level init class MagicMock(MagicMixin, Mock): """ @@ -1901,15 +2071,11 @@ def mock_add_spec(self, spec, spec_set=False): -class MagicProxy(object): +class MagicProxy(Base): def __init__(self, name, parent): self.name = name self.parent = parent - def __call__(self, *args, **kwargs): - m = self.create_mock() - return m(*args, **kwargs) - def create_mock(self): entry = self.name parent = self.parent @@ -1923,6 +2089,231 @@ def __get__(self, obj, _type=None): return self.create_mock() +class AsyncMockMixin(Base): + await_count = _delegating_property('await_count') + await_args = _delegating_property('await_args') + await_args_list = _delegating_property('await_args_list') + + def __init__(self, /, *args, **kwargs): + super().__init__(*args, **kwargs) + # asyncio.iscoroutinefunction() checks _is_coroutine property to say if an + # object is a coroutine. Without this check it looks to see if it is a + # function/method, which in this case it is not (since it is an + # AsyncMock). + # It is set through __dict__ because when spec_set is True, this + # attribute is likely undefined. + self.__dict__['_is_coroutine'] = asyncio.coroutines._is_coroutine + self.__dict__['_mock_await_count'] = 0 + self.__dict__['_mock_await_args'] = None + self.__dict__['_mock_await_args_list'] = _CallList() + code_mock = NonCallableMock(spec_set=CodeType) + code_mock.co_flags = inspect.CO_COROUTINE + self.__dict__['__code__'] = code_mock + + async def _execute_mock_call(self, /, *args, **kwargs): + # This is nearly just like super(), except for sepcial handling + # of coroutines + + _call = _Call((args, kwargs), two=True) + self.await_count += 1 + self.await_args = _call + self.await_args_list.append(_call) + + effect = self.side_effect + if effect is not None: + if _is_exception(effect): + raise effect + elif not _callable(effect): + try: + result = next(effect) + except StopIteration: + # It is impossible to propogate a StopIteration + # through coroutines because of PEP 479 + raise StopAsyncIteration + if _is_exception(result): + raise result + elif asyncio.iscoroutinefunction(effect): + result = await effect(*args, **kwargs) + else: + result = effect(*args, **kwargs) + + if result is not DEFAULT: + return result + + if self._mock_return_value is not DEFAULT: + return self.return_value + + if self._mock_wraps is not None: + if asyncio.iscoroutinefunction(self._mock_wraps): + return await self._mock_wraps(*args, **kwargs) + return self._mock_wraps(*args, **kwargs) + + return self.return_value + + def assert_awaited(self): + """ + Assert that the mock was awaited at least once. + """ + if self.await_count == 0: + msg = f"Expected {self._mock_name or 'mock'} to have been awaited." + raise AssertionError(msg) + + def assert_awaited_once(self): + """ + Assert that the mock was awaited exactly once. + """ + if not self.await_count == 1: + msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once." + f" Awaited {self.await_count} times.") + raise AssertionError(msg) + + def assert_awaited_with(self, /, *args, **kwargs): + """ + Assert that the last await was with the specified arguments. + """ + if self.await_args is None: + expected = self._format_mock_call_signature(args, kwargs) + raise AssertionError(f'Expected await: {expected}\nNot awaited') + + def _error_message(): + msg = self._format_mock_failure_message(args, kwargs, action='await') + return msg + + expected = self._call_matcher((args, kwargs)) + actual = self._call_matcher(self.await_args) + if expected != actual: + cause = expected if isinstance(expected, Exception) else None + raise AssertionError(_error_message()) from cause + + def assert_awaited_once_with(self, /, *args, **kwargs): + """ + Assert that the mock was awaited exactly once and with the specified + arguments. + """ + if not self.await_count == 1: + msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once." + f" Awaited {self.await_count} times.") + raise AssertionError(msg) + return self.assert_awaited_with(*args, **kwargs) + + def assert_any_await(self, /, *args, **kwargs): + """ + Assert the mock has ever been awaited with the specified arguments. + """ + expected = self._call_matcher((args, kwargs)) + actual = [self._call_matcher(c) for c in self.await_args_list] + if expected not in actual: + cause = expected if isinstance(expected, Exception) else None + expected_string = self._format_mock_call_signature(args, kwargs) + raise AssertionError( + '%s await not found' % expected_string + ) from cause + + def assert_has_awaits(self, calls, any_order=False): + """ + Assert the mock has been awaited with the specified calls. + The :attr:`await_args_list` list is checked for the awaits. + + If `any_order` is False (the default) then the awaits must be + sequential. There can be extra calls before or after the + specified awaits. + + If `any_order` is True then the awaits can be in any order, but + they must all appear in :attr:`await_args_list`. + """ + expected = [self._call_matcher(c) for c in calls] + cause = next((e for e in expected if isinstance(e, Exception)), None) + all_awaits = _CallList(self._call_matcher(c) for c in self.await_args_list) + if not any_order: + if expected not in all_awaits: + if cause is None: + problem = 'Awaits not found.' + else: + problem = ('Error processing expected awaits.\n' + 'Errors: {}').format( + [e if isinstance(e, Exception) else None + for e in expected]) + raise AssertionError( + f'{problem}\n' + f'Expected: {_CallList(calls)}\n' + f'Actual: {self.await_args_list}' + ) from cause + return + + all_awaits = list(all_awaits) + + not_found = [] + for kall in expected: + try: + all_awaits.remove(kall) + except ValueError: + not_found.append(kall) + if not_found: + raise AssertionError( + '%r not all found in await list' % (tuple(not_found),) + ) from cause + + def assert_not_awaited(self): + """ + Assert that the mock was never awaited. + """ + if self.await_count != 0: + msg = (f"Expected {self._mock_name or 'mock'} to not have been awaited." + f" Awaited {self.await_count} times.") + raise AssertionError(msg) + + def reset_mock(self, /, *args, **kwargs): + """ + See :func:`.Mock.reset_mock()` + """ + super().reset_mock(*args, **kwargs) + self.await_count = 0 + self.await_args = None + self.await_args_list = _CallList() + + +class AsyncMock(AsyncMockMixin, AsyncMagicMixin, Mock): + """ + Enhance :class:`Mock` with features allowing to mock + an async function. + + The :class:`AsyncMock` object will behave so the object is + recognized as an async function, and the result of a call is an awaitable: + + >>> mock = AsyncMock() + >>> asyncio.iscoroutinefunction(mock) + True + >>> inspect.isawaitable(mock()) + True + + + The result of ``mock()`` is an async function which will have the outcome + of ``side_effect`` or ``return_value``: + + - if ``side_effect`` is a function, the async function will return the + result of that function, + - if ``side_effect`` is an exception, the async function will raise the + exception, + - if ``side_effect`` is an iterable, the async function will return the + next value of the iterable, however, if the sequence of result is + exhausted, ``StopIteration`` is raised immediately, + - if ``side_effect`` is not defined, the async function will return the + value defined by ``return_value``, hence, by default, the async function + returns a new :class:`AsyncMock` object. + + If the outcome of ``side_effect`` or ``return_value`` is an async function, + the mock async function obtained when the mock object is called will be this + async function itself (and not an async function returning an async + function). + + The test author can also specify a wrapped object with ``wraps``. In this + case, the :class:`Mock` object behavior is the same as with an + :class:`.Mock` object: the wrapped object may have methods + defined as async function functions. + + Based on Martin Richard's asynctest project. + """ + class _ANY(object): "A helper object that compares equal to everything." @@ -1945,7 +2336,7 @@ def _format_call_signature(name, args, kwargs): formatted_args = '' args_string = ', '.join([repr(arg) for arg in args]) kwargs_string = ', '.join([ - '%s=%r' % (key, value) for key, value in sorted(kwargs.items()) + '%s=%r' % (key, value) for key, value in kwargs.items() ]) if args_string: formatted_args = args_string @@ -2011,9 +2402,9 @@ def __new__(cls, value=(), name='', parent=None, two=False, def __init__(self, value=(), name=None, parent=None, two=False, from_kall=True): - self.name = name - self.parent = parent - self.from_kall = from_kall + self._mock_name = name + self._mock_parent = parent + self._mock_from_kall = from_kall def __eq__(self, other): @@ -2030,6 +2421,10 @@ def __eq__(self, other): else: self_name, self_args, self_kwargs = self + if (getattr(self, '_mock_parent', None) and getattr(other, '_mock_parent', None) + and self._mock_parent != other._mock_parent): + return False + other_name = '' if len_other == 0: other_args, other_kwargs = (), {} @@ -2070,30 +2465,52 @@ def __eq__(self, other): __ne__ = object.__ne__ - def __call__(self, *args, **kwargs): - if self.name is None: + def __call__(self, /, *args, **kwargs): + if self._mock_name is None: return _Call(('', args, kwargs), name='()') - name = self.name + '()' - return _Call((self.name, args, kwargs), name=name, parent=self) + name = self._mock_name + '()' + return _Call((self._mock_name, args, kwargs), name=name, parent=self) def __getattr__(self, attr): - if self.name is None: + if self._mock_name is None: return _Call(name=attr, from_kall=False) - name = '%s.%s' % (self.name, attr) + name = '%s.%s' % (self._mock_name, attr) return _Call(name=name, parent=self, from_kall=False) - def count(self, *args, **kwargs): + def __getattribute__(self, attr): + if attr in tuple.__dict__: + raise AttributeError + return tuple.__getattribute__(self, attr) + + + def count(self, /, *args, **kwargs): return self.__getattr__('count')(*args, **kwargs) - def index(self, *args, **kwargs): + def index(self, /, *args, **kwargs): return self.__getattr__('index')(*args, **kwargs) + def _get_call_arguments(self): + if len(self) == 2: + args, kwargs = self + else: + name, args, kwargs = self + + return args, kwargs + + @property + def args(self): + return self._get_call_arguments()[0] + + @property + def kwargs(self): + return self._get_call_arguments()[1] + def __repr__(self): - if not self.from_kall: - name = self.name or 'call' + if not self._mock_from_kall: + name = self._mock_name or 'call' if name.startswith('()'): name = 'call%s' % name return name @@ -2119,16 +2536,15 @@ def call_list(self): vals = [] thing = self while thing is not None: - if thing.from_kall: + if thing._mock_from_kall: vals.append(thing) - thing = thing.parent + thing = thing._mock_parent return _CallList(reversed(vals)) call = _Call(from_kall=False) - def create_autospec(spec, spec_set=False, instance=False, _parent=None, _name=None, **kwargs): """Create a mock object using another object as a spec. Attributes on the @@ -2154,7 +2570,7 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, spec = type(spec) is_type = isinstance(spec, type) - + is_async_func = _is_async_func(spec) _kwargs = {'spec': spec} if spec_set: _kwargs = {'spec_set': spec} @@ -2171,6 +2587,11 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, # descriptors don't have a spec # because we don't know what type they return _kwargs = {} + elif is_async_func: + if instance: + raise RuntimeError("Instance can not be True when create_autospec " + "is mocking an async function") + Klass = AsyncMock elif not _callable(spec): Klass = NonCallableMagicMock elif is_type and instance and not _instance_callable(spec): @@ -2190,6 +2611,8 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, # should only happen at the top level because we don't # recurse for functions mock = _set_signature(mock, spec) + if is_async_func: + _setup_async_mock(mock) else: _check_signature(spec, mock, is_type, instance) @@ -2233,9 +2656,13 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, skipfirst = _must_skip(spec, entry, is_type) kwargs['_eat_self'] = skipfirst - new = MagicMock(parent=parent, name=entry, _new_name=entry, - _new_parent=parent, - **kwargs) + if asyncio.iscoroutinefunction(original): + child_klass = AsyncMock + else: + child_klass = MagicMock + new = child_klass(parent=parent, name=entry, _new_name=entry, + _new_parent=parent, + **kwargs) mock._mock_children[entry] = new _check_signature(original, new, skipfirst=skipfirst) @@ -2266,26 +2693,17 @@ def _must_skip(spec, entry, is_type): continue if isinstance(result, (staticmethod, classmethod)): return False - elif isinstance(getattr(result, '__get__', None), MethodWrapperTypes): + elif isinstance(result, FunctionTypes): # Normal method => skip if looked up on type # (if looked up on instance, self is already skipped) return is_type else: return False - # shouldn't get here unless function is a dynamically provided attribute - # XXXX untested behaviour + # function is a dynamically provided attribute return is_type -def _get_class(obj): - try: - return obj.__class__ - except AttributeError: - # it is possible for objects to have no __class__ - return type(obj) - - class _SpecState(object): def __init__(self, spec, spec_set=False, parent=None, @@ -2305,32 +2723,15 @@ def __init__(self, spec, spec_set=False, parent=None, type(ANY.__eq__), ) -MethodWrapperTypes = ( - type(ANY.__eq__.__get__), -) - file_spec = None -def _iterate_read_data(read_data): - # Helper for mock_open: - # Retrieve lines from read_data via a generator so that separate calls to - # readline, read, and readlines are properly interleaved - sep = b'\n' if isinstance(read_data, bytes) else '\n' - data_as_list = [l + sep for l in read_data.split(sep)] - - if data_as_list[-1] == sep: - # If the last line ended in a newline, the list comprehension will have an - # extra entry that's just a newline. Remove this. - data_as_list = data_as_list[:-1] - else: - # If there wasn't an extra newline by itself, then the file being - # emulated doesn't have a newline to end the last line remove the - # newline that our naive format() added - data_as_list[-1] = data_as_list[-1][:-1] - for line in data_as_list: - yield line +def _to_stream(read_data): + if isinstance(read_data, bytes): + return io.BytesIO(read_data) + else: + return io.StringIO(read_data) def mock_open(mock=None, read_data=''): @@ -2342,28 +2743,38 @@ def mock_open(mock=None, read_data=''): default) then a `MagicMock` will be created for you, with the API limited to methods or attributes available on standard file handles. - `read_data` is a string for the `read` methoddline`, and `readlines` of the + `read_data` is a string for the `read`, `readline` and `readlines` of the file handle to return. This is an empty string by default. """ + _read_data = _to_stream(read_data) + _state = [_read_data, None] + def _readlines_side_effect(*args, **kwargs): if handle.readlines.return_value is not None: return handle.readlines.return_value - return list(_state[0]) + return _state[0].readlines(*args, **kwargs) def _read_side_effect(*args, **kwargs): if handle.read.return_value is not None: return handle.read.return_value - return type(read_data)().join(_state[0]) + return _state[0].read(*args, **kwargs) - def _readline_side_effect(): + def _readline_side_effect(*args, **kwargs): + yield from _iter_side_effect() + while True: + yield _state[0].readline(*args, **kwargs) + + def _iter_side_effect(): if handle.readline.return_value is not None: while True: yield handle.readline.return_value for line in _state[0]: yield line - while True: - yield type(read_data)() + def _next_side_effect(): + if handle.readline.return_value is not None: + return handle.readline.return_value + return next(_state[0]) global file_spec if file_spec is None: @@ -2376,8 +2787,6 @@ def _readline_side_effect(): handle = MagicMock(spec=file_spec) handle.__enter__.return_value = handle - _state = [_iterate_read_data(read_data), None] - handle.write.return_value = None handle.read.return_value = None handle.readline.return_value = None @@ -2387,9 +2796,11 @@ def _readline_side_effect(): _state[1] = _readline_side_effect() handle.readline.side_effect = _state[1] handle.readlines.side_effect = _readlines_side_effect + handle.__iter__.side_effect = _iter_side_effect + handle.__next__.side_effect = _next_side_effect def reset_data(*args, **kwargs): - _state[0] = _iterate_read_data(read_data) + _state[0] = _to_stream(read_data) if handle.readline.side_effect == _state[1]: # Only reset the side effect if the user hasn't overridden it. _state[1] = _readline_side_effect() @@ -2410,25 +2821,24 @@ class PropertyMock(Mock): Fetching a `PropertyMock` instance from an object calls the mock, with no args. Setting it calls the mock with the value being set. """ - def _get_child_mock(self, **kwargs): + def _get_child_mock(self, /, **kwargs): return MagicMock(**kwargs) - def __get__(self, obj, obj_type): + def __get__(self, obj, obj_type=None): return self() def __set__(self, obj, val): self(val) def seal(mock): - """Disable the automatic generation of "submocks" + """Disable the automatic generation of child mocks. Given an input Mock, seals it to ensure no further mocks will be generated when accessing an attribute that was not already defined. - Submocks are defined as all mocks which were created DIRECTLY from the - parent. If a mock is assigned to an attribute of an existing mock, - it is not considered a submock. - + The operation recursively seals the mock passed in, meaning that + the mock itself, any mocks generated by accessing one of its attributes, + and all assigned mocks without a name or spec will be sealed. """ mock._mock_sealed = True for attr in dir(mock): @@ -2440,3 +2850,24 @@ def seal(mock): continue if m._mock_new_parent is mock: seal(m) + + +class _AsyncIterator: + """ + Wraps an iterator in an asynchronous iterator. + """ + def __init__(self, iterator): + self.iterator = iterator + code_mock = NonCallableMock(spec_set=CodeType) + code_mock.co_flags = inspect.CO_ITERABLE_COROUTINE + self.__dict__['__code__'] = code_mock + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self.iterator) + except StopIteration: + pass + raise StopAsyncIteration diff --git a/Lib/unittest/runner.py b/Lib/unittest/runner.py index 2c5ea4ab07..45e7e4c045 100644 --- a/Lib/unittest/runner.py +++ b/Lib/unittest/runner.py @@ -168,7 +168,7 @@ def run(self, test): warnings.filterwarnings('module', category=DeprecationWarning, message=r'Please use assert\w+ instead.') - startTime = time.time() + startTime = time.perf_counter() startTestRun = getattr(result, 'startTestRun', None) if startTestRun is not None: startTestRun() @@ -178,7 +178,7 @@ def run(self, test): stopTestRun = getattr(result, 'stopTestRun', None) if stopTestRun is not None: stopTestRun() - stopTime = time.time() + stopTime = time.perf_counter() timeTaken = stopTime - startTime result.printErrors() if hasattr(result, 'separator2'): diff --git a/Lib/unittest/suite.py b/Lib/unittest/suite.py index 353d4a17b9..41993f9cf6 100644 --- a/Lib/unittest/suite.py +++ b/Lib/unittest/suite.py @@ -166,10 +166,18 @@ def _handleClassSetUp(self, test, result): raise currentClass._classSetupFailed = True className = util.strclass(currentClass) - errorName = 'setUpClass (%s)' % className - self._addClassOrModuleLevelException(result, e, errorName) + self._createClassOrModuleLevelException(result, e, + 'setUpClass', + className) finally: _call_if_exists(result, '_restoreStdout') + if currentClass._classSetupFailed is True: + currentClass.doClassCleanups() + if len(currentClass.tearDown_exceptions) > 0: + for exc in currentClass.tearDown_exceptions: + self._createClassOrModuleLevelException( + result, exc[1], 'setUpClass', className, + info=exc) def _get_previous_module(self, result): previousModule = None @@ -199,21 +207,37 @@ def _handleModuleFixture(self, test, result): try: setUpModule() except Exception as e: + try: + case.doModuleCleanups() + except Exception as exc: + self._createClassOrModuleLevelException(result, exc, + 'setUpModule', + currentModule) if isinstance(result, _DebugResult): raise result._moduleSetUpFailed = True - errorName = 'setUpModule (%s)' % currentModule - self._addClassOrModuleLevelException(result, e, errorName) + self._createClassOrModuleLevelException(result, e, + 'setUpModule', + currentModule) finally: _call_if_exists(result, '_restoreStdout') - def _addClassOrModuleLevelException(self, result, exception, errorName): + def _createClassOrModuleLevelException(self, result, exc, method_name, + parent, info=None): + errorName = f'{method_name} ({parent})' + self._addClassOrModuleLevelException(result, exc, errorName, info) + + def _addClassOrModuleLevelException(self, result, exception, errorName, + info=None): error = _ErrorHolder(errorName) addSkip = getattr(result, 'addSkip', None) if addSkip is not None and isinstance(exception, case.SkipTest): addSkip(error, str(exception)) else: - result.addError(error, sys.exc_info()) + if not info: + result.addError(error, sys.exc_info()) + else: + result.addError(error, info) def _handleModuleTearDown(self, result): previousModule = self._get_previous_module(result) @@ -235,10 +259,17 @@ def _handleModuleTearDown(self, result): except Exception as e: if isinstance(result, _DebugResult): raise - errorName = 'tearDownModule (%s)' % previousModule - self._addClassOrModuleLevelException(result, e, errorName) + self._createClassOrModuleLevelException(result, e, + 'tearDownModule', + previousModule) finally: _call_if_exists(result, '_restoreStdout') + try: + case.doModuleCleanups() + except Exception as e: + self._createClassOrModuleLevelException(result, e, + 'tearDownModule', + previousModule) def _tearDownPreviousClass(self, test, result): previousClass = getattr(result, '_previousTestClass', None) @@ -261,10 +292,19 @@ def _tearDownPreviousClass(self, test, result): if isinstance(result, _DebugResult): raise className = util.strclass(previousClass) - errorName = 'tearDownClass (%s)' % className - self._addClassOrModuleLevelException(result, e, errorName) + self._createClassOrModuleLevelException(result, e, + 'tearDownClass', + className) finally: _call_if_exists(result, '_restoreStdout') + previousClass.doClassCleanups() + if len(previousClass.tearDown_exceptions) > 0: + for exc in previousClass.tearDown_exceptions: + className = util.strclass(previousClass) + self._createClassOrModuleLevelException(result, exc[1], + 'tearDownClass', + className, + info=exc) class _ErrorHolder(object): From 07bfa844493c651eb228337f7e96062a7be6568a Mon Sep 17 00:00:00 2001 From: Padraic Fanning Date: Sun, 5 Sep 2021 22:03:55 -0400 Subject: [PATCH 2/4] Update unittest/test to CPython 3.8 --- Lib/unittest/test/test_async_case.py | 222 ++++ Lib/unittest/test/test_break.py | 15 +- Lib/unittest/test/test_case.py | 60 +- Lib/unittest/test/test_discovery.py | 6 + Lib/unittest/test/test_loader.py | 16 + Lib/unittest/test/test_runner.py | 670 ++++++++++- Lib/unittest/test/test_skipping.py | 11 + Lib/unittest/test/testmock/support.py | 13 +- Lib/unittest/test/testmock/testasync.py | 1045 +++++++++++++++++ Lib/unittest/test/testmock/testcallable.py | 5 +- Lib/unittest/test/testmock/testhelpers.py | 308 +++-- .../test/testmock/testmagicmethods.py | 53 +- Lib/unittest/test/testmock/testmock.py | 640 +++++++++- Lib/unittest/test/testmock/testpatch.py | 242 ++-- Lib/unittest/test/testmock/testsealable.py | 9 +- Lib/unittest/test/testmock/testwith.py | 60 +- 16 files changed, 3068 insertions(+), 307 deletions(-) create mode 100644 Lib/unittest/test/test_async_case.py create mode 100644 Lib/unittest/test/testmock/testasync.py diff --git a/Lib/unittest/test/test_async_case.py b/Lib/unittest/test/test_async_case.py new file mode 100644 index 0000000000..d01864b693 --- /dev/null +++ b/Lib/unittest/test/test_async_case.py @@ -0,0 +1,222 @@ +import asyncio +import unittest + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class TestAsyncCase(unittest.TestCase): + def test_full_cycle(self): + events = [] + + class Test(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.assertEqual(events, []) + events.append('setUp') + + async def asyncSetUp(self): + self.assertEqual(events, ['setUp']) + events.append('asyncSetUp') + + async def test_func(self): + self.assertEqual(events, ['setUp', + 'asyncSetUp']) + events.append('test') + self.addAsyncCleanup(self.on_cleanup) + + async def asyncTearDown(self): + self.assertEqual(events, ['setUp', + 'asyncSetUp', + 'test']) + events.append('asyncTearDown') + + def tearDown(self): + self.assertEqual(events, ['setUp', + 'asyncSetUp', + 'test', + 'asyncTearDown']) + events.append('tearDown') + + async def on_cleanup(self): + self.assertEqual(events, ['setUp', + 'asyncSetUp', + 'test', + 'asyncTearDown', + 'tearDown']) + events.append('cleanup') + + test = Test("test_func") + test.run() + self.assertEqual(events, ['setUp', + 'asyncSetUp', + 'test', + 'asyncTearDown', + 'tearDown', + 'cleanup']) + + def test_exception_in_setup(self): + events = [] + + class Test(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + events.append('asyncSetUp') + raise Exception() + + async def test_func(self): + events.append('test') + self.addAsyncCleanup(self.on_cleanup) + + async def asyncTearDown(self): + events.append('asyncTearDown') + + async def on_cleanup(self): + events.append('cleanup') + + + test = Test("test_func") + test.run() + self.assertEqual(events, ['asyncSetUp']) + + def test_exception_in_test(self): + events = [] + + class Test(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + events.append('asyncSetUp') + + async def test_func(self): + events.append('test') + raise Exception() + self.addAsyncCleanup(self.on_cleanup) + + async def asyncTearDown(self): + events.append('asyncTearDown') + + async def on_cleanup(self): + events.append('cleanup') + + test = Test("test_func") + test.run() + self.assertEqual(events, ['asyncSetUp', 'test', 'asyncTearDown']) + + def test_exception_in_test_after_adding_cleanup(self): + events = [] + + class Test(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + events.append('asyncSetUp') + + async def test_func(self): + events.append('test') + self.addAsyncCleanup(self.on_cleanup) + raise Exception() + + async def asyncTearDown(self): + events.append('asyncTearDown') + + async def on_cleanup(self): + events.append('cleanup') + + test = Test("test_func") + test.run() + self.assertEqual(events, ['asyncSetUp', 'test', 'asyncTearDown', 'cleanup']) + + def test_exception_in_tear_down(self): + events = [] + + class Test(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + events.append('asyncSetUp') + + async def test_func(self): + events.append('test') + self.addAsyncCleanup(self.on_cleanup) + + async def asyncTearDown(self): + events.append('asyncTearDown') + raise Exception() + + async def on_cleanup(self): + events.append('cleanup') + + test = Test("test_func") + test.run() + self.assertEqual(events, ['asyncSetUp', 'test', 'asyncTearDown', 'cleanup']) + + + def test_exception_in_tear_clean_up(self): + events = [] + + class Test(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + events.append('asyncSetUp') + + async def test_func(self): + events.append('test') + self.addAsyncCleanup(self.on_cleanup) + + async def asyncTearDown(self): + events.append('asyncTearDown') + + async def on_cleanup(self): + events.append('cleanup') + raise Exception() + + test = Test("test_func") + test.run() + self.assertEqual(events, ['asyncSetUp', 'test', 'asyncTearDown', 'cleanup']) + + def test_cleanups_interleave_order(self): + events = [] + + class Test(unittest.IsolatedAsyncioTestCase): + async def test_func(self): + self.addAsyncCleanup(self.on_sync_cleanup, 1) + self.addAsyncCleanup(self.on_async_cleanup, 2) + self.addAsyncCleanup(self.on_sync_cleanup, 3) + self.addAsyncCleanup(self.on_async_cleanup, 4) + + async def on_sync_cleanup(self, val): + events.append(f'sync_cleanup {val}') + + async def on_async_cleanup(self, val): + events.append(f'async_cleanup {val}') + + test = Test("test_func") + test.run() + self.assertEqual(events, ['async_cleanup 4', + 'sync_cleanup 3', + 'async_cleanup 2', + 'sync_cleanup 1']) + + def test_base_exception_from_async_method(self): + events = [] + class Test(unittest.IsolatedAsyncioTestCase): + async def test_base(self): + events.append("test_base") + raise BaseException() + events.append("not it") + + async def test_no_err(self): + events.append("test_no_err") + + async def test_cancel(self): + raise asyncio.CancelledError() + + test = Test("test_base") + output = test.run() + self.assertFalse(output.wasSuccessful()) + + test = Test("test_no_err") + test.run() + self.assertEqual(events, ['test_base', 'test_no_err']) + + test = Test("test_cancel") + output = test.run() + self.assertFalse(output.wasSuccessful()) + + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/unittest/test/test_break.py b/Lib/unittest/test/test_break.py index 681bd535cf..db207863ab 100644 --- a/Lib/unittest/test/test_break.py +++ b/Lib/unittest/test/test_break.py @@ -39,16 +39,13 @@ def testInstallHandler(self): def testRegisterResult(self): result = unittest.TestResult() - unittest.registerResult(result) - - for ref in unittest.signals._results: - if ref is result: - break - elif ref is not result: - self.fail("odd object in result set") - else: - self.fail("result not found") + self.assertNotIn(result, unittest.signals._results) + unittest.registerResult(result) + try: + self.assertIn(result, unittest.signals._results) + finally: + unittest.removeResult(result) def testInterruptCaught(self): default_handler = signal.getsignal(signal.SIGINT) diff --git a/Lib/unittest/test/test_case.py b/Lib/unittest/test/test_case.py index 63539ce498..6a2ca12485 100644 --- a/Lib/unittest/test/test_case.py +++ b/Lib/unittest/test/test_case.py @@ -8,6 +8,7 @@ import warnings import weakref import inspect +import types from copy import deepcopy from test import support @@ -425,6 +426,20 @@ def test_c(self): expected = ['a1', 'a2', 'b1'] self.assertEqual(events, expected) + def test_subtests_debug(self): + # Test debug() with a test that uses subTest() (bpo-34900) + events = [] + + class Foo(unittest.TestCase): + def test_a(self): + events.append('test case') + with self.subTest(): + events.append('subtest 1') + + Foo('test_a').debug() + + self.assertEqual(events, ['test case', 'subtest 1']) + # "This class attribute gives the exception raised by the test() method. # If a test framework needs to use a specialized exception, possibly to # carry additional information, it must subclass this exception in @@ -596,6 +611,15 @@ def testShortDescriptionWithMultiLineDocstring(self): 'Tests shortDescription() for a method with a longer ' 'docstring.') + def testShortDescriptionWhitespaceTrimming(self): + """ + Tests shortDescription() whitespace is trimmed, so that the first + line of nonwhite-space text becomes the docstring. + """ + self.assertEqual( + self.shortDescription(), + 'Tests shortDescription() whitespace is trimmed, so that the first') + def testAddTypeEqualityFunc(self): class SadSnake(object): """Dummy class for test_addTypeEqualityFunc.""" @@ -606,7 +630,7 @@ def AllSnakesCreatedEqual(a, b, msg=None): self.addTypeEqualityFunc(SadSnake, AllSnakesCreatedEqual) self.assertEqual(s1, s2) # No this doesn't clean up and remove the SadSnake equality func - # from this TestCase instance but since its a local nothing else + # from this TestCase instance but since it's local nothing else # will ever notice that. def testAssertIs(self): @@ -1224,7 +1248,7 @@ def Stub(): with self.assertRaises(self.failureException): self.assertRaises(ExceptionMock, lambda: 0) # Failure when the function is None - with self.assertWarns(DeprecationWarning): + with self.assertRaises(TypeError): self.assertRaises(ExceptionMock, None) # Failure when another exception is raised with self.assertRaises(ExceptionMock): @@ -1255,8 +1279,7 @@ def Stub(): with self.assertRaises(ExceptionMock, msg='foobar'): pass # Invalid keyword argument - with self.assertWarnsRegex(DeprecationWarning, 'foobar'), \ - self.assertRaises(AssertionError): + with self.assertRaisesRegex(TypeError, 'foobar'): with self.assertRaises(ExceptionMock, foobar=42): pass # Failure when another exception is raised @@ -1297,7 +1320,7 @@ def Stub(): self.assertRaisesRegex(ExceptionMock, re.compile('expect$'), Stub) self.assertRaisesRegex(ExceptionMock, 'expect$', Stub) - with self.assertWarns(DeprecationWarning): + with self.assertRaises(TypeError): self.assertRaisesRegex(ExceptionMock, 'expect$', None) def testAssertNotRaisesRegex(self): @@ -1314,8 +1337,7 @@ def testAssertNotRaisesRegex(self): with self.assertRaisesRegex(Exception, 'expect', msg='foobar'): pass # Invalid keyword argument - with self.assertWarnsRegex(DeprecationWarning, 'foobar'), \ - self.assertRaises(AssertionError): + with self.assertRaisesRegex(TypeError, 'foobar'): with self.assertRaisesRegex(Exception, 'expect', foobar=42): pass @@ -1331,6 +1353,20 @@ class MyWarn(Warning): pass self.assertRaises(TypeError, self.assertWarnsRegex, MyWarn, lambda: True) + def testAssertWarnsModifySysModules(self): + # bpo-29620: handle modified sys.modules during iteration + class Foo(types.ModuleType): + @property + def __warningregistry__(self): + sys.modules['@bar@'] = 'bar' + + sys.modules['@foo@'] = Foo('foo') + try: + self.assertWarns(UserWarning, warnings.warn, 'expected') + finally: + del sys.modules['@foo@'] + del sys.modules['@bar@'] + def testAssertRaisesRegexMismatch(self): def Stub(): raise Exception('Unexpected') @@ -1390,7 +1426,7 @@ def _runtime_warn(): with self.assertRaises(self.failureException): self.assertWarns(RuntimeWarning, lambda: 0) # Failure when the function is None - with self.assertWarns(DeprecationWarning): + with self.assertRaises(TypeError): self.assertWarns(RuntimeWarning, None) # Failure when another warning is triggered with warnings.catch_warnings(): @@ -1436,8 +1472,7 @@ def _runtime_warn(): with self.assertWarns(RuntimeWarning, msg='foobar'): pass # Invalid keyword argument - with self.assertWarnsRegex(DeprecationWarning, 'foobar'), \ - self.assertRaises(AssertionError): + with self.assertRaisesRegex(TypeError, 'foobar'): with self.assertWarns(RuntimeWarning, foobar=42): pass # Failure when another warning is triggered @@ -1478,7 +1513,7 @@ def _runtime_warn(msg): self.assertWarnsRegex(RuntimeWarning, "o+", lambda: 0) # Failure when the function is None - with self.assertWarns(DeprecationWarning): + with self.assertRaises(TypeError): self.assertWarnsRegex(RuntimeWarning, "o+", None) # Failure when another warning is triggered with warnings.catch_warnings(): @@ -1522,8 +1557,7 @@ def _runtime_warn(msg): with self.assertWarnsRegex(RuntimeWarning, 'o+', msg='foobar'): pass # Invalid keyword argument - with self.assertWarnsRegex(DeprecationWarning, 'foobar'), \ - self.assertRaises(AssertionError): + with self.assertRaisesRegex(TypeError, 'foobar'): with self.assertWarnsRegex(RuntimeWarning, 'o+', foobar=42): pass # Failure when another warning is triggered diff --git a/Lib/unittest/test/test_discovery.py b/Lib/unittest/test/test_discovery.py index 204043b493..16e081e1fb 100644 --- a/Lib/unittest/test/test_discovery.py +++ b/Lib/unittest/test/test_discovery.py @@ -723,11 +723,13 @@ class Module(object): original_listdir = os.listdir original_isfile = os.path.isfile original_isdir = os.path.isdir + original_realpath = os.path.realpath def cleanup(): os.listdir = original_listdir os.path.isfile = original_isfile os.path.isdir = original_isdir + os.path.realpath = original_realpath del sys.modules['foo'] if full_path in sys.path: sys.path.remove(full_path) @@ -742,6 +744,10 @@ def isdir(_): os.listdir = listdir os.path.isfile = isfile os.path.isdir = isdir + if os.name == 'nt': + # ntpath.realpath may inject path prefixes when failing to + # resolve real files, so we substitute abspath() here instead. + os.path.realpath = os.path.abspath return full_path def test_detect_module_clash(self): diff --git a/Lib/unittest/test/test_loader.py b/Lib/unittest/test/test_loader.py index bfd722940b..bc54bf0553 100644 --- a/Lib/unittest/test/test_loader.py +++ b/Lib/unittest/test/test_loader.py @@ -1,3 +1,4 @@ +import functools import sys import types import warnings @@ -1575,5 +1576,20 @@ def test_suiteClass__default_value(self): self.assertIs(loader.suiteClass, unittest.TestSuite) + def test_partial_functions(self): + def noop(arg): + pass + + class Foo(unittest.TestCase): + pass + + setattr(Foo, 'test_partial', functools.partial(noop, None)) + + loader = unittest.TestLoader() + + test_names = ['test_partial'] + self.assertEqual(loader.getTestCaseNames(Foo), test_names) + + if __name__ == "__main__": unittest.main() diff --git a/Lib/unittest/test/test_runner.py b/Lib/unittest/test/test_runner.py index 3759696043..58ac98c2d0 100644 --- a/Lib/unittest/test/test_runner.py +++ b/Lib/unittest/test/test_runner.py @@ -11,8 +11,41 @@ ResultWithNoStartTestRunStopTestRun) -class TestCleanUp(unittest.TestCase): +def resultFactory(*_): + return unittest.TestResult() + + +def getRunner(): + return unittest.TextTestRunner(resultclass=resultFactory, + stream=io.StringIO()) + + +def runTests(*cases): + suite = unittest.TestSuite() + for case in cases: + tests = unittest.defaultTestLoader.loadTestsFromTestCase(case) + suite.addTests(tests) + + runner = getRunner() + # creating a nested suite exposes some potential bugs + realSuite = unittest.TestSuite() + realSuite.addTest(suite) + # adding empty suites to the end exposes potential bugs + suite.addTest(unittest.TestSuite()) + realSuite.addTest(unittest.TestSuite()) + return runner.run(realSuite) + + +def cleanup(ordering, blowUp=False): + if not blowUp: + ordering.append('cleanup_good') + else: + ordering.append('cleanup_exc') + raise Exception('CleanUpExc') + + +class TestCleanUp(unittest.TestCase): def testCleanUp(self): class TestableTest(unittest.TestCase): def testNothing(self): @@ -47,10 +80,10 @@ def testNothing(self): test = TestableTest('testNothing') outcome = test._outcome = _Outcome() - exc1 = Exception('foo') + CleanUpExc = Exception('foo') exc2 = Exception('bar') def cleanup1(): - raise exc1 + raise CleanUpExc def cleanup2(): raise exc2 @@ -63,7 +96,7 @@ def cleanup2(): ((_, (Type1, instance1, _)), (_, (Type2, instance2, _))) = reversed(outcome.errors) - self.assertEqual((Type1, instance1), (Exception, exc1)) + self.assertEqual((Type1, instance1), (Exception, CleanUpExc)) self.assertEqual((Type2, instance2), (Exception, exc2)) def testCleanupInRun(self): @@ -135,6 +168,634 @@ def cleanup2(): self.assertEqual(ordering, ['setUp', 'test', 'tearDown', 'cleanup1', 'cleanup2']) +class TestClassCleanup(unittest.TestCase): + def test_addClassCleanUp(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + test = TestableTest('testNothing') + self.assertEqual(test._class_cleanups, []) + class_cleanups = [] + + def class_cleanup1(*args, **kwargs): + class_cleanups.append((3, args, kwargs)) + + def class_cleanup2(*args, **kwargs): + class_cleanups.append((4, args, kwargs)) + + TestableTest.addClassCleanup(class_cleanup1, 1, 2, 3, + four='hello', five='goodbye') + TestableTest.addClassCleanup(class_cleanup2) + + self.assertEqual(test._class_cleanups, + [(class_cleanup1, (1, 2, 3), + dict(four='hello', five='goodbye')), + (class_cleanup2, (), {})]) + + TestableTest.doClassCleanups() + self.assertEqual(class_cleanups, [(4, (), {}), (3, (1, 2, 3), + dict(four='hello', five='goodbye'))]) + + def test_run_class_cleanUp(self): + ordering = [] + blowUp = True + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + cls.addClassCleanup(cleanup, ordering) + if blowUp: + raise Exception() + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + runTests(TestableTest) + self.assertEqual(ordering, ['setUpClass', 'cleanup_good']) + + ordering = [] + blowUp = False + runTests(TestableTest) + self.assertEqual(ordering, + ['setUpClass', 'test', 'tearDownClass', 'cleanup_good']) + + def test_debug_executes_classCleanUp(self): + ordering = [] + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + cls.addClassCleanup(cleanup, ordering) + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + suite = unittest.defaultTestLoader.loadTestsFromTestCase(TestableTest) + suite.debug() + self.assertEqual(ordering, + ['setUpClass', 'test', 'tearDownClass', 'cleanup_good']) + + def test_doClassCleanups_with_errors_addClassCleanUp(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + + def cleanup1(): + raise Exception('cleanup1') + + def cleanup2(): + raise Exception('cleanup2') + + TestableTest.addClassCleanup(cleanup1) + TestableTest.addClassCleanup(cleanup2) + with self.assertRaises(Exception) as e: + TestableTest.doClassCleanups() + self.assertEqual(e, 'cleanup1') + + def test_with_errors_addCleanUp(self): + ordering = [] + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + cls.addClassCleanup(cleanup, ordering) + def setUp(self): + ordering.append('setUp') + self.addCleanup(cleanup, ordering, blowUp=True) + def testNothing(self): + pass + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, + ['setUpClass', 'setUp', 'cleanup_exc', + 'tearDownClass', 'cleanup_good']) + + def test_run_with_errors_addClassCleanUp(self): + ordering = [] + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + cls.addClassCleanup(cleanup, ordering, blowUp=True) + def setUp(self): + ordering.append('setUp') + self.addCleanup(cleanup, ordering) + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, + ['setUpClass', 'setUp', 'test', 'cleanup_good', + 'tearDownClass', 'cleanup_exc']) + + def test_with_errors_in_addClassCleanup_and_setUps(self): + ordering = [] + class_blow_up = False + method_blow_up = False + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + cls.addClassCleanup(cleanup, ordering, blowUp=True) + if class_blow_up: + raise Exception('ClassExc') + def setUp(self): + ordering.append('setUp') + if method_blow_up: + raise Exception('MethodExc') + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, + ['setUpClass', 'setUp', 'test', + 'tearDownClass', 'cleanup_exc']) + ordering = [] + class_blow_up = True + method_blow_up = False + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: ClassExc') + self.assertEqual(result.errors[1][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, + ['setUpClass', 'cleanup_exc']) + + ordering = [] + class_blow_up = False + method_blow_up = True + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: MethodExc') + self.assertEqual(result.errors[1][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, + ['setUpClass', 'setUp', 'tearDownClass', + 'cleanup_exc']) + + +class TestModuleCleanUp(unittest.TestCase): + def test_add_and_do_ModuleCleanup(self): + module_cleanups = [] + + def module_cleanup1(*args, **kwargs): + module_cleanups.append((3, args, kwargs)) + + def module_cleanup2(*args, **kwargs): + module_cleanups.append((4, args, kwargs)) + + class Module(object): + unittest.addModuleCleanup(module_cleanup1, 1, 2, 3, + four='hello', five='goodbye') + unittest.addModuleCleanup(module_cleanup2) + + self.assertEqual(unittest.case._module_cleanups, + [(module_cleanup1, (1, 2, 3), + dict(four='hello', five='goodbye')), + (module_cleanup2, (), {})]) + + unittest.case.doModuleCleanups() + self.assertEqual(module_cleanups, [(4, (), {}), (3, (1, 2, 3), + dict(four='hello', five='goodbye'))]) + self.assertEqual(unittest.case._module_cleanups, []) + + def test_doModuleCleanup_with_errors_in_addModuleCleanup(self): + module_cleanups = [] + + def module_cleanup_good(*args, **kwargs): + module_cleanups.append((3, args, kwargs)) + + def module_cleanup_bad(*args, **kwargs): + raise Exception('CleanUpExc') + + class Module(object): + unittest.addModuleCleanup(module_cleanup_good, 1, 2, 3, + four='hello', five='goodbye') + unittest.addModuleCleanup(module_cleanup_bad) + self.assertEqual(unittest.case._module_cleanups, + [(module_cleanup_good, (1, 2, 3), + dict(four='hello', five='goodbye')), + (module_cleanup_bad, (), {})]) + with self.assertRaises(Exception) as e: + unittest.case.doModuleCleanups() + self.assertEqual(str(e.exception), 'CleanUpExc') + self.assertEqual(unittest.case._module_cleanups, []) + + def test_addModuleCleanup_arg_errors(self): + cleanups = [] + def cleanup(*args, **kwargs): + cleanups.append((args, kwargs)) + + class Module(object): + unittest.addModuleCleanup(cleanup, 1, 2, function='hello') + with self.assertRaises(TypeError): + unittest.addModuleCleanup(function=cleanup, arg='hello') + with self.assertRaises(TypeError): + unittest.addModuleCleanup() + unittest.case.doModuleCleanups() + self.assertEqual(cleanups, + [((1, 2), {'function': 'hello'})]) + + def test_run_module_cleanUp(self): + blowUp = True + ordering = [] + class Module(object): + @staticmethod + def setUpModule(): + ordering.append('setUpModule') + unittest.addModuleCleanup(cleanup, ordering) + if blowUp: + raise Exception('setUpModule Exc') + @staticmethod + def tearDownModule(): + ordering.append('tearDownModule') + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + TestableTest.__module__ = 'Module' + sys.modules['Module'] = Module + result = runTests(TestableTest) + self.assertEqual(ordering, ['setUpModule', 'cleanup_good']) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: setUpModule Exc') + + ordering = [] + blowUp = False + runTests(TestableTest) + self.assertEqual(ordering, + ['setUpModule', 'setUpClass', 'test', 'tearDownClass', + 'tearDownModule', 'cleanup_good']) + self.assertEqual(unittest.case._module_cleanups, []) + + def test_run_multiple_module_cleanUp(self): + blowUp = True + blowUp2 = False + ordering = [] + class Module1(object): + @staticmethod + def setUpModule(): + ordering.append('setUpModule') + unittest.addModuleCleanup(cleanup, ordering) + if blowUp: + raise Exception() + @staticmethod + def tearDownModule(): + ordering.append('tearDownModule') + + class Module2(object): + @staticmethod + def setUpModule(): + ordering.append('setUpModule2') + unittest.addModuleCleanup(cleanup, ordering) + if blowUp2: + raise Exception() + @staticmethod + def tearDownModule(): + ordering.append('tearDownModule2') + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + class TestableTest2(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass2') + def testNothing(self): + ordering.append('test2') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass2') + + TestableTest.__module__ = 'Module1' + sys.modules['Module1'] = Module1 + TestableTest2.__module__ = 'Module2' + sys.modules['Module2'] = Module2 + runTests(TestableTest, TestableTest2) + self.assertEqual(ordering, ['setUpModule', 'cleanup_good', + 'setUpModule2', 'setUpClass2', 'test2', + 'tearDownClass2', 'tearDownModule2', + 'cleanup_good']) + ordering = [] + blowUp = False + blowUp2 = True + runTests(TestableTest, TestableTest2) + self.assertEqual(ordering, ['setUpModule', 'setUpClass', 'test', + 'tearDownClass', 'tearDownModule', + 'cleanup_good', 'setUpModule2', + 'cleanup_good']) + + ordering = [] + blowUp = False + blowUp2 = False + runTests(TestableTest, TestableTest2) + self.assertEqual(ordering, + ['setUpModule', 'setUpClass', 'test', 'tearDownClass', + 'tearDownModule', 'cleanup_good', 'setUpModule2', + 'setUpClass2', 'test2', 'tearDownClass2', + 'tearDownModule2', 'cleanup_good']) + self.assertEqual(unittest.case._module_cleanups, []) + + def test_debug_module_executes_cleanUp(self): + ordering = [] + class Module(object): + @staticmethod + def setUpModule(): + ordering.append('setUpModule') + unittest.addModuleCleanup(cleanup, ordering) + @staticmethod + def tearDownModule(): + ordering.append('tearDownModule') + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + TestableTest.__module__ = 'Module' + sys.modules['Module'] = Module + suite = unittest.defaultTestLoader.loadTestsFromTestCase(TestableTest) + suite.debug() + self.assertEqual(ordering, + ['setUpModule', 'setUpClass', 'test', 'tearDownClass', + 'tearDownModule', 'cleanup_good']) + self.assertEqual(unittest.case._module_cleanups, []) + + def test_addClassCleanup_arg_errors(self): + cleanups = [] + def cleanup(*args, **kwargs): + cleanups.append((args, kwargs)) + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.addClassCleanup(cleanup, 1, 2, function=3, cls=4) + with self.assertRaises(TypeError): + cls.addClassCleanup(function=cleanup, arg='hello') + def testNothing(self): + pass + + with self.assertRaises(TypeError): + TestableTest.addClassCleanup() + with self.assertRaises(TypeError): + unittest.TestCase.addCleanup(cls=TestableTest(), function=cleanup) + runTests(TestableTest) + self.assertEqual(cleanups, + [((1, 2), {'function': 3, 'cls': 4})]) + + def test_addCleanup_arg_errors(self): + cleanups = [] + def cleanup(*args, **kwargs): + cleanups.append((args, kwargs)) + + class TestableTest(unittest.TestCase): + def setUp(self2): + self2.addCleanup(cleanup, 1, 2, function=3, self=4) + with self.assertWarns(DeprecationWarning): + self2.addCleanup(function=cleanup, arg='hello') + def testNothing(self): + pass + + with self.assertRaises(TypeError): + TestableTest().addCleanup() + with self.assertRaises(TypeError): + unittest.TestCase.addCleanup(self=TestableTest(), function=cleanup) + runTests(TestableTest) + self.assertEqual(cleanups, + [((), {'arg': 'hello'}), + ((1, 2), {'function': 3, 'self': 4})]) + + def test_with_errors_in_addClassCleanup(self): + ordering = [] + + class Module(object): + @staticmethod + def setUpModule(): + ordering.append('setUpModule') + unittest.addModuleCleanup(cleanup, ordering) + @staticmethod + def tearDownModule(): + ordering.append('tearDownModule') + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + cls.addClassCleanup(cleanup, ordering, blowUp=True) + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + TestableTest.__module__ = 'Module' + sys.modules['Module'] = Module + + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, + ['setUpModule', 'setUpClass', 'test', 'tearDownClass', + 'cleanup_exc', 'tearDownModule', 'cleanup_good']) + + def test_with_errors_in_addCleanup(self): + ordering = [] + class Module(object): + @staticmethod + def setUpModule(): + ordering.append('setUpModule') + unittest.addModuleCleanup(cleanup, ordering) + @staticmethod + def tearDownModule(): + ordering.append('tearDownModule') + + class TestableTest(unittest.TestCase): + def setUp(self): + ordering.append('setUp') + self.addCleanup(cleanup, ordering, blowUp=True) + def testNothing(self): + ordering.append('test') + def tearDown(self): + ordering.append('tearDown') + + TestableTest.__module__ = 'Module' + sys.modules['Module'] = Module + + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, + ['setUpModule', 'setUp', 'test', 'tearDown', + 'cleanup_exc', 'tearDownModule', 'cleanup_good']) + + def test_with_errors_in_addModuleCleanup_and_setUps(self): + ordering = [] + module_blow_up = False + class_blow_up = False + method_blow_up = False + class Module(object): + @staticmethod + def setUpModule(): + ordering.append('setUpModule') + unittest.addModuleCleanup(cleanup, ordering, blowUp=True) + if module_blow_up: + raise Exception('ModuleExc') + @staticmethod + def tearDownModule(): + ordering.append('tearDownModule') + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + if class_blow_up: + raise Exception('ClassExc') + def setUp(self): + ordering.append('setUp') + if method_blow_up: + raise Exception('MethodExc') + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + TestableTest.__module__ = 'Module' + sys.modules['Module'] = Module + + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, + ['setUpModule', 'setUpClass', 'setUp', 'test', + 'tearDownClass', 'tearDownModule', + 'cleanup_exc']) + + ordering = [] + module_blow_up = True + class_blow_up = False + method_blow_up = False + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(result.errors[1][1].splitlines()[-1], + 'Exception: ModuleExc') + self.assertEqual(ordering, ['setUpModule', 'cleanup_exc']) + + ordering = [] + module_blow_up = False + class_blow_up = True + method_blow_up = False + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: ClassExc') + self.assertEqual(result.errors[1][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, ['setUpModule', 'setUpClass', + 'tearDownModule', 'cleanup_exc']) + + ordering = [] + module_blow_up = False + class_blow_up = False + method_blow_up = True + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: MethodExc') + self.assertEqual(result.errors[1][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, ['setUpModule', 'setUpClass', 'setUp', + 'tearDownClass', 'tearDownModule', + 'cleanup_exc']) + + def test_module_cleanUp_with_multiple_classes(self): + ordering =[] + def cleanup1(): + ordering.append('cleanup1') + + def cleanup2(): + ordering.append('cleanup2') + + def cleanup3(): + ordering.append('cleanup3') + + class Module(object): + @staticmethod + def setUpModule(): + ordering.append('setUpModule') + unittest.addModuleCleanup(cleanup1) + @staticmethod + def tearDownModule(): + ordering.append('tearDownModule') + + class TestableTest(unittest.TestCase): + def setUp(self): + ordering.append('setUp') + self.addCleanup(cleanup2) + def testNothing(self): + ordering.append('test') + def tearDown(self): + ordering.append('tearDown') + + class OtherTestableTest(unittest.TestCase): + def setUp(self): + ordering.append('setUp2') + self.addCleanup(cleanup3) + def testNothing(self): + ordering.append('test2') + def tearDown(self): + ordering.append('tearDown2') + + TestableTest.__module__ = 'Module' + OtherTestableTest.__module__ = 'Module' + sys.modules['Module'] = Module + runTests(TestableTest, OtherTestableTest) + self.assertEqual(ordering, + ['setUpModule', 'setUp', 'test', 'tearDown', + 'cleanup2', 'setUp2', 'test2', 'tearDown2', + 'cleanup3', 'tearDownModule', 'cleanup1']) + + class Test_TextTestRunner(unittest.TestCase): """Tests for TextTestRunner.""" @@ -276,7 +937,6 @@ def MockResultClass(*args): expectedresult = (runner.stream, DESCRIPTIONS, VERBOSITY) self.assertEqual(runner._makeResult(), expectedresult) - def test_warnings(self): """ Check that warnings argument of TextTestRunner correctly affects the diff --git a/Lib/unittest/test/test_skipping.py b/Lib/unittest/test/test_skipping.py index 71f7b70e47..1c178a95f7 100644 --- a/Lib/unittest/test/test_skipping.py +++ b/Lib/unittest/test/test_skipping.py @@ -255,6 +255,17 @@ def test_1(self): suite.run(result) self.assertEqual(result.skipped, [(test, "testing")]) + def test_skip_without_reason(self): + class Foo(unittest.TestCase): + @unittest.skip + def test_1(self): + pass + + result = unittest.TestResult() + test = Foo("test_1") + suite = unittest.TestSuite([test]) + suite.run(result) + self.assertEqual(result.skipped, [(test, "")]) if __name__ == "__main__": unittest.main() diff --git a/Lib/unittest/test/testmock/support.py b/Lib/unittest/test/testmock/support.py index 205431adca..49986d65dc 100644 --- a/Lib/unittest/test/testmock/support.py +++ b/Lib/unittest/test/testmock/support.py @@ -1,3 +1,6 @@ +target = {'foo': 'FOO'} + + def is_instance(obj, klass): """Version of is_instance that doesn't access __class__""" return issubclass(type(obj), klass) @@ -6,16 +9,8 @@ def is_instance(obj, klass): class SomeClass(object): class_attribute = None - def wibble(self): - pass + def wibble(self): pass class X(object): pass - - -def examine_warnings(func): - def wrapper(): - with catch_warnings(record=True) as ws: - func(ws) - return wrapper diff --git a/Lib/unittest/test/testmock/testasync.py b/Lib/unittest/test/testmock/testasync.py new file mode 100644 index 0000000000..e84c66c0bd --- /dev/null +++ b/Lib/unittest/test/testmock/testasync.py @@ -0,0 +1,1045 @@ +import asyncio +import inspect +import re +import unittest + +from unittest.mock import (ANY, call, AsyncMock, patch, MagicMock, Mock, + create_autospec, sentinel, _CallList) + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class AsyncClass: + def __init__(self): + pass + async def async_method(self): + pass + def normal_method(self): + pass + + @classmethod + async def async_class_method(cls): + pass + + @staticmethod + async def async_static_method(): + pass + + +class AwaitableClass: + def __await__(self): + yield + +async def async_func(): + pass + +async def async_func_args(a, b, *, c): + pass + +def normal_func(): + pass + +class NormalClass(object): + def a(self): + pass + + +async_foo_name = f'{__name__}.AsyncClass' +normal_foo_name = f'{__name__}.NormalClass' + + +class AsyncPatchDecoratorTest(unittest.TestCase): + def test_is_coroutine_function_patch(self): + @patch.object(AsyncClass, 'async_method') + def test_async(mock_method): + self.assertTrue(asyncio.iscoroutinefunction(mock_method)) + test_async() + + def test_is_async_patch(self): + @patch.object(AsyncClass, 'async_method') + def test_async(mock_method): + m = mock_method() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + @patch(f'{async_foo_name}.async_method') + def test_no_parent_attribute(mock_method): + m = mock_method() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + test_async() + test_no_parent_attribute() + + def test_is_AsyncMock_patch(self): + @patch.object(AsyncClass, 'async_method') + def test_async(mock_method): + self.assertIsInstance(mock_method, AsyncMock) + + test_async() + + def test_is_AsyncMock_patch_staticmethod(self): + @patch.object(AsyncClass, 'async_static_method') + def test_async(mock_method): + self.assertIsInstance(mock_method, AsyncMock) + + test_async() + + def test_is_AsyncMock_patch_classmethod(self): + @patch.object(AsyncClass, 'async_class_method') + def test_async(mock_method): + self.assertIsInstance(mock_method, AsyncMock) + + test_async() + + def test_async_def_patch(self): + @patch(f"{__name__}.async_func", return_value=1) + @patch(f"{__name__}.async_func_args", return_value=2) + async def test_async(func_args_mock, func_mock): + self.assertEqual(func_args_mock._mock_name, "async_func_args") + self.assertEqual(func_mock._mock_name, "async_func") + + self.assertIsInstance(async_func, AsyncMock) + self.assertIsInstance(async_func_args, AsyncMock) + + self.assertEqual(await async_func(), 1) + self.assertEqual(await async_func_args(1, 2, c=3), 2) + + asyncio.run(test_async()) + self.assertTrue(inspect.iscoroutinefunction(async_func)) + + +class AsyncPatchCMTest(unittest.TestCase): + def test_is_async_function_cm(self): + def test_async(): + with patch.object(AsyncClass, 'async_method') as mock_method: + self.assertTrue(asyncio.iscoroutinefunction(mock_method)) + + test_async() + + def test_is_async_cm(self): + def test_async(): + with patch.object(AsyncClass, 'async_method') as mock_method: + m = mock_method() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + test_async() + + def test_is_AsyncMock_cm(self): + def test_async(): + with patch.object(AsyncClass, 'async_method') as mock_method: + self.assertIsInstance(mock_method, AsyncMock) + + test_async() + + def test_async_def_cm(self): + async def test_async(): + with patch(f"{__name__}.async_func", AsyncMock()): + self.assertIsInstance(async_func, AsyncMock) + self.assertTrue(inspect.iscoroutinefunction(async_func)) + + asyncio.run(test_async()) + + +class AsyncMockTest(unittest.TestCase): + def test_iscoroutinefunction_default(self): + mock = AsyncMock() + self.assertTrue(asyncio.iscoroutinefunction(mock)) + + def test_iscoroutinefunction_function(self): + async def foo(): pass + mock = AsyncMock(foo) + self.assertTrue(asyncio.iscoroutinefunction(mock)) + self.assertTrue(inspect.iscoroutinefunction(mock)) + + def test_isawaitable(self): + mock = AsyncMock() + m = mock() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + self.assertIn('assert_awaited', dir(mock)) + + def test_iscoroutinefunction_normal_function(self): + def foo(): pass + mock = AsyncMock(foo) + self.assertTrue(asyncio.iscoroutinefunction(mock)) + self.assertTrue(inspect.iscoroutinefunction(mock)) + + def test_future_isfuture(self): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + fut = asyncio.Future() + loop.stop() + loop.close() + mock = AsyncMock(fut) + self.assertIsInstance(mock, asyncio.Future) + + +class AsyncAutospecTest(unittest.TestCase): + def test_is_AsyncMock_patch(self): + @patch(async_foo_name, autospec=True) + def test_async(mock_method): + self.assertIsInstance(mock_method.async_method, AsyncMock) + self.assertIsInstance(mock_method, MagicMock) + + @patch(async_foo_name, autospec=True) + def test_normal_method(mock_method): + self.assertIsInstance(mock_method.normal_method, MagicMock) + + test_async() + test_normal_method() + + def test_create_autospec_instance(self): + with self.assertRaises(RuntimeError): + create_autospec(async_func, instance=True) + + def test_create_autospec_awaitable_class(self): + awaitable_mock = create_autospec(spec=AwaitableClass()) + self.assertIsInstance(create_autospec(awaitable_mock), AsyncMock) + + def test_create_autospec(self): + spec = create_autospec(async_func_args) + awaitable = spec(1, 2, c=3) + async def main(): + await awaitable + + self.assertEqual(spec.await_count, 0) + self.assertIsNone(spec.await_args) + self.assertEqual(spec.await_args_list, []) + spec.assert_not_awaited() + + asyncio.run(main()) + + self.assertTrue(asyncio.iscoroutinefunction(spec)) + self.assertTrue(asyncio.iscoroutine(awaitable)) + self.assertEqual(spec.await_count, 1) + self.assertEqual(spec.await_args, call(1, 2, c=3)) + self.assertEqual(spec.await_args_list, [call(1, 2, c=3)]) + spec.assert_awaited_once() + spec.assert_awaited_once_with(1, 2, c=3) + spec.assert_awaited_with(1, 2, c=3) + spec.assert_awaited() + + def test_patch_with_autospec(self): + + async def test_async(): + with patch(f"{__name__}.async_func_args", autospec=True) as mock_method: + awaitable = mock_method(1, 2, c=3) + self.assertIsInstance(mock_method.mock, AsyncMock) + + self.assertTrue(asyncio.iscoroutinefunction(mock_method)) + self.assertTrue(asyncio.iscoroutine(awaitable)) + self.assertTrue(inspect.isawaitable(awaitable)) + + # Verify the default values during mock setup + self.assertEqual(mock_method.await_count, 0) + self.assertEqual(mock_method.await_args_list, []) + self.assertIsNone(mock_method.await_args) + mock_method.assert_not_awaited() + + await awaitable + + self.assertEqual(mock_method.await_count, 1) + self.assertEqual(mock_method.await_args, call(1, 2, c=3)) + self.assertEqual(mock_method.await_args_list, [call(1, 2, c=3)]) + mock_method.assert_awaited_once() + mock_method.assert_awaited_once_with(1, 2, c=3) + mock_method.assert_awaited_with(1, 2, c=3) + mock_method.assert_awaited() + + mock_method.reset_mock() + self.assertEqual(mock_method.await_count, 0) + self.assertIsNone(mock_method.await_args) + self.assertEqual(mock_method.await_args_list, []) + + asyncio.run(test_async()) + + +class AsyncSpecTest(unittest.TestCase): + def test_spec_normal_methods_on_class(self): + def inner_test(mock_type): + mock = mock_type(AsyncClass) + self.assertIsInstance(mock.async_method, AsyncMock) + self.assertIsInstance(mock.normal_method, MagicMock) + + for mock_type in [AsyncMock, MagicMock]: + with self.subTest(f"test method types with {mock_type}"): + inner_test(mock_type) + + def test_spec_normal_methods_on_class_with_mock(self): + mock = Mock(AsyncClass) + self.assertIsInstance(mock.async_method, AsyncMock) + self.assertIsInstance(mock.normal_method, Mock) + + def test_spec_mock_type_kw(self): + def inner_test(mock_type): + async_mock = mock_type(spec=async_func) + self.assertIsInstance(async_mock, mock_type) + with self.assertWarns(RuntimeWarning): + # Will raise a warning because never awaited + self.assertTrue(inspect.isawaitable(async_mock())) + + sync_mock = mock_type(spec=normal_func) + self.assertIsInstance(sync_mock, mock_type) + + for mock_type in [AsyncMock, MagicMock, Mock]: + with self.subTest(f"test spec kwarg with {mock_type}"): + inner_test(mock_type) + + def test_spec_mock_type_positional(self): + def inner_test(mock_type): + async_mock = mock_type(async_func) + self.assertIsInstance(async_mock, mock_type) + with self.assertWarns(RuntimeWarning): + # Will raise a warning because never awaited + self.assertTrue(inspect.isawaitable(async_mock())) + + sync_mock = mock_type(normal_func) + self.assertIsInstance(sync_mock, mock_type) + + for mock_type in [AsyncMock, MagicMock, Mock]: + with self.subTest(f"test spec positional with {mock_type}"): + inner_test(mock_type) + + def test_spec_as_normal_kw_AsyncMock(self): + mock = AsyncMock(spec=normal_func) + self.assertIsInstance(mock, AsyncMock) + m = mock() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + def test_spec_as_normal_positional_AsyncMock(self): + mock = AsyncMock(normal_func) + self.assertIsInstance(mock, AsyncMock) + m = mock() + self.assertTrue(inspect.isawaitable(m)) + asyncio.run(m) + + def test_spec_async_mock(self): + @patch.object(AsyncClass, 'async_method', spec=True) + def test_async(mock_method): + self.assertIsInstance(mock_method, AsyncMock) + + test_async() + + def test_spec_parent_not_async_attribute_is(self): + @patch(async_foo_name, spec=True) + def test_async(mock_method): + self.assertIsInstance(mock_method, MagicMock) + self.assertIsInstance(mock_method.async_method, AsyncMock) + + test_async() + + def test_target_async_spec_not(self): + @patch.object(AsyncClass, 'async_method', spec=NormalClass.a) + def test_async_attribute(mock_method): + self.assertIsInstance(mock_method, MagicMock) + self.assertFalse(inspect.iscoroutine(mock_method)) + self.assertFalse(inspect.isawaitable(mock_method)) + + test_async_attribute() + + def test_target_not_async_spec_is(self): + @patch.object(NormalClass, 'a', spec=async_func) + def test_attribute_not_async_spec_is(mock_async_func): + self.assertIsInstance(mock_async_func, AsyncMock) + test_attribute_not_async_spec_is() + + def test_spec_async_attributes(self): + @patch(normal_foo_name, spec=AsyncClass) + def test_async_attributes_coroutines(MockNormalClass): + self.assertIsInstance(MockNormalClass.async_method, AsyncMock) + self.assertIsInstance(MockNormalClass, MagicMock) + + test_async_attributes_coroutines() + + +class AsyncSpecSetTest(unittest.TestCase): + def test_is_AsyncMock_patch(self): + @patch.object(AsyncClass, 'async_method', spec_set=True) + def test_async(async_method): + self.assertIsInstance(async_method, AsyncMock) + + def test_is_async_AsyncMock(self): + mock = AsyncMock(spec_set=AsyncClass.async_method) + self.assertTrue(asyncio.iscoroutinefunction(mock)) + self.assertIsInstance(mock, AsyncMock) + + def test_is_child_AsyncMock(self): + mock = MagicMock(spec_set=AsyncClass) + self.assertTrue(asyncio.iscoroutinefunction(mock.async_method)) + self.assertFalse(asyncio.iscoroutinefunction(mock.normal_method)) + self.assertIsInstance(mock.async_method, AsyncMock) + self.assertIsInstance(mock.normal_method, MagicMock) + self.assertIsInstance(mock, MagicMock) + + def test_magicmock_lambda_spec(self): + mock_obj = MagicMock() + mock_obj.mock_func = MagicMock(spec=lambda x: x) + + with patch.object(mock_obj, "mock_func") as cm: + self.assertIsInstance(cm, MagicMock) + + +class AsyncArguments(unittest.IsolatedAsyncioTestCase): + async def test_add_return_value(self): + async def addition(self, var): + return var + 1 + + mock = AsyncMock(addition, return_value=10) + output = await mock(5) + + self.assertEqual(output, 10) + + async def test_add_side_effect_exception(self): + async def addition(var): + return var + 1 + mock = AsyncMock(addition, side_effect=Exception('err')) + with self.assertRaises(Exception): + await mock(5) + + async def test_add_side_effect_coroutine(self): + async def addition(var): + return var + 1 + mock = AsyncMock(side_effect=addition) + result = await mock(5) + self.assertEqual(result, 6) + + async def test_add_side_effect_normal_function(self): + def addition(var): + return var + 1 + mock = AsyncMock(side_effect=addition) + result = await mock(5) + self.assertEqual(result, 6) + + async def test_add_side_effect_iterable(self): + vals = [1, 2, 3] + mock = AsyncMock(side_effect=vals) + for item in vals: + self.assertEqual(await mock(), item) + + with self.assertRaises(StopAsyncIteration) as e: + await mock() + + async def test_add_side_effect_exception_iterable(self): + class SampleException(Exception): + pass + + vals = [1, SampleException("foo")] + mock = AsyncMock(side_effect=vals) + self.assertEqual(await mock(), 1) + + with self.assertRaises(SampleException) as e: + await mock() + + async def test_return_value_AsyncMock(self): + value = AsyncMock(return_value=10) + mock = AsyncMock(return_value=value) + result = await mock() + self.assertIs(result, value) + + async def test_return_value_awaitable(self): + fut = asyncio.Future() + fut.set_result(None) + mock = AsyncMock(return_value=fut) + result = await mock() + self.assertIsInstance(result, asyncio.Future) + + async def test_side_effect_awaitable_values(self): + fut = asyncio.Future() + fut.set_result(None) + + mock = AsyncMock(side_effect=[fut]) + result = await mock() + self.assertIsInstance(result, asyncio.Future) + + with self.assertRaises(StopAsyncIteration): + await mock() + + async def test_side_effect_is_AsyncMock(self): + effect = AsyncMock(return_value=10) + mock = AsyncMock(side_effect=effect) + + result = await mock() + self.assertEqual(result, 10) + + async def test_wraps_coroutine(self): + value = asyncio.Future() + + ran = False + async def inner(): + nonlocal ran + ran = True + return value + + mock = AsyncMock(wraps=inner) + result = await mock() + self.assertEqual(result, value) + mock.assert_awaited() + self.assertTrue(ran) + + async def test_wraps_normal_function(self): + value = 1 + + ran = False + def inner(): + nonlocal ran + ran = True + return value + + mock = AsyncMock(wraps=inner) + result = await mock() + self.assertEqual(result, value) + mock.assert_awaited() + self.assertTrue(ran) + + async def test_await_args_list_order(self): + async_mock = AsyncMock() + mock2 = async_mock(2) + mock1 = async_mock(1) + await mock1 + await mock2 + async_mock.assert_has_awaits([call(1), call(2)]) + self.assertEqual(async_mock.await_args_list, [call(1), call(2)]) + self.assertEqual(async_mock.call_args_list, [call(2), call(1)]) + + +class AsyncMagicMethods(unittest.TestCase): + def test_async_magic_methods_return_async_mocks(self): + m_mock = MagicMock() + self.assertIsInstance(m_mock.__aenter__, AsyncMock) + self.assertIsInstance(m_mock.__aexit__, AsyncMock) + self.assertIsInstance(m_mock.__anext__, AsyncMock) + # __aiter__ is actually a synchronous object + # so should return a MagicMock + self.assertIsInstance(m_mock.__aiter__, MagicMock) + + def test_sync_magic_methods_return_magic_mocks(self): + a_mock = AsyncMock() + self.assertIsInstance(a_mock.__enter__, MagicMock) + self.assertIsInstance(a_mock.__exit__, MagicMock) + self.assertIsInstance(a_mock.__next__, MagicMock) + self.assertIsInstance(a_mock.__len__, MagicMock) + + def test_magicmock_has_async_magic_methods(self): + m_mock = MagicMock() + self.assertTrue(hasattr(m_mock, "__aenter__")) + self.assertTrue(hasattr(m_mock, "__aexit__")) + self.assertTrue(hasattr(m_mock, "__anext__")) + + def test_asyncmock_has_sync_magic_methods(self): + a_mock = AsyncMock() + self.assertTrue(hasattr(a_mock, "__enter__")) + self.assertTrue(hasattr(a_mock, "__exit__")) + self.assertTrue(hasattr(a_mock, "__next__")) + self.assertTrue(hasattr(a_mock, "__len__")) + + def test_magic_methods_are_async_functions(self): + m_mock = MagicMock() + self.assertIsInstance(m_mock.__aenter__, AsyncMock) + self.assertIsInstance(m_mock.__aexit__, AsyncMock) + # AsyncMocks are also coroutine functions + self.assertTrue(asyncio.iscoroutinefunction(m_mock.__aenter__)) + self.assertTrue(asyncio.iscoroutinefunction(m_mock.__aexit__)) + +class AsyncContextManagerTest(unittest.TestCase): + class WithAsyncContextManager: + async def __aenter__(self, *args, **kwargs): + self.entered = True + return self + + async def __aexit__(self, *args, **kwargs): + self.exited = True + + class WithSyncContextManager: + def __enter__(self, *args, **kwargs): + return self + + def __exit__(self, *args, **kwargs): + pass + + class ProductionCode: + # Example real-world(ish) code + def __init__(self): + self.session = None + + async def main(self): + async with self.session.post('https://python.org') as response: + val = await response.json() + return val + + def test_set_return_value_of_aenter(self): + def inner_test(mock_type): + pc = self.ProductionCode() + pc.session = MagicMock(name='sessionmock') + cm = mock_type(name='magic_cm') + response = AsyncMock(name='response') + response.json = AsyncMock(return_value={'json': 123}) + cm.__aenter__.return_value = response + pc.session.post.return_value = cm + result = asyncio.run(pc.main()) + self.assertEqual(result, {'json': 123}) + + for mock_type in [AsyncMock, MagicMock]: + with self.subTest(f"test set return value of aenter with {mock_type}"): + inner_test(mock_type) + + def test_mock_supports_async_context_manager(self): + def inner_test(mock_type): + called = False + cm = self.WithAsyncContextManager() + cm_mock = mock_type(cm) + + async def use_context_manager(): + nonlocal called + async with cm_mock as result: + called = True + return result + + cm_result = asyncio.run(use_context_manager()) + self.assertTrue(called) + self.assertTrue(cm_mock.__aenter__.called) + self.assertTrue(cm_mock.__aexit__.called) + cm_mock.__aenter__.assert_awaited() + cm_mock.__aexit__.assert_awaited() + # We mock __aenter__ so it does not return self + self.assertIsNot(cm_mock, cm_result) + + for mock_type in [AsyncMock, MagicMock]: + with self.subTest(f"test context manager magics with {mock_type}"): + inner_test(mock_type) + + def test_mock_customize_async_context_manager(self): + instance = self.WithAsyncContextManager() + mock_instance = MagicMock(instance) + + expected_result = object() + mock_instance.__aenter__.return_value = expected_result + + async def use_context_manager(): + async with mock_instance as result: + return result + + self.assertIs(asyncio.run(use_context_manager()), expected_result) + + def test_mock_customize_async_context_manager_with_coroutine(self): + enter_called = False + exit_called = False + + async def enter_coroutine(*args): + nonlocal enter_called + enter_called = True + + async def exit_coroutine(*args): + nonlocal exit_called + exit_called = True + + instance = self.WithAsyncContextManager() + mock_instance = MagicMock(instance) + + mock_instance.__aenter__ = enter_coroutine + mock_instance.__aexit__ = exit_coroutine + + async def use_context_manager(): + async with mock_instance: + pass + + asyncio.run(use_context_manager()) + self.assertTrue(enter_called) + self.assertTrue(exit_called) + + def test_context_manager_raise_exception_by_default(self): + async def raise_in(context_manager): + async with context_manager: + raise TypeError() + + instance = self.WithAsyncContextManager() + mock_instance = MagicMock(instance) + with self.assertRaises(TypeError): + asyncio.run(raise_in(mock_instance)) + + +class AsyncIteratorTest(unittest.TestCase): + class WithAsyncIterator(object): + def __init__(self): + self.items = ["foo", "NormalFoo", "baz"] + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return self.items.pop() + except IndexError: + pass + + raise StopAsyncIteration + + def test_aiter_set_return_value(self): + mock_iter = AsyncMock(name="tester") + mock_iter.__aiter__.return_value = [1, 2, 3] + async def main(): + return [i async for i in mock_iter] + result = asyncio.run(main()) + self.assertEqual(result, [1, 2, 3]) + + def test_mock_aiter_and_anext_asyncmock(self): + def inner_test(mock_type): + instance = self.WithAsyncIterator() + mock_instance = mock_type(instance) + # Check that the mock and the real thing bahave the same + # __aiter__ is not actually async, so not a coroutinefunction + self.assertFalse(asyncio.iscoroutinefunction(instance.__aiter__)) + self.assertFalse(asyncio.iscoroutinefunction(mock_instance.__aiter__)) + # __anext__ is async + self.assertTrue(asyncio.iscoroutinefunction(instance.__anext__)) + self.assertTrue(asyncio.iscoroutinefunction(mock_instance.__anext__)) + + for mock_type in [AsyncMock, MagicMock]: + with self.subTest(f"test aiter and anext corourtine with {mock_type}"): + inner_test(mock_type) + + + def test_mock_async_for(self): + async def iterate(iterator): + accumulator = [] + async for item in iterator: + accumulator.append(item) + + return accumulator + + expected = ["FOO", "BAR", "BAZ"] + def test_default(mock_type): + mock_instance = mock_type(self.WithAsyncIterator()) + self.assertEqual(asyncio.run(iterate(mock_instance)), []) + + + def test_set_return_value(mock_type): + mock_instance = mock_type(self.WithAsyncIterator()) + mock_instance.__aiter__.return_value = expected[:] + self.assertEqual(asyncio.run(iterate(mock_instance)), expected) + + def test_set_return_value_iter(mock_type): + mock_instance = mock_type(self.WithAsyncIterator()) + mock_instance.__aiter__.return_value = iter(expected[:]) + self.assertEqual(asyncio.run(iterate(mock_instance)), expected) + + for mock_type in [AsyncMock, MagicMock]: + with self.subTest(f"default value with {mock_type}"): + test_default(mock_type) + + with self.subTest(f"set return_value with {mock_type}"): + test_set_return_value(mock_type) + + with self.subTest(f"set return_value iterator with {mock_type}"): + test_set_return_value_iter(mock_type) + + +class AsyncMockAssert(unittest.TestCase): + def setUp(self): + self.mock = AsyncMock() + + async def _runnable_test(self, *args, **kwargs): + await self.mock(*args, **kwargs) + + async def _await_coroutine(self, coroutine): + return await coroutine + + def test_assert_called_but_not_awaited(self): + mock = AsyncMock(AsyncClass) + with self.assertWarns(RuntimeWarning): + # Will raise a warning because never awaited + mock.async_method() + self.assertTrue(asyncio.iscoroutinefunction(mock.async_method)) + mock.async_method.assert_called() + mock.async_method.assert_called_once() + mock.async_method.assert_called_once_with() + with self.assertRaises(AssertionError): + mock.assert_awaited() + with self.assertRaises(AssertionError): + mock.async_method.assert_awaited() + + def test_assert_called_then_awaited(self): + mock = AsyncMock(AsyncClass) + mock_coroutine = mock.async_method() + mock.async_method.assert_called() + mock.async_method.assert_called_once() + mock.async_method.assert_called_once_with() + with self.assertRaises(AssertionError): + mock.async_method.assert_awaited() + + asyncio.run(self._await_coroutine(mock_coroutine)) + # Assert we haven't re-called the function + mock.async_method.assert_called_once() + mock.async_method.assert_awaited() + mock.async_method.assert_awaited_once() + mock.async_method.assert_awaited_once_with() + + def test_assert_called_and_awaited_at_same_time(self): + with self.assertRaises(AssertionError): + self.mock.assert_awaited() + + with self.assertRaises(AssertionError): + self.mock.assert_called() + + asyncio.run(self._runnable_test()) + self.mock.assert_called_once() + self.mock.assert_awaited_once() + + def test_assert_called_twice_and_awaited_once(self): + mock = AsyncMock(AsyncClass) + coroutine = mock.async_method() + with self.assertWarns(RuntimeWarning): + # The first call will be awaited so no warning there + # But this call will never get awaited, so it will warn here + mock.async_method() + with self.assertRaises(AssertionError): + mock.async_method.assert_awaited() + mock.async_method.assert_called() + asyncio.run(self._await_coroutine(coroutine)) + mock.async_method.assert_awaited() + mock.async_method.assert_awaited_once() + + def test_assert_called_once_and_awaited_twice(self): + mock = AsyncMock(AsyncClass) + coroutine = mock.async_method() + mock.async_method.assert_called_once() + asyncio.run(self._await_coroutine(coroutine)) + with self.assertRaises(RuntimeError): + # Cannot reuse already awaited coroutine + asyncio.run(self._await_coroutine(coroutine)) + mock.async_method.assert_awaited() + + def test_assert_awaited_but_not_called(self): + with self.assertRaises(AssertionError): + self.mock.assert_awaited() + with self.assertRaises(AssertionError): + self.mock.assert_called() + with self.assertRaises(TypeError): + # You cannot await an AsyncMock, it must be a coroutine + asyncio.run(self._await_coroutine(self.mock)) + + with self.assertRaises(AssertionError): + self.mock.assert_awaited() + with self.assertRaises(AssertionError): + self.mock.assert_called() + + def test_assert_has_calls_not_awaits(self): + kalls = [call('foo')] + with self.assertWarns(RuntimeWarning): + # Will raise a warning because never awaited + self.mock('foo') + self.mock.assert_has_calls(kalls) + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(kalls) + + def test_assert_has_mock_calls_on_async_mock_no_spec(self): + with self.assertWarns(RuntimeWarning): + # Will raise a warning because never awaited + self.mock() + kalls_empty = [('', (), {})] + self.assertEqual(self.mock.mock_calls, kalls_empty) + + with self.assertWarns(RuntimeWarning): + # Will raise a warning because never awaited + self.mock('foo') + self.mock('baz') + mock_kalls = ([call(), call('foo'), call('baz')]) + self.assertEqual(self.mock.mock_calls, mock_kalls) + + def test_assert_has_mock_calls_on_async_mock_with_spec(self): + a_class_mock = AsyncMock(AsyncClass) + with self.assertWarns(RuntimeWarning): + # Will raise a warning because never awaited + a_class_mock.async_method() + kalls_empty = [('', (), {})] + self.assertEqual(a_class_mock.async_method.mock_calls, kalls_empty) + self.assertEqual(a_class_mock.mock_calls, [call.async_method()]) + + with self.assertWarns(RuntimeWarning): + # Will raise a warning because never awaited + a_class_mock.async_method(1, 2, 3, a=4, b=5) + method_kalls = [call(), call(1, 2, 3, a=4, b=5)] + mock_kalls = [call.async_method(), call.async_method(1, 2, 3, a=4, b=5)] + self.assertEqual(a_class_mock.async_method.mock_calls, method_kalls) + self.assertEqual(a_class_mock.mock_calls, mock_kalls) + + def test_async_method_calls_recorded(self): + with self.assertWarns(RuntimeWarning): + # Will raise warnings because never awaited + self.mock.something(3, fish=None) + self.mock.something_else.something(6, cake=sentinel.Cake) + + self.assertEqual(self.mock.method_calls, [ + ("something", (3,), {'fish': None}), + ("something_else.something", (6,), {'cake': sentinel.Cake}) + ], + "method calls not recorded correctly") + self.assertEqual(self.mock.something_else.method_calls, + [("something", (6,), {'cake': sentinel.Cake})], + "method calls not recorded correctly") + + def test_async_arg_lists(self): + def assert_attrs(mock): + names = ('call_args_list', 'method_calls', 'mock_calls') + for name in names: + attr = getattr(mock, name) + self.assertIsInstance(attr, _CallList) + self.assertIsInstance(attr, list) + self.assertEqual(attr, []) + + assert_attrs(self.mock) + with self.assertWarns(RuntimeWarning): + # Will raise warnings because never awaited + self.mock() + self.mock(1, 2) + self.mock(a=3) + + self.mock.reset_mock() + assert_attrs(self.mock) + + a_mock = AsyncMock(AsyncClass) + with self.assertWarns(RuntimeWarning): + # Will raise warnings because never awaited + a_mock.async_method() + a_mock.async_method(1, a=3) + + a_mock.reset_mock() + assert_attrs(a_mock) + + def test_assert_awaited(self): + with self.assertRaises(AssertionError): + self.mock.assert_awaited() + + asyncio.run(self._runnable_test()) + self.mock.assert_awaited() + + def test_assert_awaited_once(self): + with self.assertRaises(AssertionError): + self.mock.assert_awaited_once() + + asyncio.run(self._runnable_test()) + self.mock.assert_awaited_once() + + asyncio.run(self._runnable_test()) + with self.assertRaises(AssertionError): + self.mock.assert_awaited_once() + + def test_assert_awaited_with(self): + msg = 'Not awaited' + with self.assertRaisesRegex(AssertionError, msg): + self.mock.assert_awaited_with('foo') + + asyncio.run(self._runnable_test()) + msg = 'expected await not found' + with self.assertRaisesRegex(AssertionError, msg): + self.mock.assert_awaited_with('foo') + + asyncio.run(self._runnable_test('foo')) + self.mock.assert_awaited_with('foo') + + asyncio.run(self._runnable_test('SomethingElse')) + with self.assertRaises(AssertionError): + self.mock.assert_awaited_with('foo') + + def test_assert_awaited_once_with(self): + with self.assertRaises(AssertionError): + self.mock.assert_awaited_once_with('foo') + + asyncio.run(self._runnable_test('foo')) + self.mock.assert_awaited_once_with('foo') + + asyncio.run(self._runnable_test('foo')) + with self.assertRaises(AssertionError): + self.mock.assert_awaited_once_with('foo') + + def test_assert_any_wait(self): + with self.assertRaises(AssertionError): + self.mock.assert_any_await('foo') + + asyncio.run(self._runnable_test('baz')) + with self.assertRaises(AssertionError): + self.mock.assert_any_await('foo') + + asyncio.run(self._runnable_test('foo')) + self.mock.assert_any_await('foo') + + asyncio.run(self._runnable_test('SomethingElse')) + self.mock.assert_any_await('foo') + + def test_assert_has_awaits_no_order(self): + calls = [call('foo'), call('baz')] + + with self.assertRaises(AssertionError) as cm: + self.mock.assert_has_awaits(calls) + self.assertEqual(len(cm.exception.args), 1) + + asyncio.run(self._runnable_test('foo')) + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls) + + asyncio.run(self._runnable_test('foo')) + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls) + + asyncio.run(self._runnable_test('baz')) + self.mock.assert_has_awaits(calls) + + asyncio.run(self._runnable_test('SomethingElse')) + self.mock.assert_has_awaits(calls) + + def test_assert_has_awaits_ordered(self): + calls = [call('foo'), call('baz')] + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls, any_order=True) + + asyncio.run(self._runnable_test('baz')) + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls, any_order=True) + + asyncio.run(self._runnable_test('bamf')) + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls, any_order=True) + + asyncio.run(self._runnable_test('foo')) + self.mock.assert_has_awaits(calls, any_order=True) + + asyncio.run(self._runnable_test('qux')) + self.mock.assert_has_awaits(calls, any_order=True) + + def test_assert_not_awaited(self): + self.mock.assert_not_awaited() + + asyncio.run(self._runnable_test()) + with self.assertRaises(AssertionError): + self.mock.assert_not_awaited() + + def test_assert_has_awaits_not_matching_spec_error(self): + async def f(x=None): pass + + self.mock = AsyncMock(spec=f) + asyncio.run(self._runnable_test(1)) + + with self.assertRaisesRegex( + AssertionError, + '^{}$'.format( + re.escape('Awaits not found.\n' + 'Expected: [call()]\n' + 'Actual: [call(1)]'))) as cm: + self.mock.assert_has_awaits([call()]) + self.assertIsNone(cm.exception.__cause__) + + with self.assertRaisesRegex( + AssertionError, + '^{}$'.format( + re.escape( + 'Error processing expected awaits.\n' + "Errors: [None, TypeError('too many positional " + "arguments')]\n" + 'Expected: [call(), call(1, 2)]\n' + 'Actual: [call(1)]'))) as cm: + self.mock.assert_has_awaits([call(), call(1, 2)]) + self.assertIsInstance(cm.exception.__cause__, TypeError) diff --git a/Lib/unittest/test/testmock/testcallable.py b/Lib/unittest/test/testmock/testcallable.py index af1ce7ebba..5eadc00704 100644 --- a/Lib/unittest/test/testmock/testcallable.py +++ b/Lib/unittest/test/testmock/testcallable.py @@ -98,8 +98,7 @@ def test_patch_spec_set_instance(self): def test_patch_spec_callable_class(self): class CallableX(X): - def __call__(self): - pass + def __call__(self): pass class Sub(CallableX): pass @@ -128,7 +127,7 @@ class Multi(SomeClass, Sub): result.foo.assert_called_once_with(3, 2, 1) - def test_create_autopsec(self): + def test_create_autospec(self): mock = create_autospec(X) instance = mock() self.assertRaises(TypeError, instance) diff --git a/Lib/unittest/test/testmock/testhelpers.py b/Lib/unittest/test/testmock/testhelpers.py index 7919482ae9..f3c7acb98c 100644 --- a/Lib/unittest/test/testmock/testhelpers.py +++ b/Lib/unittest/test/testmock/testhelpers.py @@ -1,21 +1,20 @@ +import inspect import time import types import unittest from unittest.mock import ( call, _Call, create_autospec, MagicMock, - Mock, ANY, _CallList, patch, PropertyMock + Mock, ANY, _CallList, patch, PropertyMock, _callable ) from datetime import datetime +from functools import partial class SomeClass(object): - def one(self, a, b): - pass - def two(self): - pass - def three(self, a=None): - pass + def one(self, a, b): pass + def two(self): pass + def three(self, a=None): pass @@ -46,12 +45,9 @@ def test_any_and_datetime(self): def test_any_mock_calls_comparison_order(self): mock = Mock() - d = datetime.now() class Foo(object): - def __eq__(self, other): - return False - def __ne__(self, other): - return True + def __eq__(self, other): pass + def __ne__(self, other): pass for d in datetime.now(), Foo(): mock.reset_mock() @@ -144,6 +140,8 @@ def test_call_with_args(self): self.assertEqual(args, ('foo', (1, 2, 3))) self.assertEqual(args, ('foo', (1, 2, 3), {})) self.assertEqual(args, ((1, 2, 3), {})) + self.assertEqual(args.args, (1, 2, 3)) + self.assertEqual(args.kwargs, {}) def test_named_call_with_args(self): @@ -151,6 +149,8 @@ def test_named_call_with_args(self): self.assertEqual(args, ('foo', (1, 2, 3))) self.assertEqual(args, ('foo', (1, 2, 3), {})) + self.assertEqual(args.args, (1, 2, 3)) + self.assertEqual(args.kwargs, {}) self.assertNotEqual(args, ((1, 2, 3),)) self.assertNotEqual(args, ((1, 2, 3), {})) @@ -163,6 +163,8 @@ def test_call_with_kwargs(self): self.assertEqual(args, ('foo', dict(a=3, b=4))) self.assertEqual(args, ('foo', (), dict(a=3, b=4))) self.assertEqual(args, ((), dict(a=3, b=4))) + self.assertEqual(args.args, ()) + self.assertEqual(args.kwargs, dict(a=3, b=4)) def test_named_call_with_kwargs(self): @@ -170,6 +172,8 @@ def test_named_call_with_kwargs(self): self.assertEqual(args, ('foo', dict(a=3, b=4))) self.assertEqual(args, ('foo', (), dict(a=3, b=4))) + self.assertEqual(args.args, ()) + self.assertEqual(args.kwargs, dict(a=3, b=4)) self.assertNotEqual(args, (dict(a=3, b=4),)) self.assertNotEqual(args, ((), dict(a=3, b=4))) @@ -177,6 +181,7 @@ def test_named_call_with_kwargs(self): def test_call_with_args_call_empty_name(self): args = _Call(((1, 2, 3), {})) + self.assertEqual(args, call(1, 2, 3)) self.assertEqual(call(1, 2, 3), args) self.assertIn(call(1, 2, 3), [args]) @@ -269,6 +274,22 @@ def test_extended_call(self): self.assertEqual(mock.mock_calls, last_call.call_list()) + def test_extended_not_equal(self): + a = call(x=1).foo + b = call(x=2).foo + self.assertEqual(a, a) + self.assertEqual(b, b) + self.assertNotEqual(a, b) + + + def test_nested_calls_not_equal(self): + a = call(x=1).foo().bar + b = call(x=2).foo().bar + self.assertEqual(a, a) + self.assertEqual(b, b) + self.assertNotEqual(a, b) + + def test_call_list(self): mock = MagicMock() mock(1) @@ -313,6 +334,26 @@ def test_call_with_name(self): self.assertEqual(_Call((('bar', 'barz'),),)[0], '') self.assertEqual(_Call((('bar', 'barz'), {'hello': 'world'}),)[0], '') + def test_dunder_call(self): + m = MagicMock() + m().foo()['bar']() + self.assertEqual( + m.mock_calls, + [call(), call().foo(), call().foo().__getitem__('bar'), call().foo().__getitem__()()] + ) + m = MagicMock() + m().foo()['bar'] = 1 + self.assertEqual( + m.mock_calls, + [call(), call().foo(), call().foo().__setitem__('bar', 1)] + ) + m = MagicMock() + iter(m().foo()) + self.assertEqual( + m.mock_calls, + [call(), call().foo(), call().foo().__iter__()] + ) + class SpecSignatureTest(unittest.TestCase): @@ -351,8 +392,7 @@ def test_basic(self): def test_create_autospec_return_value(self): - def f(): - pass + def f(): pass mock = create_autospec(f, return_value='foo') self.assertEqual(mock(), 'foo') @@ -372,8 +412,7 @@ def test_autospec_reset_mock(self): def test_mocking_unbound_methods(self): class Foo(object): - def foo(self, foo): - pass + def foo(self, foo): pass p = patch.object(Foo, 'foo') mock_foo = p.start() Foo().foo(1) @@ -381,24 +420,6 @@ def foo(self, foo): mock_foo.assert_called_with(1) - def test_create_autospec_unbound_methods(self): - # see mock issue 128 - # this is expected to fail until the issue is fixed - return - class Foo(object): - def foo(self): - pass - - klass = create_autospec(Foo) - instance = klass() - self.assertRaises(TypeError, instance.foo, 1) - - # Note: no type checking on the "self" parameter - klass.foo(1) - klass.foo.assert_called_with(1) - self.assertRaises(TypeError, klass.foo) - - def test_create_autospec_keyword_arguments(self): class Foo(object): a = 3 @@ -407,8 +428,7 @@ class Foo(object): def test_create_autospec_keyword_only_arguments(self): - def foo(a, *, b=None): - pass + def foo(a, *, b=None): pass m = create_autospec(foo) m(1) @@ -421,8 +441,7 @@ def foo(a, *, b=None): def test_function_as_instance_attribute(self): obj = SomeClass() - def f(a): - pass + def f(a): pass obj.f = f mock = create_autospec(obj) @@ -458,13 +477,57 @@ class Sub(SomeClass): self._check_someclass_mock(mock) + def test_spec_has_descriptor_returning_function(self): + + class CrazyDescriptor(object): + + def __get__(self, obj, type_): + if obj is None: + return lambda x: None + + class MyClass(object): + + some_attr = CrazyDescriptor() + + mock = create_autospec(MyClass) + mock.some_attr(1) + with self.assertRaises(TypeError): + mock.some_attr() + with self.assertRaises(TypeError): + mock.some_attr(1, 2) + + + def test_spec_has_function_not_in_bases(self): + + class CrazyClass(object): + + def __dir__(self): + return super(CrazyClass, self).__dir__()+['crazy'] + + def __getattr__(self, item): + if item == 'crazy': + return lambda x: x + raise AttributeError(item) + + inst = CrazyClass() + with self.assertRaises(AttributeError): + inst.other + self.assertEqual(inst.crazy(42), 42) + + mock = create_autospec(inst) + mock.crazy(42) + with self.assertRaises(TypeError): + mock.crazy() + with self.assertRaises(TypeError): + mock.crazy(1, 2) + + def test_builtin_functions_types(self): # we could replace builtin functions / methods with a function # with *args / **kwargs signature. Using the builtin method type # as a spec seems to work fairly well though. class BuiltinSubclass(list): - def bar(self, arg): - pass + def bar(self, arg): pass sorted = sorted attr = {} @@ -538,17 +601,13 @@ class Sub(SomeClass): def test_descriptors(self): class Foo(object): @classmethod - def f(cls, a, b): - pass + def f(cls, a, b): pass @staticmethod - def g(a, b): - pass + def g(a, b): pass - class Bar(Foo): - pass + class Bar(Foo): pass - class Baz(SomeClass, Bar): - pass + class Baz(SomeClass, Bar): pass for spec in (Foo, Foo(), Bar, Bar(), Baz, Baz()): mock = create_autospec(spec) @@ -561,8 +620,7 @@ class Baz(SomeClass, Bar): def test_recursive(self): class A(object): - def a(self): - pass + def a(self): pass foo = 'foo bar baz' bar = foo @@ -584,11 +642,9 @@ def a(self): def test_spec_inheritance_for_classes(self): class Foo(object): - def a(self, x): - pass + def a(self, x): pass class Bar(object): - def f(self, y): - pass + def f(self, y): pass class_mock = create_autospec(Foo) @@ -668,8 +724,7 @@ def test_builtins(self): def test_function(self): - def f(a, b): - pass + def f(a, b): pass mock = create_autospec(f) self.assertRaises(TypeError, mock) @@ -699,9 +754,10 @@ class RaiserClass(object): def existing(a, b): return a + b + self.assertEqual(RaiserClass.existing(1, 2), 3) s = create_autospec(RaiserClass) self.assertRaises(TypeError, lambda x: s.existing(1, 2, 3)) - s.existing(1, 2) + self.assertEqual(s.existing(1, 2), s.existing.return_value) self.assertRaises(AttributeError, lambda: s.nonexisting) # check we can fetch the raiser attribute and it has no spec @@ -711,8 +767,7 @@ def existing(a, b): def test_signature_class(self): class Foo(object): - def __init__(self, a, b=3): - pass + def __init__(self, a, b=3): pass mock = create_autospec(Foo) @@ -738,10 +793,8 @@ class Foo(object): def test_signature_callable(self): class Callable(object): - def __init__(self, x, y): - pass - def __call__(self, a): - pass + def __init__(self, x, y): pass + def __call__(self, a): pass mock = create_autospec(Callable) mock(1, 2) @@ -797,8 +850,7 @@ class Foo(object): def test_autospec_functions_with_self_in_odd_place(self): class Foo(object): - def f(a, self): - pass + def f(a, self): pass a = create_autospec(Foo) a.f(10) @@ -815,12 +867,9 @@ def __init__(self, value): self.value = value def __get__(self, obj, cls=None): - if obj is None: - return self - return self.value + return self - def __set__(self, obj, value): - pass + def __set__(self, obj, value): pass class MyProperty(property): pass @@ -829,12 +878,10 @@ class Foo(object): __slots__ = ['slot'] @property - def prop(self): - return 3 + def prop(self): pass @MyProperty - def subprop(self): - return 4 + def subprop(self): pass desc = Descriptor(42) @@ -871,6 +918,84 @@ def test_autospec_on_bound_builtin_function(self): mocked.assert_called_once_with(4, 5, 6) + def test_autospec_getattr_partial_function(self): + # bpo-32153 : getattr returning partial functions without + # __name__ should not create AttributeError in create_autospec + class Foo: + + def __getattr__(self, attribute): + return partial(lambda name: name, attribute) + + proxy = Foo() + autospec = create_autospec(proxy) + self.assertFalse(hasattr(autospec, '__name__')) + + + def test_spec_inspect_signature(self): + + def myfunc(x, y): pass + + mock = create_autospec(myfunc) + mock(1, 2) + mock(x=1, y=2) + + self.assertEqual(inspect.signature(mock), inspect.signature(myfunc)) + self.assertEqual(mock.mock_calls, [call(1, 2), call(x=1, y=2)]) + self.assertRaises(TypeError, mock, 1) + + + def test_spec_inspect_signature_annotations(self): + + def foo(a: int, b: int=10, *, c:int) -> int: + return a + b + c + + self.assertEqual(foo(1, 2 , c=3), 6) + mock = create_autospec(foo) + mock(1, 2, c=3) + mock(1, c=3) + + self.assertEqual(inspect.signature(mock), inspect.signature(foo)) + self.assertEqual(mock.mock_calls, [call(1, 2, c=3), call(1, c=3)]) + self.assertRaises(TypeError, mock, 1) + self.assertRaises(TypeError, mock, 1, 2, 3, c=4) + + + def test_spec_function_no_name(self): + func = lambda: 'nope' + mock = create_autospec(func) + self.assertEqual(mock.__name__, 'funcopy') + + + def test_spec_function_assert_has_calls(self): + def f(a): pass + mock = create_autospec(f) + mock(1) + mock.assert_has_calls([call(1)]) + with self.assertRaises(AssertionError): + mock.assert_has_calls([call(2)]) + + + def test_spec_function_assert_any_call(self): + def f(a): pass + mock = create_autospec(f) + mock(1) + mock.assert_any_call(1) + with self.assertRaises(AssertionError): + mock.assert_any_call(2) + + + def test_spec_function_reset_mock(self): + def f(a): pass + rv = Mock() + mock = create_autospec(f, return_value=rv) + mock(1)(2) + self.assertEqual(mock.mock_calls, [call(1)]) + self.assertEqual(rv.mock_calls, [call(2)]) + mock.reset_mock() + self.assertEqual(mock.mock_calls, []) + self.assertEqual(rv.mock_calls, []) + + class TestCallList(unittest.TestCase): def test_args_list_contains_call_list(self): @@ -942,5 +1067,40 @@ def test_propertymock_returnvalue(self): self.assertNotIsInstance(returned, PropertyMock) +class TestCallablePredicate(unittest.TestCase): + + def test_type(self): + for obj in [str, bytes, int, list, tuple, SomeClass]: + self.assertTrue(_callable(obj)) + + def test_call_magic_method(self): + class Callable: + def __call__(self): pass + instance = Callable() + self.assertTrue(_callable(instance)) + + def test_staticmethod(self): + class WithStaticMethod: + @staticmethod + def staticfunc(): pass + self.assertTrue(_callable(WithStaticMethod.staticfunc)) + + def test_non_callable_staticmethod(self): + class BadStaticMethod: + not_callable = staticmethod(None) + self.assertFalse(_callable(BadStaticMethod.not_callable)) + + def test_classmethod(self): + class WithClassMethod: + @classmethod + def classfunc(cls): pass + self.assertTrue(_callable(WithClassMethod.classfunc)) + + def test_non_callable_classmethod(self): + class BadClassMethod: + not_callable = classmethod(None) + self.assertFalse(_callable(BadClassMethod.not_callable)) + + if __name__ == '__main__': unittest.main() diff --git a/Lib/unittest/test/testmock/testmagicmethods.py b/Lib/unittest/test/testmock/testmagicmethods.py index 37623dcebc..76b3a560de 100644 --- a/Lib/unittest/test/testmock/testmagicmethods.py +++ b/Lib/unittest/test/testmock/testmagicmethods.py @@ -1,6 +1,9 @@ +import asyncio +import math import unittest +import os import sys -from unittest.mock import Mock, MagicMock, _magics +from unittest.mock import AsyncMock, Mock, MagicMock, _magics @@ -268,6 +271,31 @@ def test_magic_mock_equality(self): self.assertEqual(mock == mock, True) self.assertEqual(mock != mock, False) + def test_asyncmock_defaults(self): + mock = AsyncMock() + self.assertEqual(int(mock), 1) + self.assertEqual(complex(mock), 1j) + self.assertEqual(float(mock), 1.0) + self.assertNotIn(object(), mock) + self.assertEqual(len(mock), 0) + self.assertEqual(list(mock), []) + self.assertEqual(hash(mock), object.__hash__(mock)) + self.assertEqual(str(mock), object.__str__(mock)) + self.assertTrue(bool(mock)) + self.assertEqual(round(mock), mock.__round__()) + self.assertEqual(math.trunc(mock), mock.__trunc__()) + self.assertEqual(math.floor(mock), mock.__floor__()) + self.assertEqual(math.ceil(mock), mock.__ceil__()) + self.assertTrue(asyncio.iscoroutinefunction(mock.__aexit__)) + self.assertTrue(asyncio.iscoroutinefunction(mock.__aenter__)) + self.assertIsInstance(mock.__aenter__, AsyncMock) + self.assertIsInstance(mock.__aexit__, AsyncMock) + + # in Python 3 oct and hex use __index__ + # so these tests are for __index__ in py3k + self.assertEqual(oct(mock), '0o1') + self.assertEqual(hex(mock), '0x1') + # how to test __sizeof__ ? def test_magicmock_defaults(self): mock = MagicMock() @@ -280,6 +308,14 @@ def test_magicmock_defaults(self): self.assertEqual(hash(mock), object.__hash__(mock)) self.assertEqual(str(mock), object.__str__(mock)) self.assertTrue(bool(mock)) + self.assertEqual(round(mock), mock.__round__()) + self.assertEqual(math.trunc(mock), mock.__trunc__()) + self.assertEqual(math.floor(mock), mock.__floor__()) + self.assertEqual(math.ceil(mock), mock.__ceil__()) + self.assertTrue(asyncio.iscoroutinefunction(mock.__aexit__)) + self.assertTrue(asyncio.iscoroutinefunction(mock.__aenter__)) + self.assertIsInstance(mock.__aenter__, AsyncMock) + self.assertIsInstance(mock.__aexit__, AsyncMock) # in Python 3 oct and hex use __index__ # so these tests are for __index__ in py3k @@ -288,10 +324,18 @@ def test_magicmock_defaults(self): # how to test __sizeof__ ? + def test_magic_methods_fspath(self): + mock = MagicMock() + expected_path = mock.__fspath__() + mock.reset_mock() + + self.assertEqual(os.fspath(mock), expected_path) + mock.__fspath__.assert_called_once() + + def test_magic_methods_and_spec(self): class Iterable(object): - def __iter__(self): - pass + def __iter__(self): pass mock = Mock(spec=Iterable) self.assertRaises(AttributeError, lambda: mock.__iter__) @@ -315,8 +359,7 @@ def set_int(): def test_magic_methods_and_spec_set(self): class Iterable(object): - def __iter__(self): - pass + def __iter__(self): pass mock = Mock(spec_set=Iterable) self.assertRaises(AttributeError, lambda: mock.__iter__) diff --git a/Lib/unittest/test/testmock/testmock.py b/Lib/unittest/test/testmock/testmock.py index b64c8663d2..1cde45e9ae 100644 --- a/Lib/unittest/test/testmock/testmock.py +++ b/Lib/unittest/test/testmock/testmock.py @@ -1,4 +1,5 @@ import copy +import re import sys import tempfile @@ -8,7 +9,7 @@ from unittest.mock import ( call, DEFAULT, patch, sentinel, MagicMock, Mock, NonCallableMock, - NonCallableMagicMock, _CallList, + NonCallableMagicMock, AsyncMock, _Call, _CallList, create_autospec ) @@ -27,16 +28,16 @@ def next(self): class Something(object): - def meth(self, a, b, c, d=None): - pass + def meth(self, a, b, c, d=None): pass @classmethod - def cmeth(cls, a, b, c, d=None): - pass + def cmeth(cls, a, b, c, d=None): pass @staticmethod - def smeth(a, b, c, d=None): - pass + def smeth(a, b, c, d=None): pass + + +def something(a): pass class MockTest(unittest.TestCase): @@ -82,6 +83,21 @@ def test_return_value_in_constructor(self): "return value in constructor not honoured") + def test_change_return_value_via_delegate(self): + def f(): pass + mock = create_autospec(f) + mock.mock.return_value = 1 + self.assertEqual(mock(), 1) + + + def test_change_side_effect_via_delegate(self): + def f(): pass + mock = create_autospec(f) + mock.mock.side_effect = TypeError() + with self.assertRaises(TypeError): + mock() + + def test_repr(self): mock = Mock(name='foo') self.assertIn('foo', repr(mock)) @@ -160,8 +176,7 @@ def test_autospec_side_effect(self): results = [1, 2, 3] def effect(): return results.pop() - def f(): - pass + def f(): pass mock = create_autospec(f) mock.side_effect = [1, 2, 3] @@ -176,28 +191,12 @@ def f(): def test_autospec_side_effect_exception(self): # Test for issue 23661 - def f(): - pass + def f(): pass mock = create_autospec(f) mock.side_effect = ValueError('Bazinga!') self.assertRaisesRegex(ValueError, 'Bazinga!', mock) - @unittest.skipUnless('java' in sys.platform, - 'This test only applies to Jython') - def test_java_exception_side_effect(self): - import java - mock = Mock(side_effect=java.lang.RuntimeException("Boom!")) - - # can't use assertRaises with java exceptions - try: - mock(1, 2, fish=3) - except java.lang.RuntimeException: - pass - else: - self.fail('java exception not raised') - mock.assert_called_with(1,2, fish=3) - def test_reset_mock(self): parent = Mock() @@ -266,6 +265,10 @@ def test_call(self): self.assertEqual(mock.call_count, 1, "call_count incoreect") self.assertEqual(mock.call_args, ((sentinel.Arg,), {}), "call_args not set") + self.assertEqual(mock.call_args.args, (sentinel.Arg,), + "call_args not set") + self.assertEqual(mock.call_args.kwargs, {}, + "call_args not set") self.assertEqual(mock.call_args_list, [((sentinel.Arg,), {})], "call_args_list not initialised correctly") @@ -299,6 +302,8 @@ def test_call_args_comparison(self): ]) self.assertEqual(mock.call_args, ((sentinel.Arg,), {"kw": sentinel.Kwarg})) + self.assertEqual(mock.call_args.args, (sentinel.Arg,)) + self.assertEqual(mock.call_args.kwargs, {"kw": sentinel.Kwarg}) # Comparing call_args to a long sequence should not raise # an exception. See issue 24857. @@ -348,8 +353,7 @@ def test_assert_called_with_any(self): def test_assert_called_with_function_spec(self): - def f(a, b, c, d=None): - pass + def f(a, b, c, d=None): pass mock = Mock(spec=f) @@ -384,6 +388,14 @@ def _check(mock): _check(mock) + def test_assert_called_exception_message(self): + msg = "Expected '{0}' to have been called" + with self.assertRaisesRegex(AssertionError, msg.format('mock')): + Mock().assert_called() + with self.assertRaisesRegex(AssertionError, msg.format('test_name')): + Mock(name="test_name").assert_called() + + def test_assert_called_once_with(self): mock = Mock() mock() @@ -407,10 +419,17 @@ def test_assert_called_once_with(self): lambda: mock.assert_called_once_with('bob', 'bar', baz=2) ) + def test_assert_called_once_with_call_list(self): + m = Mock() + m(1) + m(2) + self.assertRaisesRegex(AssertionError, + re.escape("Calls: [call(1), call(2)]"), + lambda: m.assert_called_once_with(2)) + def test_assert_called_once_with_function_spec(self): - def f(a, b, c, d=None): - pass + def f(a, b, c, d=None): pass mock = Mock(spec=f) @@ -514,8 +533,7 @@ def test_from_spec(self): class Something(object): x = 3 __something__ = None - def y(self): - pass + def y(self): pass def test_attributes(mock): # should work @@ -549,6 +567,16 @@ def test_wraps_calls(self): real.assert_called_with(1, 2, fish=3) + def test_wraps_prevents_automatic_creation_of_mocks(self): + class Real(object): + pass + + real = Real() + mock = Mock(wraps=real) + + self.assertRaises(AttributeError, lambda: mock.new_attr()) + + def test_wraps_call_with_nondefault_return_value(self): real = Mock() @@ -575,6 +603,110 @@ class Real(object): self.assertEqual(result, Real.attribute.frog()) + def test_customize_wrapped_object_with_side_effect_iterable_with_default(self): + class Real(object): + def method(self): + return sentinel.ORIGINAL_VALUE + + real = Real() + mock = Mock(wraps=real) + mock.method.side_effect = [sentinel.VALUE1, DEFAULT] + + self.assertEqual(mock.method(), sentinel.VALUE1) + self.assertEqual(mock.method(), sentinel.ORIGINAL_VALUE) + self.assertRaises(StopIteration, mock.method) + + + def test_customize_wrapped_object_with_side_effect_iterable(self): + class Real(object): + def method(self): pass + + real = Real() + mock = Mock(wraps=real) + mock.method.side_effect = [sentinel.VALUE1, sentinel.VALUE2] + + self.assertEqual(mock.method(), sentinel.VALUE1) + self.assertEqual(mock.method(), sentinel.VALUE2) + self.assertRaises(StopIteration, mock.method) + + + def test_customize_wrapped_object_with_side_effect_exception(self): + class Real(object): + def method(self): pass + + real = Real() + mock = Mock(wraps=real) + mock.method.side_effect = RuntimeError + + self.assertRaises(RuntimeError, mock.method) + + + def test_customize_wrapped_object_with_side_effect_function(self): + class Real(object): + def method(self): pass + def side_effect(): + return sentinel.VALUE + + real = Real() + mock = Mock(wraps=real) + mock.method.side_effect = side_effect + + self.assertEqual(mock.method(), sentinel.VALUE) + + + def test_customize_wrapped_object_with_return_value(self): + class Real(object): + def method(self): pass + + real = Real() + mock = Mock(wraps=real) + mock.method.return_value = sentinel.VALUE + + self.assertEqual(mock.method(), sentinel.VALUE) + + + def test_customize_wrapped_object_with_return_value_and_side_effect(self): + # side_effect should always take precedence over return_value. + class Real(object): + def method(self): pass + + real = Real() + mock = Mock(wraps=real) + mock.method.side_effect = [sentinel.VALUE1, sentinel.VALUE2] + mock.method.return_value = sentinel.WRONG_VALUE + + self.assertEqual(mock.method(), sentinel.VALUE1) + self.assertEqual(mock.method(), sentinel.VALUE2) + self.assertRaises(StopIteration, mock.method) + + + def test_customize_wrapped_object_with_return_value_and_side_effect2(self): + # side_effect can return DEFAULT to default to return_value + class Real(object): + def method(self): pass + + real = Real() + mock = Mock(wraps=real) + mock.method.side_effect = lambda: DEFAULT + mock.method.return_value = sentinel.VALUE + + self.assertEqual(mock.method(), sentinel.VALUE) + + + def test_customize_wrapped_object_with_return_value_and_side_effect_default(self): + class Real(object): + def method(self): pass + + real = Real() + mock = Mock(wraps=real) + mock.method.side_effect = [sentinel.VALUE1, DEFAULT] + mock.method.return_value = sentinel.RETURN + + self.assertEqual(mock.method(), sentinel.VALUE1) + self.assertEqual(mock.method(), sentinel.RETURN) + self.assertRaises(StopIteration, mock.method) + + def test_exceptional_side_effect(self): mock = Mock(side_effect=AttributeError) self.assertRaises(AttributeError, mock) @@ -593,7 +725,7 @@ def test_baseexceptional_side_effect(self): def test_assert_called_with_message(self): mock = Mock() - self.assertRaisesRegex(AssertionError, 'Not called', + self.assertRaisesRegex(AssertionError, 'not called', mock.assert_called_with) @@ -642,6 +774,26 @@ class X(object): self.assertIsInstance(mock, X) + def test_spec_class_no_object_base(self): + class X: + pass + + mock = Mock(spec=X) + self.assertIsInstance(mock, X) + + mock = Mock(spec=X()) + self.assertIsInstance(mock, X) + + self.assertIs(mock.__class__, X) + self.assertEqual(Mock().__class__.__name__, 'Mock') + + mock = Mock(spec_set=X) + self.assertIsInstance(mock, X) + + mock = Mock(spec_set=X()) + self.assertIsInstance(mock, X) + + def test_setting_attribute_with_spec_set(self): class X(object): y = 3 @@ -690,6 +842,7 @@ def test(): def test_setting_call(self): mock = Mock() def __call__(self, a): + self._increment_mock_call(a) return self._mock_call(a) type(mock).__call__ = __call__ @@ -748,6 +901,15 @@ def test_filter_dir(self): patcher.stop() + def test_dir_does_not_include_deleted_attributes(self): + mock = Mock() + mock.child.return_value = 1 + + self.assertIn('child', dir(mock)) + del mock.child + self.assertNotIn('child', dir(mock)) + + def test_configure_mock(self): mock = Mock(foo='bar') self.assertEqual(mock.foo, 'bar') @@ -771,25 +933,20 @@ def test_configure_mock(self): def assertRaisesWithMsg(self, exception, message, func, *args, **kwargs): # needed because assertRaisesRegex doesn't work easily with newlines - try: + with self.assertRaises(exception) as context: func(*args, **kwargs) - except: - instance = sys.exc_info()[1] - self.assertIsInstance(instance, exception) - else: - self.fail('Exception %r not raised' % (exception,)) - - msg = str(instance) + msg = str(context.exception) self.assertEqual(msg, message) def test_assert_called_with_failure_message(self): mock = NonCallableMock() + actual = 'not called.' expected = "mock(1, '2', 3, bar='foo')" - message = 'Expected call: %s\nNot called' + message = 'expected call not found.\nExpected: %s\nActual: %s' self.assertRaisesWithMsg( - AssertionError, message % (expected,), + AssertionError, message % (expected, actual), mock.assert_called_with, 1, '2', 3, bar='foo' ) @@ -802,7 +959,7 @@ def test_assert_called_with_failure_message(self): for meth in asserters: actual = "foo(1, '2', 3, foo='foo')" expected = "foo(1, '2', 3, bar='foo')" - message = 'Expected call: %s\nActual call: %s' + message = 'expected call not found.\nExpected: %s\nActual: %s' self.assertRaisesWithMsg( AssertionError, message % (expected, actual), meth, 1, '2', 3, bar='foo' @@ -812,7 +969,7 @@ def test_assert_called_with_failure_message(self): for meth in asserters: actual = "foo(1, '2', 3, foo='foo')" expected = "foo(bar='foo')" - message = 'Expected call: %s\nActual call: %s' + message = 'expected call not found.\nExpected: %s\nActual: %s' self.assertRaisesWithMsg( AssertionError, message % (expected, actual), meth, bar='foo' @@ -822,7 +979,7 @@ def test_assert_called_with_failure_message(self): for meth in asserters: actual = "foo(1, '2', 3, foo='foo')" expected = "foo(1, 2, 3)" - message = 'Expected call: %s\nActual call: %s' + message = 'expected call not found.\nExpected: %s\nActual: %s' self.assertRaisesWithMsg( AssertionError, message % (expected, actual), meth, 1, 2, 3 @@ -832,7 +989,7 @@ def test_assert_called_with_failure_message(self): for meth in asserters: actual = "foo(1, '2', 3, foo='foo')" expected = "foo()" - message = 'Expected call: %s\nActual call: %s' + message = 'expected call not found.\nExpected: %s\nActual: %s' self.assertRaisesWithMsg( AssertionError, message % (expected, actual), meth ) @@ -916,6 +1073,69 @@ def test_mock_calls(self): call().__int__().call_list()) + def test_child_mock_call_equal(self): + m = Mock() + result = m() + result.wibble() + # parent looks like this: + self.assertEqual(m.mock_calls, [call(), call().wibble()]) + # but child should look like this: + self.assertEqual(result.mock_calls, [call.wibble()]) + + + def test_mock_call_not_equal_leaf(self): + m = Mock() + m.foo().something() + self.assertNotEqual(m.mock_calls[1], call.foo().different()) + self.assertEqual(m.mock_calls[0], call.foo()) + + + def test_mock_call_not_equal_non_leaf(self): + m = Mock() + m.foo().bar() + self.assertNotEqual(m.mock_calls[1], call.baz().bar()) + self.assertNotEqual(m.mock_calls[0], call.baz()) + + + def test_mock_call_not_equal_non_leaf_params_different(self): + m = Mock() + m.foo(x=1).bar() + # This isn't ideal, but there's no way to fix it without breaking backwards compatibility: + self.assertEqual(m.mock_calls[1], call.foo(x=2).bar()) + + + def test_mock_call_not_equal_non_leaf_attr(self): + m = Mock() + m.foo.bar() + self.assertNotEqual(m.mock_calls[0], call.baz.bar()) + + + def test_mock_call_not_equal_non_leaf_call_versus_attr(self): + m = Mock() + m.foo.bar() + self.assertNotEqual(m.mock_calls[0], call.foo().bar()) + + + def test_mock_call_repr(self): + m = Mock() + m.foo().bar().baz.bob() + self.assertEqual(repr(m.mock_calls[0]), 'call.foo()') + self.assertEqual(repr(m.mock_calls[1]), 'call.foo().bar()') + self.assertEqual(repr(m.mock_calls[2]), 'call.foo().bar().baz.bob()') + + + def test_mock_call_repr_loop(self): + m = Mock() + m.foo = m + repr(m.foo()) + self.assertRegex(repr(m.foo()), r"") + + + def test_mock_calls_contains(self): + m = Mock() + self.assertFalse([call()] in m.mock_calls) + + def test_subclassing(self): class Subclass(Mock): pass @@ -974,9 +1194,8 @@ def test_call_args_two_tuple(self): mock(2, b=4) self.assertEqual(len(mock.call_args), 2) - args, kwargs = mock.call_args - self.assertEqual(args, (2,)) - self.assertEqual(kwargs, dict(b=4)) + self.assertEqual(mock.call_args.args, (2,)) + self.assertEqual(mock.call_args.kwargs, dict(b=4)) expected_list = [((1,), dict(a=3)), ((2,), dict(b=4))] for expected, call_args in zip(expected_list, mock.call_args_list): @@ -1129,9 +1348,56 @@ def test_assert_has_calls(self): ) + def test_assert_has_calls_nested_spec(self): + class Something: + + def __init__(self): pass + def meth(self, a, b, c, d=None): pass + + class Foo: + + def __init__(self, a): pass + def meth1(self, a, b): pass + + mock_class = create_autospec(Something) + + for m in [mock_class, mock_class()]: + m.meth(1, 2, 3, d=1) + m.assert_has_calls([call.meth(1, 2, 3, d=1)]) + m.assert_has_calls([call.meth(1, 2, 3, 1)]) + + mock_class.reset_mock() + + for m in [mock_class, mock_class()]: + self.assertRaises(AssertionError, m.assert_has_calls, [call.Foo()]) + m.Foo(1).meth1(1, 2) + m.assert_has_calls([call.Foo(1), call.Foo(1).meth1(1, 2)]) + m.Foo.assert_has_calls([call(1), call().meth1(1, 2)]) + + mock_class.reset_mock() + + invalid_calls = [call.meth(1), + call.non_existent(1), + call.Foo().non_existent(1), + call.Foo().meth(1, 2, 3, 4)] + + for kall in invalid_calls: + self.assertRaises(AssertionError, + mock_class.assert_has_calls, + [kall] + ) + + + def test_assert_has_calls_nested_without_spec(self): + m = MagicMock() + m().foo().bar().baz() + m.one().two().three() + calls = call.one().two().three().call_list() + m.assert_has_calls(calls) + + def test_assert_has_calls_with_function_spec(self): - def f(a, b, c, d=None): - pass + def f(a, b, c, d=None): pass mock = Mock(spec=f) @@ -1161,6 +1427,32 @@ def f(a, b, c, d=None): mock.assert_has_calls(calls[:-1]) mock.assert_has_calls(calls[:-1], any_order=True) + def test_assert_has_calls_not_matching_spec_error(self): + def f(x=None): pass + + mock = Mock(spec=f) + mock(1) + + with self.assertRaisesRegex( + AssertionError, + '^{}$'.format( + re.escape('Calls not found.\n' + 'Expected: [call()]\n' + 'Actual: [call(1)]'))) as cm: + mock.assert_has_calls([call()]) + self.assertIsNone(cm.exception.__cause__) + + + with self.assertRaisesRegex( + AssertionError, + '^{}$'.format( + re.escape( + 'Error processing expected calls.\n' + "Errors: [None, TypeError('too many positional arguments')]\n" + "Expected: [call(), call(1, 2)]\n" + 'Actual: [call(1)]'))) as cm: + mock.assert_has_calls([call(), call(1, 2)]) + self.assertIsInstance(cm.exception.__cause__, TypeError) def test_assert_any_call(self): mock = Mock() @@ -1189,8 +1481,7 @@ def test_assert_any_call(self): def test_assert_any_call_with_function_spec(self): - def f(a, b, c, d=None): - pass + def f(a, b, c, d=None): pass mock = Mock(spec=f) @@ -1209,8 +1500,7 @@ def f(a, b, c, d=None): def test_mock_calls_create_autospec(self): - def f(a, b): - pass + def f(a, b): pass obj = Iter() obj.f = f @@ -1231,12 +1521,28 @@ def test_create_autospec_with_name(self): m = mock.create_autospec(object(), name='sweet_func') self.assertIn('sweet_func', repr(m)) + #Issue23078 + def test_create_autospec_classmethod_and_staticmethod(self): + class TestClass: + @classmethod + def class_method(cls): pass + + @staticmethod + def static_method(): pass + for method in ('class_method', 'static_method'): + with self.subTest(method=method): + mock_method = mock.create_autospec(getattr(TestClass, method)) + mock_method() + mock_method.assert_called_once_with() + self.assertRaises(TypeError, mock_method, 'extra_arg') + #Issue21238 def test_mock_unsafe(self): m = Mock() - with self.assertRaises(AttributeError): + msg = "Attributes cannot start with 'assert' or 'assret'" + with self.assertRaisesRegex(AttributeError, msg): m.assert_foo_call() - with self.assertRaises(AttributeError): + with self.assertRaisesRegex(AttributeError, msg): m.assret_foo_call() m = Mock(unsafe=True) m.assert_foo_call() @@ -1250,6 +1556,13 @@ def test_assert_not_called(self): with self.assertRaises(AssertionError): m.hello.assert_not_called() + def test_assert_not_called_message(self): + m = Mock() + m(1, 2) + self.assertRaisesRegex(AssertionError, + re.escape("Calls: [call(1, 2)]"), + m.assert_not_called) + def test_assert_called(self): m = Mock() with self.assertRaises(AssertionError): @@ -1271,11 +1584,25 @@ def test_assert_called_once(self): with self.assertRaises(AssertionError): m.hello.assert_called_once() - #Issue21256 printout of keyword args should be in deterministic order - def test_sorted_call_signature(self): + def test_assert_called_once_message(self): + m = Mock() + m(1, 2) + m(3) + self.assertRaisesRegex(AssertionError, + re.escape("Calls: [call(1, 2), call(3)]"), + m.assert_called_once) + + def test_assert_called_once_message_not_called(self): + m = Mock() + with self.assertRaises(AssertionError) as e: + m.assert_called_once() + self.assertNotIn("Calls:", str(e.exception)) + + #Issue37212 printout of keyword args now preserves the original order + def test_ordered_call_signature(self): m = Mock() m.hello(name='hello', daddy='hero') - text = "call(daddy='hero', name='hello')" + text = "call(name='hello', daddy='hero')" self.assertEqual(repr(m.hello.call_args), text) #Issue21270 overrides tuple methods for mock.call objects @@ -1377,7 +1704,8 @@ def test_mock_add_spec_magic_methods(self): def test_adding_child_mock(self): - for Klass in NonCallableMock, Mock, MagicMock, NonCallableMagicMock: + for Klass in (NonCallableMock, Mock, MagicMock, NonCallableMagicMock, + AsyncMock): mock = Klass() mock.foo = Mock() @@ -1450,6 +1778,29 @@ def test_mock_open_reuse_issue_21750(self): f2_data = f2.read() self.assertEqual(f1_data, f2_data) + def test_mock_open_dunder_iter_issue(self): + # Test dunder_iter method generates the expected result and + # consumes the iterator. + mocked_open = mock.mock_open(read_data='Remarkable\nNorwegian Blue') + f1 = mocked_open('a-name') + lines = [line for line in f1] + self.assertEqual(lines[0], 'Remarkable\n') + self.assertEqual(lines[1], 'Norwegian Blue') + self.assertEqual(list(f1), []) + + def test_mock_open_using_next(self): + mocked_open = mock.mock_open(read_data='1st line\n2nd line\n3rd line') + f1 = mocked_open('a-name') + line1 = next(f1) + line2 = f1.__next__() + lines = [line for line in f1] + self.assertEqual(line1, '1st line\n') + self.assertEqual(line2, '2nd line\n') + self.assertEqual(lines[0], '3rd line') + self.assertEqual(list(f1), []) + with self.assertRaises(StopIteration): + next(f1) + def test_mock_open_write(self): # Test exception in file writing write() mock_namedtemp = mock.mock_open(mock.MagicMock(name='JLV')) @@ -1543,6 +1894,55 @@ def test_attach_mock_return_value(self): self.assertEqual(m.mock_calls, call().foo().call_list()) + def test_attach_mock_patch_autospec(self): + parent = Mock() + + with mock.patch(f'{__name__}.something', autospec=True) as mock_func: + self.assertEqual(mock_func.mock._extract_mock_name(), 'something') + parent.attach_mock(mock_func, 'child') + parent.child(1) + something(2) + mock_func(3) + + parent_calls = [call.child(1), call.child(2), call.child(3)] + child_calls = [call(1), call(2), call(3)] + self.assertEqual(parent.mock_calls, parent_calls) + self.assertEqual(parent.child.mock_calls, child_calls) + self.assertEqual(something.mock_calls, child_calls) + self.assertEqual(mock_func.mock_calls, child_calls) + self.assertIn('mock.child', repr(parent.child.mock)) + self.assertEqual(mock_func.mock._extract_mock_name(), 'mock.child') + + + def test_attach_mock_patch_autospec_signature(self): + with mock.patch(f'{__name__}.Something.meth', autospec=True) as mocked: + manager = Mock() + manager.attach_mock(mocked, 'attach_meth') + obj = Something() + obj.meth(1, 2, 3, d=4) + manager.assert_has_calls([call.attach_meth(mock.ANY, 1, 2, 3, d=4)]) + obj.meth.assert_has_calls([call(mock.ANY, 1, 2, 3, d=4)]) + mocked.assert_has_calls([call(mock.ANY, 1, 2, 3, d=4)]) + + with mock.patch(f'{__name__}.something', autospec=True) as mocked: + manager = Mock() + manager.attach_mock(mocked, 'attach_func') + something(1) + manager.assert_has_calls([call.attach_func(1)]) + something.assert_has_calls([call(1)]) + mocked.assert_has_calls([call(1)]) + + with mock.patch(f'{__name__}.Something', autospec=True) as mocked: + manager = Mock() + manager.attach_mock(mocked, 'attach_obj') + obj = Something() + obj.meth(1, 2, 3, d=4) + manager.assert_has_calls([call.attach_obj(), + call.attach_obj().meth(1, 2, 3, d=4)]) + obj.meth.assert_has_calls([call(1, 2, 3, d=4)]) + mocked.assert_has_calls([call(), call().meth(1, 2, 3, d=4)]) + + def test_attribute_deletion(self): for mock in (Mock(), MagicMock(), NonCallableMagicMock(), NonCallableMock()): @@ -1556,6 +1956,43 @@ def test_attribute_deletion(self): self.assertRaises(AttributeError, getattr, mock, 'f') + def test_mock_does_not_raise_on_repeated_attribute_deletion(self): + # bpo-20239: Assigning and deleting twice an attribute raises. + for mock in (Mock(), MagicMock(), NonCallableMagicMock(), + NonCallableMock()): + mock.foo = 3 + self.assertTrue(hasattr(mock, 'foo')) + self.assertEqual(mock.foo, 3) + + del mock.foo + self.assertFalse(hasattr(mock, 'foo')) + + mock.foo = 4 + self.assertTrue(hasattr(mock, 'foo')) + self.assertEqual(mock.foo, 4) + + del mock.foo + self.assertFalse(hasattr(mock, 'foo')) + + + def test_mock_raises_when_deleting_nonexistent_attribute(self): + for mock in (Mock(), MagicMock(), NonCallableMagicMock(), + NonCallableMock()): + del mock.foo + with self.assertRaises(AttributeError): + del mock.foo + + + def test_reset_mock_does_not_raise_on_attr_deletion(self): + # bpo-31177: reset_mock should not raise AttributeError when attributes + # were deleted in a mock instance + mock = Mock() + mock.child = True + del mock.child + mock.reset_mock() + self.assertFalse(hasattr(mock, 'child')) + + def test_class_assignable(self): for mock in Mock(), MagicMock(): self.assertNotIsInstance(mock, int) @@ -1564,6 +2001,85 @@ def test_class_assignable(self): self.assertIsInstance(mock, int) mock.foo + def test_name_attribute_of_call(self): + # bpo-35357: _Call should not disclose any attributes whose names + # may clash with popular ones (such as ".name") + self.assertIsNotNone(call.name) + self.assertEqual(type(call.name), _Call) + self.assertEqual(type(call.name().name), _Call) + + def test_parent_attribute_of_call(self): + # bpo-35357: _Call should not disclose any attributes whose names + # may clash with popular ones (such as ".parent") + self.assertIsNotNone(call.parent) + self.assertEqual(type(call.parent), _Call) + self.assertEqual(type(call.parent().parent), _Call) + + + def test_parent_propagation_with_create_autospec(self): + + def foo(a, b): pass + + mock = Mock() + mock.child = create_autospec(foo) + mock.child(1, 2) + + self.assertRaises(TypeError, mock.child, 1) + self.assertEqual(mock.mock_calls, [call.child(1, 2)]) + self.assertIn('mock.child', repr(mock.child.mock)) + + def test_parent_propagation_with_autospec_attach_mock(self): + + def foo(a, b): pass + + parent = Mock() + parent.attach_mock(create_autospec(foo, name='bar'), 'child') + parent.child(1, 2) + + self.assertRaises(TypeError, parent.child, 1) + self.assertEqual(parent.child.mock_calls, [call.child(1, 2)]) + self.assertIn('mock.child', repr(parent.child.mock)) + + + def test_isinstance_under_settrace(self): + # bpo-36593 : __class__ is not set for a class that has __class__ + # property defined when it's used with sys.settrace(trace) set. + # Delete the module to force reimport with tracing function set + # restore the old reference later since there are other tests that are + # dependent on unittest.mock.patch. In testpatch.PatchTest + # test_patch_dict_test_prefix and test_patch_test_prefix not restoring + # causes the objects patched to go out of sync + + old_patch = unittest.mock.patch + + # Directly using __setattr__ on unittest.mock causes current imported + # reference to be updated. Use a lambda so that during cleanup the + # re-imported new reference is updated. + self.addCleanup(lambda patch: setattr(unittest.mock, 'patch', patch), + old_patch) + + with patch.dict('sys.modules'): + del sys.modules['unittest.mock'] + + # This trace will stop coverage being measured ;-) + def trace(frame, event, arg): # pragma: no cover + return trace + + self.addCleanup(sys.settrace, sys.gettrace()) + sys.settrace(trace) + + from unittest.mock import ( + Mock, MagicMock, NonCallableMock, NonCallableMagicMock + ) + + mocks = [ + Mock, MagicMock, NonCallableMock, NonCallableMagicMock, AsyncMock + ] + + for mock in mocks: + obj = mock(spec=Something) + self.assertIsInstance(obj, Something) + if __name__ == '__main__': unittest.main() diff --git a/Lib/unittest/test/testmock/testpatch.py b/Lib/unittest/test/testmock/testpatch.py index fe4ecefd44..e065a2c35f 100644 --- a/Lib/unittest/test/testmock/testpatch.py +++ b/Lib/unittest/test/testmock/testpatch.py @@ -9,6 +9,7 @@ from unittest.test.testmock import support from unittest.test.testmock.support import SomeClass, is_instance +from test.test_importlib.util import uncache from unittest.mock import ( NonCallableMock, CallableMixin, sentinel, MagicMock, Mock, NonCallableMagicMock, patch, _patch, @@ -42,23 +43,24 @@ def __delattr__(self, name): class Foo(object): - def __init__(self, a): - pass - def f(self, a): - pass - def g(self): - pass + def __init__(self, a): pass + def f(self, a): pass + def g(self): pass foo = 'bar' + @staticmethod + def static_method(): pass + + @classmethod + def class_method(cls): pass + class Bar(object): - def a(self): - pass + def a(self): pass foo_name = '%s.Foo' % __name__ -def function(a, b=Foo): - pass +def function(a, b=Foo): pass class Container(object): @@ -103,6 +105,10 @@ def test(): self.assertEqual(Something.attribute, sentinel.Original, "patch not restored") + def test_patchobject_with_string_as_target(self): + msg = "'Something' must be the actual object to be patched, not a str" + with self.assertRaisesRegex(TypeError, msg): + patch.object('Something', 'do_something') def test_patchobject_with_none(self): class Something(object): @@ -361,31 +367,19 @@ def test(): def test_patch_wont_create_by_default(self): - try: + with self.assertRaises(AttributeError): @patch('%s.frooble' % builtin_string, sentinel.Frooble) - def test(): - self.assertEqual(frooble, sentinel.Frooble) + def test(): pass test() - except AttributeError: - pass - else: - self.fail('Patching non existent attributes should fail') - self.assertRaises(NameError, lambda: frooble) def test_patchobject_wont_create_by_default(self): - try: + with self.assertRaises(AttributeError): @patch.object(SomeClass, 'ord', sentinel.Frooble) - def test(): - self.fail('Patching non existent attributes should fail') - + def test(): pass test() - except AttributeError: - pass - else: - self.fail('Patching non existent attributes should fail') self.assertFalse(hasattr(SomeClass, 'ord')) @@ -475,6 +469,9 @@ class Something(object): attribute = sentinel.Original class Foo(object): + + test_class_attr = 'whatever' + def test_method(other_self, mock_something): self.assertEqual(PTModule.something, mock_something, "unpatched") @@ -626,6 +623,13 @@ def test(): self.assertEqual(foo.values, original) + def test_patch_dict_as_context_manager(self): + foo = {'a': 'b'} + with patch.dict(foo, a='c') as patched: + self.assertEqual(patched, {'a': 'c'}) + self.assertEqual(foo, {'a': 'b'}) + + def test_name_preserved(self): foo = {} @@ -633,8 +637,7 @@ def test_name_preserved(self): @patch('%s.SomeClass' % __name__, object(), autospec=True) @patch.object(SomeClass, object()) @patch.dict(foo) - def some_name(): - pass + def some_name(): pass self.assertEqual(some_name.__name__, 'some_name') @@ -645,12 +648,9 @@ def test_patch_with_exception(self): @patch.dict(foo, {'a': 'b'}) def test(): raise NameError('Konrad') - try: + + with self.assertRaises(NameError): test() - except NameError: - pass - else: - self.fail('NameError not raised by test') self.assertEqual(foo, {}) @@ -663,47 +663,21 @@ def test(): test() - def test_patch_descriptor(self): - # would be some effort to fix this - we could special case the - # builtin descriptors: classmethod, property, staticmethod - return - class Nothing(object): - foo = None - - class Something(object): - foo = {} - - @patch.object(Nothing, 'foo', 2) - @classmethod - def klass(cls): - self.assertIs(cls, Something) - - @patch.object(Nothing, 'foo', 2) - @staticmethod - def static(arg): - return arg - - @patch.dict(foo) - @classmethod - def klass_dict(cls): - self.assertIs(cls, Something) - - @patch.dict(foo) - @staticmethod - def static_dict(arg): - return arg + def test_patch_dict_decorator_resolution(self): + # bpo-35512: Ensure that patch with a string target resolves to + # the new dictionary during function call + original = support.target.copy() - # these will raise exceptions if patching descriptors is broken - self.assertEqual(Something.static('f00'), 'f00') - Something.klass() - self.assertEqual(Something.static_dict('f00'), 'f00') - Something.klass_dict() + @patch.dict('unittest.test.testmock.support.target', {'bar': 'BAR'}) + def test(): + self.assertEqual(support.target, {'foo': 'BAZ', 'bar': 'BAR'}) - something = Something() - self.assertEqual(something.static('f00'), 'f00') - something.klass() - self.assertEqual(something.static_dict('f00'), 'f00') - something.klass_dict() + try: + support.target = {'foo': 'BAZ'} + test() + self.assertEqual(support.target, {'foo': 'BAZ'}) + finally: + support.target = original def test_patch_spec_set(self): @@ -754,10 +728,18 @@ def test_patch_start_stop(self): def test_stop_without_start(self): + # bpo-36366: calling stop without start will return None. + patcher = patch(foo_name, 'bar', 3) + self.assertIsNone(patcher.stop()) + + + def test_stop_idempotent(self): + # bpo-36366: calling stop on an already stopped patch will return None. patcher = patch(foo_name, 'bar', 3) - # calling stop without start used to produce a very obscure error - self.assertRaises(RuntimeError, patcher.stop) + patcher.start() + patcher.stop() + self.assertIsNone(patcher.stop()) def test_patchobject_start_stop(self): @@ -897,17 +879,13 @@ def test_patch_dict_keyword_args(self): def test_autospec(self): class Boo(object): - def __init__(self, a): - pass - def f(self, a): - pass - def g(self): - pass + def __init__(self, a): pass + def f(self, a): pass + def g(self): pass foo = 'bar' class Bar(object): - def a(self): - pass + def a(self): pass def _test(mock): mock(1) @@ -997,6 +975,18 @@ def test(mock_function): self.assertEqual(result, 3) + def test_autospec_staticmethod(self): + with patch('%s.Foo.static_method' % __name__, autospec=True) as method: + Foo.static_method() + method.assert_called_once_with() + + + def test_autospec_classmethod(self): + with patch('%s.Foo.class_method' % __name__, autospec=True) as method: + Foo.class_method() + method.assert_called_once_with() + + def test_autospec_with_new(self): patcher = patch('%s.function' % __name__, new=3, autospec=True) self.assertRaises(TypeError, patcher.start) @@ -1266,7 +1256,6 @@ def test(f, foo): def test_patch_multiple_create_mocks_different_order(self): - # bug revealed by Jython! original_f = Foo.f original_g = Foo.g @@ -1443,20 +1432,17 @@ def test_nested_patch_failure(self): @patch.object(Foo, 'g', 1) @patch.object(Foo, 'missing', 1) @patch.object(Foo, 'f', 1) - def thing1(): - pass + def thing1(): pass @patch.object(Foo, 'missing', 1) @patch.object(Foo, 'g', 1) @patch.object(Foo, 'f', 1) - def thing2(): - pass + def thing2(): pass @patch.object(Foo, 'g', 1) @patch.object(Foo, 'f', 1) @patch.object(Foo, 'missing', 1) - def thing3(): - pass + def thing3(): pass for func in thing1, thing2, thing3: self.assertRaises(AttributeError, func) @@ -1475,20 +1461,17 @@ def crasher(): @patch.object(Foo, 'g', 1) @patch.object(Foo, 'foo', new_callable=crasher) @patch.object(Foo, 'f', 1) - def thing1(): - pass + def thing1(): pass @patch.object(Foo, 'foo', new_callable=crasher) @patch.object(Foo, 'g', 1) @patch.object(Foo, 'f', 1) - def thing2(): - pass + def thing2(): pass @patch.object(Foo, 'g', 1) @patch.object(Foo, 'f', 1) @patch.object(Foo, 'foo', new_callable=crasher) - def thing3(): - pass + def thing3(): pass for func in thing1, thing2, thing3: self.assertRaises(NameError, func) @@ -1514,8 +1497,7 @@ def test_patch_multiple_failure(self): patcher.additional_patchers = additionals @patcher - def func(): - pass + def func(): pass self.assertRaises(AttributeError, func) self.assertEqual(Foo.f, original_f) @@ -1543,8 +1525,7 @@ def crasher(): patcher.additional_patchers = additionals @patcher - def func(): - pass + def func(): pass self.assertRaises(NameError, func) self.assertEqual(Foo.f, original_f) @@ -1660,22 +1641,21 @@ def test_mock_calls_with_patch(self): def test_patch_imports_lazily(self): - sys.modules.pop('squizz', None) - p1 = patch('squizz.squozz') self.assertRaises(ImportError, p1.start) - squizz = Mock() - squizz.squozz = 6 - sys.modules['squizz'] = squizz - p1 = patch('squizz.squozz') - squizz.squozz = 3 - p1.start() - p1.stop() - self.assertEqual(squizz.squozz, 3) + with uncache('squizz'): + squizz = Mock() + sys.modules['squizz'] = squizz + squizz.squozz = 6 + p1 = patch('squizz.squozz') + squizz.squozz = 3 + p1.start() + p1.stop() + self.assertEqual(squizz.squozz, 3) - def test_patch_propogrates_exc_on_exit(self): + def test_patch_propagates_exc_on_exit(self): class holder: exc_info = None, None, None @@ -1696,12 +1676,17 @@ def with_custom_patch(target): def test(mock): raise RuntimeError - self.assertRaises(RuntimeError, test) + with uncache('squizz'): + squizz = Mock() + sys.modules['squizz'] = squizz + + self.assertRaises(RuntimeError, test) + self.assertIs(holder.exc_info[0], RuntimeError) self.assertIsNotNone(holder.exc_info[1], - 'exception value not propgated') + 'exception value not propagated') self.assertIsNotNone(holder.exc_info[2], - 'exception traceback not propgated') + 'exception traceback not propagated') def test_create_and_specs(self): @@ -1849,5 +1834,36 @@ def foo(*a, x=0): self.assertEqual(foo(), 1) self.assertEqual(foo(), 0) + def test_dotted_but_module_not_loaded(self): + # This exercises the AttributeError branch of _dot_lookup. + + # make sure it's there + import unittest.test.testmock.support + # now make sure it's not: + with patch.dict('sys.modules'): + del sys.modules['unittest.test.testmock.support'] + del sys.modules['unittest.test.testmock'] + del sys.modules['unittest.test'] + del sys.modules['unittest'] + + # now make sure we can patch based on a dotted path: + @patch('unittest.test.testmock.support.X') + def test(mock): + pass + test() + + + def test_invalid_target(self): + with self.assertRaises(TypeError): + patch('') + + + def test_cant_set_kwargs_when_passing_a_mock(self): + @patch('unittest.test.testmock.support.X', new=object(), x=1) + def test(): pass + with self.assertRaises(TypeError): + test() + + if __name__ == '__main__': unittest.main() diff --git a/Lib/unittest/test/testmock/testsealable.py b/Lib/unittest/test/testmock/testsealable.py index 0e72b32411..59f52338d4 100644 --- a/Lib/unittest/test/testmock/testsealable.py +++ b/Lib/unittest/test/testmock/testsealable.py @@ -3,15 +3,10 @@ class SampleObject: - def __init__(self): - self.attr_sample1 = 1 - self.attr_sample2 = 1 - def method_sample1(self): - pass + def method_sample1(self): pass - def method_sample2(self): - pass + def method_sample2(self): pass class TestSealable(unittest.TestCase): diff --git a/Lib/unittest/test/testmock/testwith.py b/Lib/unittest/test/testmock/testwith.py index a7bee73003..42ebf3898c 100644 --- a/Lib/unittest/test/testmock/testwith.py +++ b/Lib/unittest/test/testmock/testwith.py @@ -10,6 +10,8 @@ something_else = sentinel.SomethingElse +class SampleException(Exception): pass + class WithTest(unittest.TestCase): @@ -20,14 +22,10 @@ def test_with_statement(self): def test_with_statement_exception(self): - try: + with self.assertRaises(SampleException): with patch('%s.something' % __name__, sentinel.Something2): self.assertEqual(something, sentinel.Something2, "unpatched") - raise Exception('pow') - except Exception: - pass - else: - self.fail("patch swallowed exception") + raise SampleException() self.assertEqual(something, sentinel.Something) @@ -126,6 +124,19 @@ def test_dict_context_manager(self): self.assertEqual(foo, {}) + def test_double_patch_instance_method(self): + class C: + def f(self): pass + + c = C() + + with patch.object(c, 'f', autospec=True) as patch1: + with patch.object(c, 'f', autospec=True) as patch2: + c.f() + self.assertEqual(patch2.call_count, 1) + self.assertEqual(patch1.call_count, 0) + c.f() + self.assertEqual(patch1.call_count, 1) class TestMockOpen(unittest.TestCase): @@ -188,6 +199,7 @@ def test_read_data(self): def test_readline_data(self): # Check that readline will return all the lines from the fake file + # And that once fully consumed, readline will return an empty string. mock = mock_open(read_data='foo\nbar\nbaz\n') with patch('%s.open' % __name__, mock, create=True): h = open('bar') @@ -197,6 +209,7 @@ def test_readline_data(self): self.assertEqual(line1, 'foo\n') self.assertEqual(line2, 'bar\n') self.assertEqual(line3, 'baz\n') + self.assertEqual(h.readline(), '') # Check that we properly emulate a file that doesn't end in a newline mock = mock_open(read_data='foo') @@ -204,8 +217,36 @@ def test_readline_data(self): h = open('bar') result = h.readline() self.assertEqual(result, 'foo') + self.assertEqual(h.readline(), '') + def test_dunder_iter_data(self): + # Check that dunder_iter will return all the lines from the fake file. + mock = mock_open(read_data='foo\nbar\nbaz\n') + with patch('%s.open' % __name__, mock, create=True): + h = open('bar') + lines = [l for l in h] + self.assertEqual(lines[0], 'foo\n') + self.assertEqual(lines[1], 'bar\n') + self.assertEqual(lines[2], 'baz\n') + self.assertEqual(h.readline(), '') + with self.assertRaises(StopIteration): + next(h) + + def test_next_data(self): + # Check that next will correctly return the next available + # line and plays well with the dunder_iter part. + mock = mock_open(read_data='foo\nbar\nbaz\n') + with patch('%s.open' % __name__, mock, create=True): + h = open('bar') + line1 = next(h) + line2 = next(h) + lines = [l for l in h] + self.assertEqual(line1, 'foo\n') + self.assertEqual(line2, 'bar\n') + self.assertEqual(lines[0], 'baz\n') + self.assertEqual(h.readline(), '') + def test_readlines_data(self): # Test that emulating a file that ends in a newline character works mock = mock_open(read_data='foo\nbar\nbaz\n') @@ -257,7 +298,12 @@ def test_mock_open_read_with_argument(self): # for mocks returned by mock_open some_data = 'foo\nbar\nbaz' mock = mock_open(read_data=some_data) - self.assertEqual(mock().read(10), some_data) + self.assertEqual(mock().read(10), some_data[:10]) + self.assertEqual(mock().read(10), some_data[:10]) + + f = mock() + self.assertEqual(f.read(10), some_data[:10]) + self.assertEqual(f.read(10), some_data[10:]) def test_interleaved_reads(self): From acee5c73ed5431c1251f02acec9fe245288cec4a Mon Sep 17 00:00:00 2001 From: Padraic Fanning Date: Sun, 5 Sep 2021 22:06:53 -0400 Subject: [PATCH 3/4] Skip hanging test --- Lib/unittest/test/test_async_case.py | 1 + 1 file changed, 1 insertion(+) diff --git a/Lib/unittest/test/test_async_case.py b/Lib/unittest/test/test_async_case.py index d01864b693..fedd4707ac 100644 --- a/Lib/unittest/test/test_async_case.py +++ b/Lib/unittest/test/test_async_case.py @@ -190,6 +190,7 @@ async def on_async_cleanup(self, val): 'async_cleanup 2', 'sync_cleanup 1']) + @unittest.skip("TODO: RUSTPYTHON, hangs") def test_base_exception_from_async_method(self): events = [] class Test(unittest.IsolatedAsyncioTestCase): From 98a91abb8911f7dfbc2e7765f167b245d8a71bda Mon Sep 17 00:00:00 2001 From: Padraic Fanning Date: Sun, 5 Sep 2021 22:09:09 -0400 Subject: [PATCH 4/4] Mark erroring/failing tests --- Lib/unittest/test/test_runner.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/Lib/unittest/test/test_runner.py b/Lib/unittest/test/test_runner.py index 58ac98c2d0..5e47811710 100644 --- a/Lib/unittest/test/test_runner.py +++ b/Lib/unittest/test/test_runner.py @@ -403,6 +403,8 @@ class Module(object): self.assertEqual(str(e.exception), 'CleanUpExc') self.assertEqual(unittest.case._module_cleanups, []) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_addModuleCleanup_arg_errors(self): cleanups = [] def cleanup(*args, **kwargs): @@ -562,6 +564,8 @@ def tearDownClass(cls): 'tearDownModule', 'cleanup_good']) self.assertEqual(unittest.case._module_cleanups, []) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_addClassCleanup_arg_errors(self): cleanups = [] def cleanup(*args, **kwargs): @@ -584,6 +588,8 @@ def testNothing(self): self.assertEqual(cleanups, [((1, 2), {'function': 3, 'cls': 4})]) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_addCleanup_arg_errors(self): cleanups = [] def cleanup(*args, **kwargs):