From 8dc56f0ba6109536bc7795ddfcd02d789f14a287 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Fri, 18 Feb 2022 19:42:18 +0000 Subject: [PATCH 1/2] stubtest: error if a function is async at runtime but not in the stub (and vice versa) --- mypy/stubtest.py | 41 +++++++++++++++++++++++++++--- mypy/test/teststubtest.py | 53 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 4 deletions(-) diff --git a/mypy/stubtest.py b/mypy/stubtest.py index e67992ace5f8..1ed9f422a749 100644 --- a/mypy/stubtest.py +++ b/mypy/stubtest.py @@ -676,6 +676,18 @@ def _verify_signature( yield 'runtime does not have **kwargs argument "{}"'.format(stub.varkw.variable.name) +def _verify_coroutine( + stub: nodes.FuncItem, runtime: Any, *, runtime_is_coroutine: bool +) -> Optional[str]: + if stub.is_coroutine: + if not runtime_is_coroutine: + return "is a coroutine function in the stub, but not at runtime" + else: + if runtime_is_coroutine: + return "is a coroutine function at runtime, but not in the stub" + return None + + @verify.register(nodes.FuncItem) def verify_funcitem( stub: nodes.FuncItem, runtime: MaybeMissing[Any], object_path: List[str] @@ -693,19 +705,40 @@ def verify_funcitem( yield Error(object_path, "is inconsistent, " + message, stub, runtime) signature = safe_inspect_signature(runtime) + runtime_is_coroutine = inspect.iscoroutinefunction(runtime) + + if signature: + stub_sig = Signature.from_funcitem(stub) + runtime_sig = Signature.from_inspect_signature(signature) + runtime_sig_desc = f'{"async " if runtime_is_coroutine else ""}def {signature}' + else: + runtime_sig_desc = None + + coroutine_mismatch_error = _verify_coroutine( + stub, + runtime, + runtime_is_coroutine=runtime_is_coroutine + ) + + if coroutine_mismatch_error is not None: + yield Error( + object_path, + coroutine_mismatch_error, + stub, + runtime, + runtime_desc=runtime_sig_desc + ) + if not signature: return - stub_sig = Signature.from_funcitem(stub) - runtime_sig = Signature.from_inspect_signature(signature) - for message in _verify_signature(stub_sig, runtime_sig, function_name=stub.name): yield Error( object_path, "is inconsistent, " + message, stub, runtime, - runtime_desc="def " + str(signature), + runtime_desc=runtime_sig_desc, ) diff --git a/mypy/test/teststubtest.py b/mypy/test/teststubtest.py index 2852299548ed..78ae82b058cd 100644 --- a/mypy/test/teststubtest.py +++ b/mypy/test/teststubtest.py @@ -28,6 +28,34 @@ def use_tmp_dir() -> Iterator[None]: TEST_MODULE_NAME = "test_module" + +stubtest_typing_stub = """ +Any = object() + +class _SpecialForm: + def __getitem__(self, typeargs: Any) -> object: ... + +Callable: _SpecialForm = ... +Generic: _SpecialForm = ... + +class TypeVar: + def __init__(self, name, covariant: bool = ..., contravariant: bool = ...) -> None: ... + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) +_K = TypeVar("_K") +_V = TypeVar("_V") +_S = TypeVar("_S", contravariant=True) +_R = TypeVar("_R", covariant=True) + +class Coroutine(Generic[_T_co, _S, _R]): ... +class Iterable(Generic[_T_co]): ... +class Mapping(Generic[_K, _V]): ... +class Sequence(Iterable[_T_co]): ... +class Tuple(Sequence[_T_co]): ... +def overload(func: _T) -> _T: ... +""" + stubtest_builtins_stub = """ from typing import Generic, Mapping, Sequence, TypeVar, overload @@ -66,6 +94,8 @@ def run_stubtest( with use_tmp_dir(): with open("builtins.pyi", "w") as f: f.write(stubtest_builtins_stub) + with open("typing.pyi", "w") as f: + f.write(stubtest_typing_stub) with open("{}.pyi".format(TEST_MODULE_NAME), "w") as f: f.write(stub) with open("{}.py".format(TEST_MODULE_NAME), "w") as f: @@ -172,6 +202,29 @@ class X: error="X.mistyped_var", ) + @collect_cases + def test_coroutines(self) -> Iterator[Case]: + yield Case( + stub="async def foo() -> int: ...", + runtime="def foo(): return 5", + error="foo", + ) + yield Case( + stub="def bar() -> int: ...", + runtime="async def bar(): return 5", + error="bar", + ) + yield Case( + stub="def baz() -> int: ...", + runtime="def baz(): return 5", + error=None, + ) + yield Case( + stub="async def bingo() -> int: ...", + runtime="async def bingo(): return 5", + error=None, + ) + @collect_cases def test_arg_name(self) -> Iterator[Case]: yield Case( From cbdc9ba3b318d2c9f26a15deda9c7e373e0c545b Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Sat, 19 Feb 2022 00:51:21 +0000 Subject: [PATCH 2/2] Improve error messages --- mypy/stubtest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mypy/stubtest.py b/mypy/stubtest.py index 1ed9f422a749..546ea96dd9a0 100644 --- a/mypy/stubtest.py +++ b/mypy/stubtest.py @@ -681,10 +681,10 @@ def _verify_coroutine( ) -> Optional[str]: if stub.is_coroutine: if not runtime_is_coroutine: - return "is a coroutine function in the stub, but not at runtime" + return 'is an "async def" function in the stub, but not at runtime' else: if runtime_is_coroutine: - return "is a coroutine function at runtime, but not in the stub" + return 'is an "async def" function at runtime, but not in the stub' return None