diff --git a/.github/workflows/mypy_primer.yml b/.github/workflows/mypy_primer.yml index ee868484751e..532e77a0cacb 100644 --- a/.github/workflows/mypy_primer.yml +++ b/.github/workflows/mypy_primer.yml @@ -67,6 +67,7 @@ jobs: --debug \ --additional-flags="--debug-serialize" \ --output concise \ + --show-speed-regression \ | tee diff_${{ matrix.shard-index }}.txt ) || [ $? -eq 1 ] - if: ${{ matrix.shard-index == 0 }} diff --git a/mypy/checker.py b/mypy/checker.py index 7d0b41c516e1..19bf24327dc8 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -455,7 +455,7 @@ def check_first_pass(self) -> None: Deferred functions will be processed by check_second_pass(). """ self.recurse_into_functions = True - with state.strict_optional_set(self.options.strict_optional): + with state.strict_optional_set(self.options.strict_optional), state.type_checker_set(self): self.errors.set_file( self.path, self.tree.fullname, scope=self.tscope, options=self.options ) @@ -496,7 +496,7 @@ def check_second_pass( This goes through deferred nodes, returning True if there were any. """ self.recurse_into_functions = True - with state.strict_optional_set(self.options.strict_optional): + with state.strict_optional_set(self.options.strict_optional), state.type_checker_set(self): if not todo and not self.deferred_nodes: return False self.errors.set_file( diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 1a76372d4731..b2e443c82e80 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -96,6 +96,7 @@ def __init__( is_self: bool = False, rvalue: Expression | None = None, suppress_errors: bool = False, + preserve_type_var_ids: bool = False, ) -> None: self.is_lvalue = is_lvalue self.is_super = is_super @@ -112,6 +113,10 @@ def __init__( assert is_lvalue self.rvalue = rvalue self.suppress_errors = suppress_errors + # This attribute is only used to preserve old protocol member access logic. + # It is needed to avoid infinite recursion in cases involving self-referential + # generic methods, see find_member() for details. Do not use for other purposes! + self.preserve_type_var_ids = preserve_type_var_ids def named_type(self, name: str) -> Instance: return self.chk.named_type(name) @@ -142,6 +147,7 @@ def copy_modified( no_deferral=self.no_deferral, rvalue=self.rvalue, suppress_errors=self.suppress_errors, + preserve_type_var_ids=self.preserve_type_var_ids, ) if self_type is not None: mx.self_type = self_type @@ -231,8 +237,6 @@ def analyze_member_access( def _analyze_member_access( name: str, typ: Type, mx: MemberContext, override_info: TypeInfo | None = None ) -> Type: - # TODO: This and following functions share some logic with subtypes.find_member; - # consider refactoring. typ = get_proper_type(typ) if isinstance(typ, Instance): return analyze_instance_member_access(name, typ, mx, override_info) @@ -355,7 +359,8 @@ def analyze_instance_member_access( return AnyType(TypeOfAny.special_form) assert isinstance(method.type, Overloaded) signature = method.type - signature = freshen_all_functions_type_vars(signature) + if not mx.preserve_type_var_ids: + signature = freshen_all_functions_type_vars(signature) if not method.is_static: signature = check_self_arg( signature, mx.self_type, method.is_class, mx.context, name, mx.msg @@ -928,7 +933,8 @@ def analyze_var( def expand_without_binding( typ: Type, var: Var, itype: Instance, original_itype: Instance, mx: MemberContext ) -> Type: - typ = freshen_all_functions_type_vars(typ) + if not mx.preserve_type_var_ids: + typ = freshen_all_functions_type_vars(typ) typ = expand_self_type_if_needed(typ, mx, var, original_itype) expanded = expand_type_by_instance(typ, itype) freeze_all_type_vars(expanded) @@ -938,7 +944,8 @@ def expand_without_binding( def expand_and_bind_callable( functype: FunctionLike, var: Var, itype: Instance, name: str, mx: MemberContext ) -> Type: - functype = freshen_all_functions_type_vars(functype) + if not mx.preserve_type_var_ids: + functype = freshen_all_functions_type_vars(functype) typ = get_proper_type(expand_self_type(var, functype, mx.original_type)) assert isinstance(typ, FunctionLike) typ = check_self_arg(typ, mx.self_type, var.is_classmethod, mx.context, name, mx.msg) @@ -1033,10 +1040,12 @@ def f(self: S) -> T: ... return functype else: selfarg = get_proper_type(item.arg_types[0]) - # This level of erasure matches the one in checker.check_func_def(), - # better keep these two checks consistent. - if subtypes.is_subtype( + # This matches similar special-casing in bind_self(), see more details there. + self_callable = name == "__call__" and isinstance(selfarg, CallableType) + if self_callable or subtypes.is_subtype( dispatched_arg_type, + # This level of erasure matches the one in checker.check_func_def(), + # better keep these two checks consistent. erase_typevars(erase_to_bound(selfarg)), # This is to work around the fact that erased ParamSpec and TypeVarTuple # callables are not always compatible with non-erased ones both ways. @@ -1197,15 +1206,10 @@ def analyze_class_attribute_access( is_classmethod = (is_decorated and cast(Decorator, node.node).func.is_class) or ( isinstance(node.node, SYMBOL_FUNCBASE_TYPES) and node.node.is_class ) - is_staticmethod = (is_decorated and cast(Decorator, node.node).func.is_static) or ( - isinstance(node.node, SYMBOL_FUNCBASE_TYPES) and node.node.is_static - ) t = get_proper_type(t) if isinstance(t, FunctionLike) and is_classmethod: t = check_self_arg(t, mx.self_type, False, mx.context, name, mx.msg) - result = add_class_tvars( - t, isuper, is_classmethod, is_staticmethod, mx.self_type, original_vars=original_vars - ) + result = add_class_tvars(t, isuper, is_classmethod, mx, original_vars=original_vars) # __set__ is not called on class objects. if not mx.is_lvalue: result = analyze_descriptor_access(result, mx) @@ -1337,8 +1341,7 @@ def add_class_tvars( t: ProperType, isuper: Instance | None, is_classmethod: bool, - is_staticmethod: bool, - original_type: Type, + mx: MemberContext, original_vars: Sequence[TypeVarLikeType] | None = None, ) -> Type: """Instantiate type variables during analyze_class_attribute_access, @@ -1356,9 +1359,6 @@ class B(A[str]): pass isuper: Current instance mapped to the superclass where method was defined, this is usually done by map_instance_to_supertype() is_classmethod: True if this method is decorated with @classmethod - is_staticmethod: True if this method is decorated with @staticmethod - original_type: The value of the type B in the expression B.foo() or the corresponding - component in case of a union (this is used to bind the self-types) original_vars: Type variables of the class callable on which the method was accessed Returns: Expanded method type with added type variables (when needed). @@ -1379,11 +1379,11 @@ class B(A[str]): pass # (i.e. appear in the return type of the class object on which the method was accessed). if isinstance(t, CallableType): tvars = original_vars if original_vars is not None else [] - t = freshen_all_functions_type_vars(t) + if not mx.preserve_type_var_ids: + t = freshen_all_functions_type_vars(t) if is_classmethod: - t = bind_self(t, original_type, is_classmethod=True) - if is_classmethod or is_staticmethod: - assert isuper is not None + t = bind_self(t, mx.self_type, is_classmethod=True) + if isuper is not None: t = expand_type_by_instance(t, isuper) freeze_all_type_vars(t) return t.copy_modified(variables=list(tvars) + list(t.variables)) @@ -1392,14 +1392,7 @@ class B(A[str]): pass [ cast( CallableType, - add_class_tvars( - item, - isuper, - is_classmethod, - is_staticmethod, - original_type, - original_vars=original_vars, - ), + add_class_tvars(item, isuper, is_classmethod, mx, original_vars=original_vars), ) for item in t.items ] diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 031f86e7dfff..f17d3ecfcd83 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -4,7 +4,6 @@ from typing import Final, TypeVar, cast, overload from mypy.nodes import ARG_STAR, FakeInfo, Var -from mypy.state import state from mypy.types import ( ANY_STRATEGY, AnyType, @@ -544,6 +543,8 @@ def remove_trivial(types: Iterable[Type]) -> list[Type]: * Remove everything else if there is an `object` * Remove strict duplicate types """ + from mypy.state import state + removed_none = False new_types = [] all_types = set() diff --git a/mypy/messages.py b/mypy/messages.py index 2e07d7f63498..d18e9917a095 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -2220,8 +2220,13 @@ def report_protocol_problems( exp = get_proper_type(exp) got = get_proper_type(got) setter_suffix = " setter type" if is_lvalue else "" - if not isinstance(exp, (CallableType, Overloaded)) or not isinstance( - got, (CallableType, Overloaded) + if ( + not isinstance(exp, (CallableType, Overloaded)) + or not isinstance(got, (CallableType, Overloaded)) + # If expected type is a type object, it means it is a nested class. + # Showing constructor signature in errors would be confusing in this case, + # since we don't check the signature, only subclassing of type objects. + or exp.is_type_obj() ): self.note( "{}: expected{} {}, got {}".format( diff --git a/mypy/plugin.py b/mypy/plugin.py index 39841d5b907a..de075866d613 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -119,14 +119,13 @@ class C: pass from __future__ import annotations from abc import abstractmethod -from typing import Any, Callable, NamedTuple, TypeVar +from typing import TYPE_CHECKING, Any, Callable, NamedTuple, TypeVar from mypy_extensions import mypyc_attr, trait from mypy.errorcodes import ErrorCode from mypy.lookup import lookup_fully_qualified from mypy.message_registry import ErrorMessage -from mypy.messages import MessageBuilder from mypy.nodes import ( ArgKind, CallExpr, @@ -138,7 +137,6 @@ class C: pass TypeInfo, ) from mypy.options import Options -from mypy.tvar_scope import TypeVarLikeScope from mypy.types import ( CallableType, FunctionLike, @@ -149,6 +147,10 @@ class C: pass UnboundType, ) +if TYPE_CHECKING: + from mypy.messages import MessageBuilder + from mypy.tvar_scope import TypeVarLikeScope + @trait class TypeAnalyzerPluginInterface: diff --git a/mypy/state.py b/mypy/state.py index a3055bf6b208..41b8b75be127 100644 --- a/mypy/state.py +++ b/mypy/state.py @@ -4,16 +4,19 @@ from contextlib import contextmanager from typing import Final +from mypy.checker_shared import TypeCheckerSharedApi + # These are global mutable state. Don't add anything here unless there's a very # good reason. -class StrictOptionalState: +class SubtypeState: # Wrap this in a class since it's faster that using a module-level attribute. - def __init__(self, strict_optional: bool) -> None: - # Value varies by file being processed + def __init__(self, strict_optional: bool, type_checker: TypeCheckerSharedApi | None) -> None: + # Values vary by file being processed self.strict_optional = strict_optional + self.type_checker = type_checker @contextmanager def strict_optional_set(self, value: bool) -> Iterator[None]: @@ -24,6 +27,15 @@ def strict_optional_set(self, value: bool) -> Iterator[None]: finally: self.strict_optional = saved + @contextmanager + def type_checker_set(self, value: TypeCheckerSharedApi) -> Iterator[None]: + saved = self.type_checker + self.type_checker = value + try: + yield + finally: + self.type_checker = saved + -state: Final = StrictOptionalState(strict_optional=True) +state: Final = SubtypeState(strict_optional=True, type_checker=None) find_occurrences: tuple[str, str] | None = None diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 71b8b0ba59f5..226c39bb2933 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -26,6 +26,7 @@ COVARIANT, INVARIANT, VARIANCE_NOT_READY, + Context, Decorator, FuncBase, OverloadedFuncDef, @@ -717,8 +718,7 @@ def visit_callable_type(self, left: CallableType) -> bool: elif isinstance(right, Instance): if right.type.is_protocol and "__call__" in right.type.protocol_members: # OK, a callable can implement a protocol with a `__call__` member. - # TODO: we should probably explicitly exclude self-types in this case. - call = find_member("__call__", right, left, is_operator=True) + call = find_member("__call__", right, right, is_operator=True) assert call is not None if self._is_subtype(left, call): if len(right.type.protocol_members) == 1: @@ -954,7 +954,7 @@ def visit_overloaded(self, left: Overloaded) -> bool: if isinstance(right, Instance): if right.type.is_protocol and "__call__" in right.type.protocol_members: # same as for CallableType - call = find_member("__call__", right, left, is_operator=True) + call = find_member("__call__", right, right, is_operator=True) assert call is not None if self._is_subtype(left, call): if len(right.type.protocol_members) == 1: @@ -1261,14 +1261,87 @@ def find_member( is_operator: bool = False, class_obj: bool = False, is_lvalue: bool = False, +) -> Type | None: + type_checker = state.type_checker + if type_checker is None: + # Unfortunately, there are many scenarios where someone calls is_subtype() before + # type checking phase. In this case we fallback to old (incomplete) logic. + # TODO: reduce number of such cases (e.g. semanal_typeargs, post-semanal plugins). + return find_member_simple( + name, itype, subtype, is_operator=is_operator, class_obj=class_obj, is_lvalue=is_lvalue + ) + + # We don't use ATTR_DEFINED error code below (since missing attributes can cause various + # other error codes), instead we perform quick node lookup with all the fallbacks. + info = itype.type + sym = info.get(name) + node = sym.node if sym else None + if not node: + name_not_found = True + if ( + name not in ["__getattr__", "__setattr__", "__getattribute__"] + and not is_operator + and not class_obj + and itype.extra_attrs is None # skip ModuleType.__getattr__ + ): + for method_name in ("__getattribute__", "__getattr__"): + method = info.get_method(method_name) + if method and method.info.fullname != "builtins.object": + name_not_found = False + break + if name_not_found: + if info.fallback_to_any or class_obj and info.meta_fallback_to_any: + return AnyType(TypeOfAny.special_form) + if itype.extra_attrs and name in itype.extra_attrs.attrs: + return itype.extra_attrs.attrs[name] + return None + + from mypy.checkmember import ( + MemberContext, + analyze_class_attribute_access, + analyze_instance_member_access, + ) + + mx = MemberContext( + is_lvalue=is_lvalue, + is_super=False, + is_operator=is_operator, + original_type=itype, + self_type=subtype, + context=Context(), # all errors are filtered, but this is a required argument + chk=type_checker, + suppress_errors=True, + # This is needed to avoid infinite recursion in situations involving protocols like + # class P(Protocol[T]): + # def combine(self, other: P[S]) -> P[Tuple[T, S]]: ... + # Normally we call freshen_all_functions_type_vars() during attribute access, + # to avoid type variable id collisions, but for protocols this means we can't + # use the assumption stack, that will grow indefinitely. + # TODO: find a cleaner solution that doesn't involve massive perf impact. + preserve_type_var_ids=True, + ) + with type_checker.msg.filter_errors(filter_deprecated=True): + if class_obj: + fallback = itype.type.metaclass_type or mx.named_type("builtins.type") + return analyze_class_attribute_access(itype, name, mx, mcs_fallback=fallback) + else: + return analyze_instance_member_access(name, itype, mx, info) + + +def find_member_simple( + name: str, + itype: Instance, + subtype: Type, + *, + is_operator: bool = False, + class_obj: bool = False, + is_lvalue: bool = False, ) -> Type | None: """Find the type of member by 'name' in 'itype's TypeInfo. Find the member type after applying type arguments from 'itype', and binding 'self' to 'subtype'. Return None if member was not found. """ - # TODO: this code shares some logic with checkmember.analyze_member_access, - # consider refactoring. info = itype.type method = info.get_method(name) if method: diff --git a/mypy/types.py b/mypy/types.py index 41a958ae93cc..a922f64a47a8 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -32,7 +32,6 @@ SymbolNode, ) from mypy.options import Options -from mypy.state import state from mypy.util import IdMapper T = TypeVar("T") @@ -2979,6 +2978,8 @@ def accept(self, visitor: TypeVisitor[T]) -> T: def relevant_items(self) -> list[Type]: """Removes NoneTypes from Unions when strict Optional checking is off.""" + from mypy.state import state + if state.strict_optional: return self.items else: diff --git a/test-data/unit/check-python312.test b/test-data/unit/check-python312.test index 2f3d5e08dab3..54864b24ea40 100644 --- a/test-data/unit/check-python312.test +++ b/test-data/unit/check-python312.test @@ -246,6 +246,7 @@ class Invariant[T]: inv1: Invariant[float] = Invariant[int]([1]) # E: Incompatible types in assignment (expression has type "Invariant[int]", variable has type "Invariant[float]") inv2: Invariant[int] = Invariant[float]([1]) # E: Incompatible types in assignment (expression has type "Invariant[float]", variable has type "Invariant[int]") [builtins fixtures/tuple.pyi] +[typing fixtures/typing-full.pyi] [case testPEP695InferVarianceCalculateOnDemand] class Covariant[T]: @@ -1635,8 +1636,8 @@ class M[T: (int, str)](NamedTuple): c: M[int] d: M[str] e: M[bool] # E: Value of type variable "T" of "M" cannot be "bool" - [builtins fixtures/tuple.pyi] +[typing fixtures/typing-full.pyi] [case testPEP695GenericTypedDict] from typing import TypedDict diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 47c8a71ba0e3..5d6706c35308 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2780,7 +2780,7 @@ class TD(TypedDict): reveal_type(TD.__iter__) # N: Revealed type is "def (typing._TypedDict) -> typing.Iterator[builtins.str]" reveal_type(TD.__annotations__) # N: Revealed type is "typing.Mapping[builtins.str, builtins.object]" -reveal_type(TD.values) # N: Revealed type is "def (self: typing.Mapping[T`1, T_co`2]) -> typing.Iterable[T_co`2]" +reveal_type(TD.values) # N: Revealed type is "def (self: typing.Mapping[builtins.str, builtins.object]) -> typing.Iterable[builtins.object]" [builtins fixtures/dict-full.pyi] [typing fixtures/typing-typeddict.pyi]