Skip to content

Commit 9e86c60

Browse files
authored
unittest.mock: use ParamSpec in patch (#10325)
Fixes #10324
1 parent 7114aec commit 9e86c60

File tree

3 files changed

+34
-4
lines changed

3 files changed

+34
-4
lines changed

stdlib/unittest/mock.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@ from collections.abc import Awaitable, Callable, Coroutine, Iterable, Mapping, S
33
from contextlib import _GeneratorContextManager
44
from types import TracebackType
55
from typing import Any, Generic, TypeVar, overload
6-
from typing_extensions import Final, Literal, Self, TypeAlias
6+
from typing_extensions import Final, Literal, ParamSpec, Self, TypeAlias
77

88
_T = TypeVar("_T")
99
_TT = TypeVar("_TT", bound=type[Any])
1010
_R = TypeVar("_R")
1111
_F = TypeVar("_F", bound=Callable[..., Any])
1212
_AF = TypeVar("_AF", bound=Callable[..., Coroutine[Any, Any, Any]])
13+
_P = ParamSpec("_P")
1314

1415
if sys.version_info >= (3, 8):
1516
__all__ = (
@@ -234,7 +235,7 @@ class _patch(Generic[_T]):
234235
@overload
235236
def __call__(self, func: _TT) -> _TT: ...
236237
@overload
237-
def __call__(self, func: Callable[..., _R]) -> Callable[..., _R]: ...
238+
def __call__(self, func: Callable[_P, _R]) -> Callable[_P, _R]: ...
238239
if sys.version_info >= (3, 8):
239240
def decoration_helper(
240241
self, patched: _patch[Any], args: Sequence[Any], keywargs: Any

test_cases/stdlib/check_unittest.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from __future__ import annotations
22

33
import unittest
4+
from collections.abc import Callable
45
from datetime import datetime, timedelta
56
from decimal import Decimal
67
from fractions import Fraction
8+
from typing_extensions import assert_type
9+
from unittest.mock import Mock, patch
710

811
case = unittest.TestCase()
912

@@ -86,3 +89,29 @@ def __gt__(self, other: Bacon) -> bool:
8689
case.assertGreater(Spam(), Eggs()) # type: ignore
8790
case.assertGreater(Ham(), Bacon()) # type: ignore
8891
case.assertGreater(Bacon(), Ham()) # type: ignore
92+
93+
###
94+
# Tests for mock.patch
95+
###
96+
97+
98+
@patch("sys.exit", new=Mock())
99+
def f(i: int) -> str:
100+
return "asdf"
101+
102+
103+
assert_type(f(1), str)
104+
f("a") # type: ignore
105+
106+
107+
@patch("sys.exit", new=Mock())
108+
class TestXYZ(unittest.TestCase):
109+
attr: int = 5
110+
111+
@staticmethod
112+
def method() -> int:
113+
return 123
114+
115+
116+
assert_type(TestXYZ.attr, int)
117+
assert_type(TestXYZ.method, Callable[[], int])

tests/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,12 @@ def testcase_dir_from_package_name(package_name: str) -> Path:
111111

112112

113113
def get_all_testcase_directories() -> list[PackageInfo]:
114-
testcase_directories = [PackageInfo("stdlib", Path("test_cases"))]
114+
testcase_directories: list[PackageInfo] = []
115115
for package_name in os.listdir("stubs"):
116116
potential_testcase_dir = testcase_dir_from_package_name(package_name)
117117
if potential_testcase_dir.is_dir():
118118
testcase_directories.append(PackageInfo(package_name, potential_testcase_dir))
119-
return sorted(testcase_directories)
119+
return [PackageInfo("stdlib", Path("test_cases"))] + sorted(testcase_directories)
120120

121121

122122
# ====================================================================

0 commit comments

Comments
 (0)