From fae9016daf0165f1a119076c0891d03b1225bd49 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Mon, 4 Jun 2018 09:56:53 -0700 Subject: [PATCH 01/21] Add more robust support for detecting partially overlapping types This pull request adds more robust support for detecting partially overlapping types. Specifically, it detects overlaps with... 1. TypedDicts 2. Tuples 3. Unions 4. Typevars 5. Generic types containing variations of the above --- mypy/checker.py | 129 ++++++-- mypy/meet.py | 211 ++++++++++--- mypy/messages.py | 6 + mypy/subtypes.py | 1 + mypy/types.py | 3 + test-data/unit/check-overloading.test | 414 +++++++++++++++++++++++++- 6 files changed, 693 insertions(+), 71 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index c95742285338..01b0459e7b1c 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5,7 +5,7 @@ from contextlib import contextmanager from typing import ( - Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator + Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, Iterable ) from mypy.errors import Errors, report_internal_error @@ -30,7 +30,7 @@ Type, AnyType, CallableType, FunctionLike, Overloaded, TupleType, TypedDictType, Instance, NoneTyp, strip_type, TypeType, TypeOfAny, UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef, - true_only, false_only, function_type, is_named_instance, union_items, + true_only, false_only, function_type, is_named_instance, union_items, TypeQuery ) from mypy.sametypes import is_same_type, is_same_types from mypy.messages import MessageBuilder, make_inferred_type_note @@ -52,7 +52,7 @@ from mypy.join import join_types from mypy.treetransform import TransformVisitor from mypy.binder import ConditionalTypeBinder, get_declaration -from mypy.meet import is_overlapping_types, is_partially_overlapping_types +from mypy.meet import is_overlapping_erased_types, is_partially_overlapping from mypy.options import Options from mypy.plugin import Plugin, CheckerPluginInterface from mypy.sharedparse import BINARY_MAGIC_METHODS @@ -468,6 +468,9 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: if is_unsafe_overlapping_overload_signatures(sig1, sig2): self.msg.overloaded_signatures_overlap( i + 1, i + j + 2, item.func) + elif is_unsafe_partially_overlapping_overload_signatures(sig1, sig2): + self.msg.overloaded_signatures_partial_overlap( + i + 1, i + j + 2, item.func) if impl_type is not None: assert defn.impl is not None @@ -3063,7 +3066,7 @@ def find_isinstance_check(self, node: Expression else: optional_type, comp_type = second_type, first_type optional_expr = node.operands[1] - if is_overlapping_types(optional_type, comp_type): + if is_overlapping_erased_types(optional_type, comp_type): return {optional_expr: remove_optional(optional_type)}, {} elif node.operators in [['in'], ['not in']]: expr = node.operands[0] @@ -3074,7 +3077,7 @@ def find_isinstance_check(self, node: Expression right_type.type.fullname() != 'builtins.object')) if (right_type and right_ok and is_optional(left_type) and literal(expr) == LITERAL_TYPE and not is_literal_none(expr) and - is_overlapping_types(left_type, right_type)): + is_overlapping_erased_types(left_type, right_type)): if node.operators == ['in']: return {expr: remove_optional(left_type)}, {} if node.operators == ['not in']: @@ -3414,7 +3417,7 @@ def conditional_type_map(expr: Expression, and is_proper_subtype(current_type, proposed_type)): # Expression is always of one of the types in proposed_type_ranges return {}, None - elif not is_overlapping_types(current_type, proposed_type): + elif not is_overlapping_erased_types(current_type, proposed_type): # Expression is never of any type in proposed_type_ranges return None, {} else: @@ -3630,7 +3633,7 @@ def are_argument_counts_overlapping(t: CallableType, s: CallableType) -> bool: def is_unsafe_overlapping_overload_signatures(signature: CallableType, other: CallableType) -> bool: - """Check if two overloaded function signatures may be unsafely overlapping. + """Check if two overloaded signatures are unsafely overlapping, ignoring partial overlaps. We consider two functions 's' and 't' to be unsafely overlapping both if of the following are true: @@ -3638,29 +3641,111 @@ def is_unsafe_overlapping_overload_signatures(signature: CallableType, 1. s's parameters are all more precise or partially overlapping with t's 2. s's return type is NOT a subtype of t's. + This function will perform a modified version of the above two checks: + we do not check for partial overlaps. This lets us vary our error messages + depending on the severity of the overlap. + + See 'is_unsafe_partially_overlapping_overload_signatures' for the full checks. + Assumes that 'signature' appears earlier in the list of overload alternatives then 'other' and that their argument counts are overlapping. """ - # TODO: Handle partially overlapping parameter types - # - # For example, the signatures "f(x: Union[A, B]) -> int" and "f(x: Union[B, C]) -> str" - # is unsafe: the parameter types are partially overlapping. - # - # To fix this, we need to either modify meet.is_overlapping_types or add a new - # function and use "is_more_precise(...) or is_partially_overlapping(...)" for the is_compat - # checks. - # - # (We already have a rudimentary implementation of 'is_partially_overlapping', but it only - # attempts to handle the obvious cases -- see its docstring for more info.) + return is_callable_compatible(signature, other, + is_compat=is_more_precise, + is_compat_return=lambda l, r: not is_subtype(l, r), + ignore_return=False, + check_args_covariantly=True, + allow_partial_overlap=True) + + +def is_unsafe_partially_overlapping_overload_signatures(signature: CallableType, + other: CallableType) -> bool: + """Check if two overloaded signatures are unsafely overlapping, ignoring partial overlaps. + + We consider two functions 's' and 't' to be unsafely overlapping both if + of the following are true: + + 1. s's parameters are all more precise or partially overlapping with t's + 2. s's return type is NOT a subtype of t's. + + Assumes that 'signature' appears earlier in the list of overload + alternatives then 'other' and that their argument counts are overlapping. + """ def is_more_precise_or_partially_overlapping(t: Type, s: Type) -> bool: - return is_more_precise(t, s) or is_partially_overlapping_types(t, s) + return is_more_precise(t, s) or is_partially_overlapping(t, s) - return is_callable_compatible(signature, other, + # Try detaching callables from the containing class so we can try unifying + # free type variables against each other. + # + # This lets us identify cases where the two signatures use completely + # incompatible types -- e.g. see the testOverloadingInferUnionReturnWithMixedTypevars + # test case. + signature = detach_callable(signature) + other = detach_callable(other) + + # Note: We repeat this check twice in both directions due to a slight + # asymmetry in 'is_callable_compatible'. When checking for partial overlaps, + # we attempt to unify 'signature' and 'other' both against each other. + # + # If 'signature' cannot be unified with 'other', we end early. However, + # if 'other' cannot be modified with 'signature', the function continues + # using the older version of 'other'. + # + # This discrepancy is unfortunately difficult to get rid of, so we repeat the + # checks twice in both directions for now. + return (is_callable_compatible(signature, other, is_compat=is_more_precise_or_partially_overlapping, is_compat_return=lambda l, r: not is_subtype(l, r), + ignore_return=False, check_args_covariantly=True, - allow_partial_overlap=True) + allow_partial_overlap=True) or + is_callable_compatible(other, signature, + is_compat=is_more_precise_or_partially_overlapping, + is_compat_return=lambda l, r: not is_subtype(r, l), + ignore_return=False, + check_args_covariantly=False, + allow_partial_overlap=True)) + + +def detach_callable(typ: CallableType) -> CallableType: + """Ensures that the callable's type variables are 'detached' and independent of the context + + A callable normally keeps track of the type variables it uses within its 'variables' field. + However, if the callable is from a method and that method is using a class type variable, + the callable will not keep track of that type variable since it belongs to the class. + + This function will traverse the callable and find all used type vars and add them to the + variables field if it isn't already present. + + The caller can then unify on all type variables whether or not the callable is originally + from a class or not.""" + type_vars = typ.accept(TypeVarExtractor()) + new_variables = [] + for var in type_vars: + new_variables.append(TypeVarDef( + name=var.name, + fullname=var.fullname, + id=var.id, + values=var.values, + upper_bound=var.upper_bound, + variance=var.variance, + )) + return typ.copy_modified(variables=new_variables) + + +class TypeVarExtractor(TypeQuery[Set[TypeVarType]]): + def __init__(self) -> None: + super().__init__(self._merge) + + def _merge(self, iter: Iterable[Set[TypeVarType]]) -> Set[TypeVarType]: + out = set() + for item in iter: + out.update(item) + return out + + def visit_type_var(self, t: TypeVarType) -> Set[TypeVarType]: + return {t} def overload_can_never_match(signature: CallableType, other: CallableType) -> bool: @@ -3719,7 +3804,7 @@ def is_unsafe_overlapping_operator_signatures(signature: Type, other: Type) -> b for i in range(min_args): t1 = signature.arg_types[i] t2 = other.arg_types[i] - if not is_overlapping_types(t1, t2): + if not is_overlapping_erased_types(t1, t2): return False # All arguments types for the smallest common argument count are # overlapping => the signature is overlapping. The overlapping is diff --git a/mypy/meet.py b/mypy/meet.py index 32772f4801c1..2abdc26819fb 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -5,9 +5,13 @@ from mypy.types import ( Type, AnyType, TypeVisitor, UnboundType, NoneTyp, TypeVarType, Instance, CallableType, TupleType, TypedDictType, ErasedType, UnionType, PartialType, DeletedType, - UninhabitedType, TypeType, TypeOfAny + UninhabitedType, TypeType, TypeOfAny, Overloaded, is_named_instance, ) -from mypy.subtypes import is_equivalent, is_subtype, is_protocol_implementation +from mypy.subtypes import ( + is_equivalent, is_subtype, is_protocol_implementation, is_callable_compatible, +) +from mypy.sametypes import is_same_type +from mypy.maptype import map_instance_to_supertype from mypy import experiments @@ -32,7 +36,7 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type: if isinstance(declared, UnionType): return UnionType.make_simplified_union([narrow_declared_type(x, narrowed) for x in declared.relevant_items()]) - elif not is_overlapping_types(declared, narrowed, use_promotions=True): + elif not is_overlapping_erased_types(declared, narrowed, use_promotions=True): if experiments.STRICT_OPTIONAL: return UninhabitedType() else: @@ -49,41 +53,175 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type: return narrowed -def is_partially_overlapping_types(t: Type, s: Type) -> bool: - """Returns 'true' if the two types are partially, but not completely, overlapping. +def get_possible_variants(typ: Type) -> List[Type]: + if isinstance(typ, TypeVarType): + if len(typ.values) > 0: + return typ.values + else: + return [typ.upper_bound] + elif isinstance(typ, UnionType): + return typ.items + [typ] + elif isinstance(typ, Overloaded): + # Note: doing 'return typ.items() + [typ]' makes mypy + # infer a too-specific return type of List[CallableType] + out = [] # type: List[Type] + out.extend(typ.items()) + out.append(typ) + return out + else: + return [typ] - NOTE: This function is only a partial implementation. - It exists mostly so that overloads correctly handle partial - overlaps for the more obvious cases. - """ - # Are unions partially overlapping? - if isinstance(t, UnionType) and isinstance(s, UnionType): - t_set = set(t.items) - s_set = set(s.items) - num_same = len(t_set.intersection(s_set)) - num_diff = len(t_set.symmetric_difference(s_set)) - return num_same > 0 and num_diff > 0 - - # Are tuples partially overlapping? - tup_overlap = is_overlapping_tuples(t, s, use_promotions=True) - if tup_overlap is not None and tup_overlap: - return tup_overlap - - def is_object(t: Type) -> bool: - return isinstance(t, Instance) and t.type.fullname() == 'builtins.object' +def is_partially_overlapping(left: Type, right: Type) -> bool: + # We should never encounter these types + illegal_types = (UnboundType, PartialType, ErasedType, DeletedType) + if isinstance(left, illegal_types) or isinstance(right, illegal_types): + raise AssertionError( + "Encountered unexpected types: left={} right={}".format(type(left), type(right))) - # Is either 't' or 's' an unrestricted TypeVar? - if isinstance(t, TypeVarType) and is_object(t.upper_bound) and len(t.values) == 0: + # 'Any' may or may not be partially overlapping with the other type + if isinstance(left, AnyType) or isinstance(right, AnyType): return True - if isinstance(s, TypeVarType) and is_object(s.upper_bound) and len(s.values) == 0: + # There are a number of different types that can have "variants" or are "union-like". + # For example: + # + # - Unions + # - TypeVars with value restrictions + # - Overloads + # + # We extract the component variants into a list. Types with a single variant are + # stored in a singleton list. + # + # The logic to check whether any of these types are overlapping are essentially the + # same: we obtain the list of possible variants and make sure there exists + # items in both the intersection and the difference. + + left_possible = get_possible_variants(left) + right_possible = get_possible_variants(right) + + # However, TypeVars get special treatment. It's sufficient to check to see + # if at least one overlap exists. + # + # The TypeVar checks must come before any single-variant types. + if isinstance(left, TypeVarType) or isinstance(right, TypeVarType): + for l in left_possible: + for r in right_possible: + if is_overlapping_types(l, r): + return True + return False + + # Next, we handle single-variant types that may be inherently partially overlapping: + # + # - TypedDicts + # - Tuples + # + # If we cannot identify a partial overlap and end early, we degrade these two types + # into their (Instance) fallbacks. + + if isinstance(left, TypedDictType) and isinstance(right, TypedDictType): + # If one is a subtype of the other, no partial overlap + if is_subtype(left, right) or is_subtype(right, left): + return False + + # All required keys in left are present and overlapping with something in right + for key in left.required_keys: + if key not in right.items: + return False + if not is_overlapping_types(left.items[key], right.items[key]): + return False + + # Repeat check in the other direction + for key in right.required_keys: + if key not in left.items: + return False + if not is_overlapping_types(left.items[key], right.items[key]): + return False + + # The presence of any additional optional keys does not affect whether the two + # TypedDicts are partially overlapping: the dicts would be overlapping if the + # keys happened to be missing. return True + elif isinstance(left, TypedDictType): + left = left.fallback + elif isinstance(right, TypedDictType): + right = right.fallback + + if is_tuple(left) and is_tuple(right): + left = adjust_tuple(left, right) or left + right = adjust_tuple(right, left) or right + assert isinstance(left, TupleType) + assert isinstance(right, TupleType) + if len(left.items) != len(right.items): + return False + return all(is_overlapping_types(l, r) for l, r in zip(left.items, right.items)) + elif isinstance(left, TupleType): + left = left.fallback + elif isinstance(right, TupleType): + right = right.fallback + + # Next, we handle single-variant types that are not inherently partially overlapping, + # but do require custom logic to inspect. + + if isinstance(left, TypeType) and isinstance(right, TypeType): + return is_partially_overlapping(left.item, right.item) + elif isinstance(left, TypeType) or isinstance(right, TypeType): + # TODO: Can Callable[[...], T] and Type[T] be partially overlapping? + return False + if isinstance(left, CallableType) and isinstance(right, CallableType): + return is_callable_compatible(left, right, + is_compat=is_overlapping_types, + ignore_pos_arg_names=True, + allow_partial_overlap=True) + elif isinstance(left, CallableType): + left = left.fallback + elif isinstance(right, CallableType): + right = right.fallback + + # Next, we check if left and right are instances + + if isinstance(left, Instance) and isinstance(right, Instance): + # Two unrelated types cannot be partially overlapping: they're disjoint. + # We don't need to handle promotions because promotable types either + # are overlapping or are not -- they can't be partially overlapping. + if left.type.has_base(right.type.fullname()): + left = map_instance_to_supertype(left, right.type) + elif right.type.has_base(left.type.fullname()): + right = map_instance_to_supertype(right, left.type) + else: + return False + + if len(left.args) == len(right.args): + for left_arg, right_arg in zip(left.args, right.args): + if is_partially_overlapping(left_arg, right_arg): + return True + return False + + # We handle all remaining types here: in particular, types like + # UnionType, Overloaded, NoneTyp, and UninhabitedType. + + found_same = False + found_diff = False + for a in left_possible: + for b in right_possible: + if is_same_type(a, b): + found_same = True + break + else: + found_diff = True + + # Return early if possible + if found_same and found_diff: + return True return False -def is_overlapping_types(t: Type, s: Type, use_promotions: bool = False) -> bool: +def is_overlapping_types(t: Type, s: Type) -> bool: + return is_subtype(t, s) or is_subtype(s, t) or is_partially_overlapping(t, s) + + +def is_overlapping_erased_types(t: Type, s: Type, use_promotions: bool = False) -> bool: """Can a value of type t be a value of type s, or vice versa? Note that this effectively checks against erased types, since type @@ -130,10 +268,10 @@ class C(A, B): ... s = s.as_anonymous().fallback if isinstance(t, UnionType): - return any(is_overlapping_types(item, s) + return any(is_overlapping_erased_types(item, s) for item in t.relevant_items()) if isinstance(s, UnionType): - return any(is_overlapping_types(t, item) + return any(is_overlapping_erased_types(t, item) for item in s.relevant_items()) # We must check for TupleTypes before Instances, since Tuple[A, ...] @@ -150,9 +288,9 @@ class C(A, B): ... # Consider cases like int vs float to be overlapping where # there is only a type promotion relationship but not proper # subclassing. - if t.type._promote and is_overlapping_types(t.type._promote, s): + if t.type._promote and is_overlapping_erased_types(t.type._promote, s): return True - if s.type._promote and is_overlapping_types(s.type._promote, t): + if s.type._promote and is_overlapping_erased_types(s.type._promote, t): return True if t.type in s.type.mro or s.type in t.type.mro: return True @@ -163,7 +301,7 @@ class C(A, B): ... return False if isinstance(t, TypeType) and isinstance(s, TypeType): # If both types are TypeType, compare their inner types. - return is_overlapping_types(t.item, s.item, use_promotions) + return is_overlapping_erased_types(t.item, s.item, use_promotions) elif isinstance(t, TypeType) or isinstance(s, TypeType): # If exactly only one of t or s is a TypeType, check if one of them # is an `object` or a `type` and otherwise assume no overlap. @@ -189,7 +327,7 @@ def is_overlapping_tuples(t: Type, s: Type, use_promotions: bool) -> Optional[bo if isinstance(t, TupleType) or isinstance(s, TupleType): if isinstance(t, TupleType) and isinstance(s, TupleType): if t.length() == s.length(): - if all(is_overlapping_types(ti, si, use_promotions) + if all(is_overlapping_erased_types(ti, si, use_promotions) for ti, si in zip(t.items, s.items)): return True # TupleType and non-tuples do not overlap @@ -206,6 +344,11 @@ def adjust_tuple(left: Type, r: Type) -> Optional[TupleType]: return None +def is_tuple(typ: Type) -> bool: + return (isinstance(typ, TupleType) + or (isinstance(typ, Instance) and typ.type.fullname() == 'builtins.tuple')) + + class TypeMeetVisitor(TypeVisitor[Type]): def __init__(self, s: Type) -> None: self.s = s diff --git a/mypy/messages.py b/mypy/messages.py index a9276bde6ec4..31c85dcad967 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -958,6 +958,12 @@ def overloaded_signatures_overlap(self, index1: int, index2: int, context: Conte self.fail('Overloaded function signatures {} and {} overlap with ' 'incompatible return types'.format(index1, index2), context) + def overloaded_signatures_partial_overlap(self, index1: int, index2: int, + context: Context) -> None: + self.fail('Overloaded function signatures {} and {} '.format(index1, index2) + + 'are partially overlapping: the two signatures may return ' + + 'incompatible types given certain calls', context) + def overloaded_signature_will_never_match(self, index1: int, index2: int, context: Context) -> None: self.fail( diff --git a/mypy/subtypes.py b/mypy/subtypes.py index d38cdf67b810..ec8c6cbbce75 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -920,6 +920,7 @@ def unify_generic_callable(type: CallableType, target: CallableType, c = mypy.constraints.infer_constraints( type.ret_type, target.ret_type, return_constraint_direction) constraints.extend(c) + type_var_ids = [tvar.id for tvar in type.variables] inferred_vars = mypy.solve.solve_constraints(type_var_ids, constraints) if None in inferred_vars: diff --git a/mypy/types.py b/mypy/types.py index f9d5d3c23a34..632ba8862ea5 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1187,6 +1187,9 @@ def deserialize(cls, data: JsonDict) -> 'TypedDictType': set(data['required_keys']), Instance.deserialize(data['fallback'])) + def has_optional_keys(self) -> bool: + return any(key not in self.required_keys for key in self.items) + def is_anonymous(self) -> bool: return self.fallback.type.fullname() == 'typing.Mapping' diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 0d3a6e444300..1d049ec0fe67 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -334,8 +334,8 @@ def bar(x: Union[T, C]) -> Union[T, int]: [builtins fixtures/isinstancelist.pyi] -[case testTypeCheckOverloadImplementationTypeVarDifferingUsage] -from typing import overload, Union, List, TypeVar +[case testTypeCheckOverloadImplementationTypeVarDifferingUsage1] +from typing import overload, Union, List, TypeVar, Generic T = TypeVar('T') @@ -348,6 +348,50 @@ def foo(t: Union[List[T], T]) -> T: return t[0] else: return t + +class Wrapper(Generic[T]): + @overload + def foo(self, t: List[T]) -> T: ... + @overload + def foo(self, t: T) -> T: ... + def foo(self, t: Union[List[T], T]) -> T: + if isinstance(t, list): + return t[0] + else: + return t +[builtins fixtures/isinstancelist.pyi] + +[case testTypeCheckOverloadImplementationTypeVarDifferingUsage2] +from typing import overload, Union, List, TypeVar, Generic + +T = TypeVar('T') + +# Note: this is unsafe when T = object +@overload +def foo(t: List[T], s: T) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +@overload +def foo(t: T, s: T) -> str: ... +def foo(t, s): pass + +# TODO: Why are we getting a different error message here? +# Shouldn't we be getting the same error message? +class Wrapper(Generic[T]): + @overload + def foo(self, t: List[T], s: T) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls + @overload + def foo(self, t: T, s: T) -> str: ... + def foo(self, t, s): pass + +class Dummy(Generic[T]): pass + +# Same root issue: why does the additional constraint bound T <: T +# cause the constraint solver to not infer T = object like it did in the +# first example? +@overload +def bar(d: Dummy[T], t: List[T], s: T) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def bar(d: Dummy[T], t: T, s: T) -> str: ... +def bar(d: Dummy[T], t, s): pass [builtins fixtures/isinstancelist.pyi] [case testTypeCheckOverloadedFunctionBody] @@ -1677,7 +1721,7 @@ def r(x: Any) -> Any:... @overload def g(x: A) -> A: ... @overload -def g(x: Tuple[A1, ...]) -> B: ... # E: Overloaded function signatures 2 and 3 overlap with incompatible return types +def g(x: Tuple[A1, ...]) -> B: ... # E: Overloaded function signatures 2 and 3 are partially overlapping: the two signatures may return incompatible types given certain calls @overload def g(x: Tuple[A, A]) -> C: ... @overload @@ -1864,7 +1908,7 @@ def foo(x: None, y: None) -> str: ... # E: Overloaded function signatures 1 and def foo(x: T, y: T) -> int: ... def foo(x): ... -# TODO: We should allow this; T can't be bound to two distinct types +# What if 'T' is 'object'? @overload def bar(x: None, y: int) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types @overload @@ -1872,17 +1916,17 @@ def bar(x: T, y: T) -> int: ... def bar(x, y): ... class Wrapper(Generic[T]): - # TODO: This should be an error + # TODO: Why do these have different error messages? @overload - def foo(self, x: None, y: None) -> str: ... + def foo(self, x: None, y: None) -> str: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls @overload - def foo(self, x: T, y: None) -> str: ... + def foo(self, x: T, y: None) -> int: ... def foo(self, x): ... @overload - def bar(self, x: None, y: int) -> str: ... + def bar(self, x: None, y: int) -> str: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls @overload - def bar(self, x: T, y: T) -> str: ... + def bar(self, x: T, y: T) -> int: ... def bar(self, x, y): ... [case testOverloadFlagsPossibleMatches] @@ -2382,7 +2426,7 @@ class C: ... class D: ... @overload -def f(x: Union[A, B]) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +def f(x: Union[A, B]) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls @overload def f(x: Union[B, C]) -> str: ... def f(x): ... @@ -2390,26 +2434,366 @@ def f(x): ... @overload def g(x: Union[A, B]) -> int: ... @overload -def g(x: Union[C, D]) -> str: ... +def g(x: Union[B, C]) -> int: ... +def g(x): ... + +@overload +def h(x: Union[A, B]) -> int: ... +@overload +def h(x: Union[C, D]) -> str: ... +def h(x): ... + +@overload +def i(x: Union[A, B]) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +@overload +def i(x: Union[A, B, C]) -> str: ... +def i(x): ... + +[case testOverloadWithPartiallyOverlappingUnionsNested] +from typing import overload, Union, List + +class A: ... +class B: ... +class C: ... +class D: ... + +@overload +def f(x: List[Union[A, B]]) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f(x: List[Union[B, C]]) -> str: ... +def f(x): ... + +@overload +def g(x: List[Union[A, B]]) -> int: ... +@overload +def g(x: List[Union[B, C]]) -> int: ... def g(x): ... @overload -def h(x: Union[A, B]) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +def h(x: List[Union[A, B]]) -> int: ... @overload -def h(x: Union[A, B, C]) -> str: ... +def h(x: List[Union[C, D]]) -> str: ... def h(x): ... +@overload +def i(x: List[Union[A, B]]) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def i(x: List[Union[A, B, C]]) -> str: ... +def i(x): ... + +[builtins fixtures/list.pyi] + [case testOverloadPartialOverlapWithUnrestrictedTypeVar] from typing import TypeVar, overload T = TypeVar('T') @overload -def f(x: int) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +def f(x: int) -> str: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls @overload def f(x: T) -> T: ... def f(x): ... +@overload +def g(x: int) -> int: ... +@overload +def g(x: T) -> T: ... +def g(x): ... + +[case testOverloadPartialOverlapWithUnrestrictedTypeVarNested] +from typing import TypeVar, overload, List + +T = TypeVar('T') + +@overload +def f1(x: List[int]) -> str: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f1(x: List[T]) -> T: ... +def f1(x): ... + +@overload +def f2(x: List[int]) -> List[str]: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f2(x: List[T]) -> List[T]: ... +def f2(x): ... + +@overload +def g1(x: List[int]) -> int: ... +@overload +def g1(x: List[T]) -> T: ... +def g1(x): ... + +@overload +def g2(x: List[int]) -> List[int]: ... +@overload +def g2(x: List[T]) -> List[T]: ... +def g2(x): ... + +[builtins fixtures/list.pyi] + +[case testOverloadPartialOverlapWithUnrestrictedTypeVarInClass] +from typing import TypeVar, overload, Generic + +T = TypeVar('T') + +class Wrapper(Generic[T]): + @overload + def f(self, x: int) -> str: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls + @overload + def f(self, x: T) -> T: ... + def f(self, x): ... + + # TODO: This shouldn't trigger an error message. + # Related to testTypeCheckOverloadImplementationTypeVarDifferingUsage2? + @overload + def g(self, x: int) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls + @overload + def g(self, x: T) -> T: ... + def g(self, x): ... + +[case testOverloadPartialOverlapWithUnrestrictedTypeVarInClassNested] +from typing import TypeVar, overload, Generic, List + +T = TypeVar('T') + +class Wrapper(Generic[T]): + @overload + def f1(self, x: List[int]) -> str: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls + @overload + def f1(self, x: List[T]) -> T: ... + def f1(self, x): ... + + @overload + def f2(self, x: List[int]) -> List[str]: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls + @overload + def f2(self, x: List[T]) -> List[T]: ... + def f2(self, x): ... + + # TODO: This shouldn't trigger an error message. + # Related to testTypeCheckOverloadImplementationTypeVarDifferingUsage2? + @overload + def g1(self, x: List[int]) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls + @overload + def g1(self, x: List[T]) -> T: ... + def g1(self, x): ... + + @overload + def g2(self, x: List[int]) -> List[int]: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls + @overload + def g2(self, x: List[T]) -> List[T]: ... + def g2(self, x): ... + +[builtins fixtures/list.pyi] + +[case testOverloadTypedDictDifferentRequiredKeysMeansDictsAreDisjoint] +from typing import overload +from mypy_extensions import TypedDict + +A = TypedDict('A', {'x': int, 'y': int}) +B = TypedDict('B', {'x': int, 'y': str}) + +@overload +def f(x: A) -> int: ... +@overload +def f(x: B) -> str: ... +def f(x): pass +[builtins fixtures/dict.pyi] + +[case testOverloadedTypedDictPartiallyOverlappingRequiredKeys] +from typing import overload, Union +from mypy_extensions import TypedDict + +A = TypedDict('A', {'x': int, 'y': Union[int, str]}) +B = TypedDict('B', {'x': int, 'y': Union[str, float]}) + +@overload +def f(x: A) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f(x: B) -> str: ... +def f(x): pass + +@overload +def g(x: A) -> int: ... +@overload +def g(x: B) -> object: ... +def g(x): pass +[builtins fixtures/dict.pyi] + +[case testOverloadedTypedDictFullyNonTotalDictsAreAlwaysPartiallyOverlapping] +from typing import overload +from mypy_extensions import TypedDict + +A = TypedDict('A', {'x': int, 'y': str}, total=False) +B = TypedDict('B', {'a': bool}, total=False) +C = TypedDict('C', {'x': str, 'y': int}, total=False) + +@overload +def f(x: A) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f(x: B) -> str: ... +def f(x): pass + +@overload +def g(x: A) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def g(x: C) -> str: ... +def g(x): pass +[builtins fixtures/dict.pyi] + +[case testOverloadedTotalAndNonTotalTypedDictsCanPartiallyOverlap] +from typing import overload, Union +from mypy_extensions import TypedDict + +A = TypedDict('A', {'x': int, 'y': str}) +B = TypedDict('B', {'x': Union[int, str], 'y': str, 'z': int}, total=False) + +@overload +def f1(x: A) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f1(x: B) -> str: ... +def f1(x): pass + +@overload +def f2(x: B) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f2(x: A) -> str: ... +def f2(x): pass + +[builtins fixtures/dict.pyi] + +[case testOverloadedTypedDictsWithSomeOptionalKeysArePartiallyOverlapping] +from typing import overload, Union +from mypy_extensions import TypedDict + +class A(TypedDict): + x: int + y: int + +class B(TypedDict, total=False): + z: str + +class C(TypedDict, total=False): + z: int + +@overload +def f(x: B) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f(x: C) -> str: ... +def f(x): pass + +[builtins fixtures/dict.pyi] + +[case testOverloadedPartiallyOverlappingInheritedTypes1] +from typing import overload, List, Union, TypeVar, Generic + +class A: pass +class B: pass +class C: pass + +T = TypeVar('T') + +class ListSubclass(List[T]): pass +class Unrelated(Generic[T]): pass + +@overload +def f(x: List[Union[A, B]]) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f(x: ListSubclass[Union[B, C]]) -> str: ... +def f(x): pass + +@overload +def g(x: List[Union[A, B]]) -> int: ... +@overload +def g(x: Unrelated[Union[B, C]]) -> str: ... +def g(x): pass + +[builtins fixtures/list.pyi] + +[case testOverloadedPartiallyOverlappingInheritedTypes2] +from typing import overload, List, Union + +class A: pass +class B: pass +class C: pass + +class ListSubclass(List[Union[B, C]]): pass + +@overload +def f(x: List[Union[A, B]]) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f(x: ListSubclass) -> str: ... +def f(x): pass + +[builtins fixtures/list.pyi] + +[case testOverloadedPartiallyOverlappingInheritedTypes3] +from typing import overload, Union, Dict, TypeVar + +class A: pass +class B: pass +class C: pass + +S = TypeVar('S') + +class DictSubclass(Dict[str, S]): pass + +@overload +def f(x: Dict[str, Union[A, B]]) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f(x: DictSubclass[Union[B, C]]) -> str: ... +def f(x): pass + +[builtins fixtures/dict.pyi] + +[case testOverloadedPartiallyOverlappingTypeVarsAndUnion] +from typing import overload, TypeVar, Union + +class A: pass +class B: pass +class C: pass + +S = TypeVar('S', A, B) + +@overload +def f(x: S) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f(x: Union[B, C]) -> str: ... +def f(x): pass + +@overload +def g(x: Union[B, C]) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def g(x: S) -> str: ... +def g(x): pass + +[case testOverloadPartiallyOverlappingTypeVarsIdentical] +from typing import overload, TypeVar, Union + +T = TypeVar('T') + +class A: pass +class B: pass +class C: pass + +@overload +def f(x: T, y: T, z: Union[A, B]) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f(x: T, y: T, z: Union[B, C]) -> str: ... +def f(x, y, z): pass + +[case testOverloadedPartiallyOverlappingCallables] +from typing import overload, Union, Callable + +class A: pass +class B: pass +class C: pass + +@overload +def f(x: Callable[[Union[A, B]], int]) -> int: ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls +@overload +def f(x: Callable[[Union[B, C]], int]) -> str: ... +def f(x): pass + [case testOverloadNotConfusedForProperty] from typing import overload @@ -3302,7 +3686,7 @@ T = TypeVar('T') class FakeAttribute(Generic[T]): @overload - def dummy(self, instance: None, owner: Type[T]) -> 'FakeAttribute[T]': ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types + def dummy(self, instance: None, owner: Type[T]) -> 'FakeAttribute[T]': ... # E: Overloaded function signatures 1 and 2 are partially overlapping: the two signatures may return incompatible types given certain calls @overload def dummy(self, instance: T, owner: Type[T]) -> int: ... def dummy(self, instance: Optional[T], owner: Type[T]) -> Union['FakeAttribute[T]', int]: ... From 1f17affd64a35a8cc6daa0070d2285a2a27b8361 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Fri, 6 Jul 2018 01:04:28 -0700 Subject: [PATCH 02/21] WIP commit This commit is a WIP which I am pushing now for testing purposes. I will rebase/push another commit with more polishing/more exposition a bit later. --- mypy/checker.py | 6 +- mypy/meet.py | 204 +++++++++++++----------------------------------- 2 files changed, 58 insertions(+), 152 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 01b0459e7b1c..879dbd888079 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -52,7 +52,7 @@ from mypy.join import join_types from mypy.treetransform import TransformVisitor from mypy.binder import ConditionalTypeBinder, get_declaration -from mypy.meet import is_overlapping_erased_types, is_partially_overlapping +from mypy.meet import is_overlapping_erased_types, is_overlapping_types from mypy.options import Options from mypy.plugin import Plugin, CheckerPluginInterface from mypy.sharedparse import BINARY_MAGIC_METHODS @@ -3417,7 +3417,7 @@ def conditional_type_map(expr: Expression, and is_proper_subtype(current_type, proposed_type)): # Expression is always of one of the types in proposed_type_ranges return {}, None - elif not is_overlapping_erased_types(current_type, proposed_type): + elif not is_overlapping_types(current_type, proposed_type): # Expression is never of any type in proposed_type_ranges return None, {} else: @@ -3673,7 +3673,7 @@ def is_unsafe_partially_overlapping_overload_signatures(signature: CallableType, alternatives then 'other' and that their argument counts are overlapping. """ def is_more_precise_or_partially_overlapping(t: Type, s: Type) -> bool: - return is_more_precise(t, s) or is_partially_overlapping(t, s) + return is_more_precise(t, s) or is_overlapping_types(t, s) # Try detaching callables from the containing class so we can try unifying # free type variables against each other. diff --git a/mypy/meet.py b/mypy/meet.py index 2abdc26819fb..d46eeabc2e39 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -10,7 +10,7 @@ from mypy.subtypes import ( is_equivalent, is_subtype, is_protocol_implementation, is_callable_compatible, ) -from mypy.sametypes import is_same_type +from mypy.erasetype import erase_type from mypy.maptype import map_instance_to_supertype from mypy import experiments @@ -36,7 +36,7 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type: if isinstance(declared, UnionType): return UnionType.make_simplified_union([narrow_declared_type(x, narrowed) for x in declared.relevant_items()]) - elif not is_overlapping_erased_types(declared, narrowed, use_promotions=True): + elif not is_overlapping_types(declared, narrowed): if experiments.STRICT_OPTIONAL: return UninhabitedType() else: @@ -60,29 +60,38 @@ def get_possible_variants(typ: Type) -> List[Type]: else: return [typ.upper_bound] elif isinstance(typ, UnionType): - return typ.items + [typ] + return typ.items elif isinstance(typ, Overloaded): - # Note: doing 'return typ.items() + [typ]' makes mypy + # Note: doing 'return typ.items()' makes mypy # infer a too-specific return type of List[CallableType] out = [] # type: List[Type] out.extend(typ.items()) - out.append(typ) return out else: return [typ] -def is_partially_overlapping(left: Type, right: Type) -> bool: - # We should never encounter these types +def is_overlapping_types(left: Type, right: Type) -> bool: + """Can a value of type 'left' also be of type 'right' or vice-versa?""" + # We should never encounter these types, but if we do, we handle + # them in the same way we handle 'Any'. + illegal_types = (UnboundType, PartialType, ErasedType, DeletedType) if isinstance(left, illegal_types) or isinstance(right, illegal_types): - raise AssertionError( - "Encountered unexpected types: left={} right={}".format(type(left), type(right))) + # TODO: Replace this with an 'assert False'. + return True - # 'Any' may or may not be partially overlapping with the other type + # 'Any' may or may not be overlapping with the other type if isinstance(left, AnyType) or isinstance(right, AnyType): return True + # We check for complete overlaps first as a general-purpose failsafe. + # If this check fails, we start checking to see if there exists a + # *partial* overlap between types. + + if is_subtype(left, right) or is_subtype(right, left): + return True + # There are a number of different types that can have "variants" or are "union-like". # For example: # @@ -95,15 +104,17 @@ def is_partially_overlapping(left: Type, right: Type) -> bool: # # The logic to check whether any of these types are overlapping are essentially the # same: we obtain the list of possible variants and make sure there exists - # items in both the intersection and the difference. + # items in the intersection. left_possible = get_possible_variants(left) right_possible = get_possible_variants(right) - # However, TypeVars get special treatment. It's sufficient to check to see - # if at least one overlap exists. + # We start by checking TypeVars first: this is because in some of the checks + # below, it's convenient to just return early in certain cases. # - # The TypeVar checks must come before any single-variant types. + # If we were to defer checking TypeVars to down below, that would end up + # causing issues since the TypeVars would never have the opportunity to + # try binding to the relevant types. if isinstance(left, TypeVarType) or isinstance(right, TypeVarType): for l in left_possible: for r in right_possible: @@ -111,6 +122,14 @@ def is_partially_overlapping(left: Type, right: Type) -> bool: return True return False + # Now that we've finished handling TypeVars, we're free to end early + # if one one of the types is None and we're running in strict-optional + # mode. (We must perform this check after the TypeVar checks because + # a TypeVar could be bound to None, for example.) + if experiments.STRICT_OPTIONAL: + if isinstance(left, NoneTyp) != isinstance(right, NoneTyp): + return False + # Next, we handle single-variant types that may be inherently partially overlapping: # # - TypedDicts @@ -120,10 +139,6 @@ def is_partially_overlapping(left: Type, right: Type) -> bool: # into their (Instance) fallbacks. if isinstance(left, TypedDictType) and isinstance(right, TypedDictType): - # If one is a subtype of the other, no partial overlap - if is_subtype(left, right) or is_subtype(right, left): - return False - # All required keys in left are present and overlapping with something in right for key in left.required_keys: if key not in right.items: @@ -164,7 +179,7 @@ def is_partially_overlapping(left: Type, right: Type) -> bool: # but do require custom logic to inspect. if isinstance(left, TypeType) and isinstance(right, TypeType): - return is_partially_overlapping(left.item, right.item) + return is_overlapping_types(left.item, right.item) elif isinstance(left, TypeType) or isinstance(right, TypeType): # TODO: Can Callable[[...], T] and Type[T] be partially overlapping? return False @@ -182,6 +197,11 @@ def is_partially_overlapping(left: Type, right: Type) -> bool: # Next, we check if left and right are instances if isinstance(left, Instance) and isinstance(right, Instance): + if left.type.is_protocol and is_protocol_implementation(right, left): + return True + if right.type.is_protocol and is_protocol_implementation(left, right): + return True + # Two unrelated types cannot be partially overlapping: they're disjoint. # We don't need to handle promotions because promotable types either # are overlapping or are not -- they can't be partially overlapping. @@ -194,146 +214,32 @@ def is_partially_overlapping(left: Type, right: Type) -> bool: if len(left.args) == len(right.args): for left_arg, right_arg in zip(left.args, right.args): - if is_partially_overlapping(left_arg, right_arg): + if is_overlapping_types(left_arg, right_arg): return True + return False # We handle all remaining types here: in particular, types like # UnionType, Overloaded, NoneTyp, and UninhabitedType. + # + # We deliberately skip singleton variant types to avoid + # infinitely recursing. - found_same = False - found_diff = False - for a in left_possible: - for b in right_possible: - if is_same_type(a, b): - found_same = True - break - else: - found_diff = True - - # Return early if possible - if found_same and found_diff: - return True - return False - - -def is_overlapping_types(t: Type, s: Type) -> bool: - return is_subtype(t, s) or is_subtype(s, t) or is_partially_overlapping(t, s) - - -def is_overlapping_erased_types(t: Type, s: Type, use_promotions: bool = False) -> bool: - """Can a value of type t be a value of type s, or vice versa? - - Note that this effectively checks against erased types, since type - variables are erased at runtime and the overlapping check is based - on runtime behavior. The exception is protocol types, it is not safe, - but convenient and is an opt-in behavior. - - If use_promotions is True, also consider type promotions (int and - float would only be overlapping if it's True). - - This does not consider multiple inheritance. For example, A and B in - the following example are not considered overlapping, even though - via C they can be overlapping: + if len(left_possible) >= 1 or len(right_possible) >= 1: + for a in left_possible: + for b in right_possible: + if is_overlapping_types(a, b): + return True - class A: ... - class B: ... - class C(A, B): ... + # We ought to have handled every case by now: we conclude the + # two types are not overlapping, either completely or partially. - The rationale is that this case is usually very unlikely as multiple - inheritance is rare. Also, we can't reliably determine whether - multiple inheritance actually occurs somewhere in a program, due to - stub files hiding implementation details, dynamic loading etc. + return False - TODO: Don't consider callables always overlapping. - TODO: Don't consider type variables with values always overlapping. - """ - # Any overlaps with everything - if isinstance(t, AnyType) or isinstance(s, AnyType): - return True - # object overlaps with everything - if (isinstance(t, Instance) and t.type.fullname() == 'builtins.object' or - isinstance(s, Instance) and s.type.fullname() == 'builtins.object'): - return True - # Since we are effectively working with the erased types, we only - # need to handle occurrences of TypeVarType at the top level. - if isinstance(t, TypeVarType): - t = t.erase_to_union_or_bound() - if isinstance(s, TypeVarType): - s = s.erase_to_union_or_bound() - if isinstance(t, TypedDictType): - t = t.as_anonymous().fallback - if isinstance(s, TypedDictType): - s = s.as_anonymous().fallback - - if isinstance(t, UnionType): - return any(is_overlapping_erased_types(item, s) - for item in t.relevant_items()) - if isinstance(s, UnionType): - return any(is_overlapping_erased_types(t, item) - for item in s.relevant_items()) - - # We must check for TupleTypes before Instances, since Tuple[A, ...] - # is an Instance - tup_overlap = is_overlapping_tuples(t, s, use_promotions) - if tup_overlap is not None: - return tup_overlap - - if isinstance(t, Instance): - if isinstance(s, Instance): - # Consider two classes non-disjoint if one is included in the mro - # of another. - if use_promotions: - # Consider cases like int vs float to be overlapping where - # there is only a type promotion relationship but not proper - # subclassing. - if t.type._promote and is_overlapping_erased_types(t.type._promote, s): - return True - if s.type._promote and is_overlapping_erased_types(s.type._promote, t): - return True - if t.type in s.type.mro or s.type in t.type.mro: - return True - if t.type.is_protocol and is_protocol_implementation(s, t): - return True - if s.type.is_protocol and is_protocol_implementation(t, s): - return True - return False - if isinstance(t, TypeType) and isinstance(s, TypeType): - # If both types are TypeType, compare their inner types. - return is_overlapping_erased_types(t.item, s.item, use_promotions) - elif isinstance(t, TypeType) or isinstance(s, TypeType): - # If exactly only one of t or s is a TypeType, check if one of them - # is an `object` or a `type` and otherwise assume no overlap. - one = t if isinstance(t, TypeType) else s - other = s if isinstance(t, TypeType) else t - if isinstance(other, Instance): - return other.type.fullname() in {'builtins.object', 'builtins.type'} - else: - return isinstance(other, CallableType) and is_subtype(other, one) - if experiments.STRICT_OPTIONAL: - if isinstance(t, NoneTyp) != isinstance(s, NoneTyp): - # NoneTyp does not overlap with other non-Union types under strict Optional checking - return False - # We conservatively assume that non-instance, non-union, non-TupleType and non-TypeType types - # can overlap any other types. - return True - - -def is_overlapping_tuples(t: Type, s: Type, use_promotions: bool) -> Optional[bool]: - """Part of is_overlapping_types(), for tuples only""" - t = adjust_tuple(t, s) or t - s = adjust_tuple(s, t) or s - if isinstance(t, TupleType) or isinstance(s, TupleType): - if isinstance(t, TupleType) and isinstance(s, TupleType): - if t.length() == s.length(): - if all(is_overlapping_erased_types(ti, si, use_promotions) - for ti, si in zip(t.items, s.items)): - return True - # TupleType and non-tuples do not overlap - return False - # No tuples are involved here - return None +def is_overlapping_erased_types(left: Type, right: Type) -> bool: + """The same as 'is_overlapping_erased_types', except the types are erased first.""" + return is_overlapping_types(erase_type(left), erase_type(right)) def adjust_tuple(left: Type, r: Type) -> Optional[TupleType]: From 4cdeab7ccaac565e3475660ba99bf1737f243487 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Tue, 17 Jul 2018 13:49:57 -0700 Subject: [PATCH 03/21] Refactor and add some comments --- mypy/checker.py | 10 ++--- mypy/meet.py | 116 ++++++++++++++++++++++++++++++------------------ 2 files changed, 78 insertions(+), 48 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index a01de18bbe73..19bcc2731210 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3638,7 +3638,7 @@ def is_unsafe_overlapping_overload_signatures(signature: CallableType, other: CallableType) -> bool: """Check if two overloaded signatures are unsafely overlapping, ignoring partial overlaps. - We consider two functions 's' and 't' to be unsafely overlapping both if + We consider two functions 's' and 't' to be unsafely overlapping if both of the following are true: 1. s's parameters are all more precise or partially overlapping with t's @@ -3666,7 +3666,7 @@ def is_unsafe_partially_overlapping_overload_signatures(signature: CallableType, other: CallableType) -> bool: """Check if two overloaded signatures are unsafely overlapping, ignoring partial overlaps. - We consider two functions 's' and 't' to be unsafely overlapping both if + We consider two functions 's' and 't' to be unsafely overlapping if both of the following are true: 1. s's parameters are all more precise or partially overlapping with t's @@ -3678,8 +3678,8 @@ def is_unsafe_partially_overlapping_overload_signatures(signature: CallableType, def is_more_precise_or_partially_overlapping(t: Type, s: Type) -> bool: return is_more_precise(t, s) or is_overlapping_types(t, s) - # Try detaching callables from the containing class so we can try unifying - # free type variables against each other. + # Try detaching callables from the containing class so that all TypeVars + # are treated as being free. # # This lets us identify cases where the two signatures use completely # incompatible types -- e.g. see the testOverloadingInferUnionReturnWithMixedTypevars @@ -3712,7 +3712,7 @@ def is_more_precise_or_partially_overlapping(t: Type, s: Type) -> bool: def detach_callable(typ: CallableType) -> CallableType: - """Ensures that the callable's type variables are 'detached' and independent of the context + """Ensures that the callable's type variables are 'detached' and independent of the context. A callable normally keeps track of the type variables it uses within its 'variables' field. However, if the callable is from a method and that method is using a class type variable, diff --git a/mypy/meet.py b/mypy/meet.py index a68c64c38ff4..c0fb784596c4 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -5,7 +5,7 @@ from mypy.types import ( Type, AnyType, TypeVisitor, UnboundType, NoneTyp, TypeVarType, Instance, CallableType, TupleType, TypedDictType, ErasedType, UnionType, PartialType, DeletedType, - UninhabitedType, TypeType, TypeOfAny, Overloaded, FunctionLike, is_named_instance, + UninhabitedType, TypeType, TypeOfAny, Overloaded, FunctionLike, ) from mypy.subtypes import ( is_equivalent, is_subtype, is_protocol_implementation, is_callable_compatible, @@ -54,6 +54,30 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type: def get_possible_variants(typ: Type) -> List[Type]: + """This function takes any "Union-like" type and returns a list of the available "options". + + Specifically, there are currently exactly three different types that can have + "variants" or are "union-like": + + - Unions + - TypeVars with value restrictions + - Overloads + + This function will return a list of each "option" present in those types. + + If this function receives any other type, we return a list containing just that + original type. (E.g. pretend the type was contained within a singleton union). + + The only exception is regular TypeVars: we return a list containing that TypeVar's + upper bound. + + This function is useful primarily when checking to see if two types are overlapping: + the algorithm to check if two unions are overlapping is fundamentally the same as + the algorithm for checking if two overloads are overlapping. + + Normalizing both kinds of types in the same way lets us reuse the same algorithm + for both. + """ if isinstance(typ, TypeVarType): if len(typ.values) > 0: return typ.values @@ -73,12 +97,13 @@ def get_possible_variants(typ: Type) -> List[Type]: def is_overlapping_types(left: Type, right: Type) -> bool: """Can a value of type 'left' also be of type 'right' or vice-versa?""" + # We should never encounter these types, but if we do, we handle # them in the same way we handle 'Any'. - illegal_types = (UnboundType, PartialType, ErasedType, DeletedType) if isinstance(left, illegal_types) or isinstance(right, illegal_types): - # TODO: Replace this with an 'assert False'. + # TODO: Replace this with an 'assert False' once we are confident we + # never accidentally generate these types. return True # 'Any' may or may not be overlapping with the other type @@ -92,19 +117,12 @@ def is_overlapping_types(left: Type, right: Type) -> bool: if is_subtype(left, right) or is_subtype(right, left): return True - # There are a number of different types that can have "variants" or are "union-like". - # For example: - # - # - Unions - # - TypeVars with value restrictions - # - Overloads + # See the docstring for 'get_possible_variants' for more info on what the + # following lines are doing. # - # We extract the component variants into a list. Types with a single variant are - # stored in a singleton list. - # - # The logic to check whether any of these types are overlapping are essentially the - # same: we obtain the list of possible variants and make sure there exists - # items in the intersection. + # Note that we use 'left_possible' and 'right_possible' in two different + # locations: immediately after to handle TypeVars, and near the end of + # 'is_overlapping_types' to handle types like Unions or Overloads. left_possible = get_possible_variants(left) right_possible = get_possible_variants(right) @@ -136,46 +154,23 @@ def is_overlapping_types(left: Type, right: Type) -> bool: # - Tuples # # If we cannot identify a partial overlap and end early, we degrade these two types - # into their (Instance) fallbacks. + # into their 'Instance' fallbacks. if isinstance(left, TypedDictType) and isinstance(right, TypedDictType): - # All required keys in left are present and overlapping with something in right - for key in left.required_keys: - if key not in right.items: - return False - if not is_overlapping_types(left.items[key], right.items[key]): - return False - - # Repeat check in the other direction - for key in right.required_keys: - if key not in left.items: - return False - if not is_overlapping_types(left.items[key], right.items[key]): - return False - - # The presence of any additional optional keys does not affect whether the two - # TypedDicts are partially overlapping: the dicts would be overlapping if the - # keys happened to be missing. - return True + return are_typed_dicts_overlapping(left, right) elif isinstance(left, TypedDictType): left = left.fallback elif isinstance(right, TypedDictType): right = right.fallback if is_tuple(left) and is_tuple(right): - left = adjust_tuple(left, right) or left - right = adjust_tuple(right, left) or right - assert isinstance(left, TupleType) - assert isinstance(right, TupleType) - if len(left.items) != len(right.items): - return False - return all(is_overlapping_types(l, r) for l, r in zip(left.items, right.items)) + return are_tuples_overlapping(left, right) elif isinstance(left, TupleType): left = left.fallback elif isinstance(right, TupleType): right = right.fallback - # Next, we handle single-variant types that are not inherently partially overlapping, + # Next, we handle single-variant types that cannot be inherently partially overlapping, # but do require custom logic to inspect. if isinstance(left, TypeType) and isinstance(right, TypeType): @@ -222,7 +217,7 @@ def is_overlapping_types(left: Type, right: Type) -> bool: # We handle all remaining types here: in particular, types like # UnionType, Overloaded, NoneTyp, and UninhabitedType. # - # We deliberately skip singleton variant types to avoid + # We deliberately skip comparing two singleton variant types to avoid # infinitely recursing. if len(left_possible) >= 1 or len(right_possible) >= 1: @@ -242,6 +237,41 @@ def is_overlapping_erased_types(left: Type, right: Type) -> bool: return is_overlapping_types(erase_type(left), erase_type(right)) +def are_typed_dicts_overlapping(left: TypedDictType, right: TypedDictType) -> bool: + """Returns 'true' if left and right are overlapping TypeDictTypes.""" + # All required keys in left are present and overlapping with something in right + for key in left.required_keys: + if key not in right.items: + return False + if not is_overlapping_types(left.items[key], right.items[key]): + return False + + # Repeat check in the other direction + for key in right.required_keys: + if key not in left.items: + return False + if not is_overlapping_types(left.items[key], right.items[key]): + return False + + # The presence of any additional optional keys does not affect whether the two + # TypedDicts are partially overlapping: the dicts would be overlapping if the + # keys happened to be missing. + return True + + +def are_tuples_overlapping(left: Type, right: Type) -> bool: + """Returns true if left and right are overlapping tuples. + + Precondition: is_tuple(left) and is_tuple(right) are both true.""" + left = adjust_tuple(left, right) or left + right = adjust_tuple(right, left) or right + assert isinstance(left, TupleType) + assert isinstance(right, TupleType) + if len(left.items) != len(right.items): + return False + return all(is_overlapping_types(l, r) for l, r in zip(left.items, right.items)) + + def adjust_tuple(left: Type, r: Type) -> Optional[TupleType]: """Find out if `left` is a Tuple[A, ...], and adjust its length to `right`""" if isinstance(left, Instance) and left.type.fullname() == 'builtins.tuple': From 813ddf7920f5dfe31fcc5226988af6c59394c67a Mon Sep 17 00:00:00 2001 From: Janus Troelsen Date: Wed, 18 Jul 2018 17:51:56 +0200 Subject: [PATCH 04/21] Add distutils and encodings as stdlib modules (#5372) Fixes https://github.com/python/mypy/issues/5351 --- mypy/moduleinfo.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mypy/moduleinfo.py b/mypy/moduleinfo.py index ec6a8fdfde23..3d3dab20f317 100644 --- a/mypy/moduleinfo.py +++ b/mypy/moduleinfo.py @@ -270,9 +270,11 @@ 'decimal', 'difflib', 'dis', + 'distutils', 'doctest', 'dummy_threading', 'email', + 'encodings', 'fcntl', 'filecmp', 'fileinput', From 9bb04773bdf26297403a140ce61658e8b8ff429d Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Wed, 18 Jul 2018 13:08:54 -0700 Subject: [PATCH 05/21] Sync typeshed (#5373) This is mostly so we can get https://github.com/python/typeshed/pull/2338, which unblocks https://github.com/python/mypy/pull/5370. --- typeshed | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/typeshed b/typeshed index f582b53ff7de..574807d9eac3 160000 --- a/typeshed +++ b/typeshed @@ -1 +1 @@ -Subproject commit f582b53ff7de199bb3faabee02542efeb7b01690 +Subproject commit 574807d9eac34797c9bb776458e29a680b569c91 From 2e4ab7cbb3ec2de376e6b18c1cb187360f4c8679 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Wed, 18 Jul 2018 13:24:54 -0700 Subject: [PATCH 06/21] Enable --no-silence-site-packages in eval tests (#5370) An unintended side-effect of silencing errors in site packages and typeshed by default is that our tests no longer flag changes to mypy that end up breaking typeshed in some way. (For example, see https://github.com/python/mypy/pulls/5280 which *really* should not be passing, at least as of time of writing.) This pull request will add the `--no-silence-site-packages` flag to any test suites that directly or indirectly test typeshed. Specifically, I believe this PR ought to add the flag to any tests triggered by `testsamples.py`, `testselfcheck.py`, and `testpythoneval.py`. --- mypy/test/helpers.py | 3 ++- mypy/test/testpythoneval.py | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/mypy/test/helpers.py b/mypy/test/helpers.py index 6e891d54bf1b..2dafc500d6e0 100644 --- a/mypy/test/helpers.py +++ b/mypy/test/helpers.py @@ -32,7 +32,8 @@ def run_mypy(args: List[str]) -> None: __tracebackhide__ = True outval, errval, status = api.run(args + ['--show-traceback', - '--no-site-packages']) + '--no-site-packages', + '--no-silence-site-packages']) if status != 0: sys.stdout.write(outval) sys.stderr.write(errval) diff --git a/mypy/test/testpythoneval.py b/mypy/test/testpythoneval.py index d2b497717b6e..bc69155cf8f7 100644 --- a/mypy/test/testpythoneval.py +++ b/mypy/test/testpythoneval.py @@ -50,7 +50,12 @@ def test_python_evaluation(testcase: DataDrivenTestCase, cache_dir: str) -> None """ assert testcase.old_cwd is not None, "test was not properly set up" # TODO: Enable strict optional for these tests - mypy_cmdline = ['--show-traceback', '--no-site-packages', '--no-strict-optional'] + mypy_cmdline = [ + '--show-traceback', + '--no-site-packages', + '--no-strict-optional', + '--no-silence-site-packages', + ] py2 = testcase.name.lower().endswith('python2') if py2: mypy_cmdline.append('--py2') From ee86385ab306921c4a1b5ca82e3f3ef5eb3c3865 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Wed, 18 Jul 2018 15:32:24 -0700 Subject: [PATCH 07/21] WIP: Try switching to using relevant_items() --- mypy/meet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/meet.py b/mypy/meet.py index c0fb784596c4..341fb7718bb7 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -84,7 +84,7 @@ def get_possible_variants(typ: Type) -> List[Type]: else: return [typ.upper_bound] elif isinstance(typ, UnionType): - return typ.items + return typ.relevant_items() elif isinstance(typ, Overloaded): # Note: doing 'return typ.items()' makes mypy # infer a too-specific return type of List[CallableType] From 4cad67285b1b7a1e7de77d0abde89c64eb3ffc59 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Tue, 17 Jul 2018 18:45:19 -0700 Subject: [PATCH 08/21] WIP: try modifying interaction with overloads, classes, and invariance --- mypy/checker.py | 81 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 68 insertions(+), 13 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 19bcc2731210..938dbfc6990e 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3653,13 +3653,24 @@ def is_unsafe_overlapping_overload_signatures(signature: CallableType, Assumes that 'signature' appears earlier in the list of overload alternatives then 'other' and that their argument counts are overlapping. """ + #if "foo" in signature.name or "bar" in signature.name or "chain_call" in signature.name: + # print("in first") - return is_callable_compatible(signature, other, + signature = detach_callable(signature) + other = detach_callable(other) + + return (is_callable_compatible(signature, other, is_compat=is_more_precise, is_compat_return=lambda l, r: not is_subtype(l, r), ignore_return=False, check_args_covariantly=True, - allow_partial_overlap=True) + allow_partial_overlap=True) or + is_callable_compatible(other, signature, + is_compat=is_more_precise, + is_compat_return=lambda l, r: not is_subtype(r, l), + ignore_return=False, + check_args_covariantly=False, + allow_partial_overlap=True)) def is_unsafe_partially_overlapping_overload_signatures(signature: CallableType, @@ -3675,6 +3686,9 @@ def is_unsafe_partially_overlapping_overload_signatures(signature: CallableType, Assumes that 'signature' appears earlier in the list of overload alternatives then 'other' and that their argument counts are overlapping. """ + #if "foo" in signature.name or "bar" in signature.name or "chain_call" in signature.name: + # print("in second") + def is_more_precise_or_partially_overlapping(t: Type, s: Type) -> bool: return is_more_precise(t, s) or is_overlapping_types(t, s) @@ -3723,9 +3737,35 @@ def detach_callable(typ: CallableType) -> CallableType: The caller can then unify on all type variables whether or not the callable is originally from a class or not.""" - type_vars = typ.accept(TypeVarExtractor()) + type_list = typ.arg_types + [typ.ret_type] + old_type_list = list(type_list) + + appear_map = {} # type: Dict[str, List[int]] + for i, inner_type in enumerate(type_list): + typevars_available = inner_type.accept(TypeVarExtractor()) + for var in typevars_available: + if var.fullname not in appear_map: + appear_map[var.fullname] = [] + appear_map[var.fullname].append(i) + + from mypy.erasetype import erase_type + + used_type_var_names = set() + for var_name, appearances in appear_map.items(): + '''if len(appearances) == 1: + entry = appearances[0] + type_list[entry] = erase_type(type_list[entry]) + else: + used_type_var_names.add(var_name)''' + + used_type_var_names.add(var_name) + + all_type_vars = typ.accept(TypeVarExtractor()) new_variables = [] - for var in type_vars: + for var in set(all_type_vars): + if var.fullname not in used_type_var_names: + continue + #new_variables.append(var) new_variables.append(TypeVarDef( name=var.name, fullname=var.fullname, @@ -3734,21 +3774,36 @@ def detach_callable(typ: CallableType) -> CallableType: upper_bound=var.upper_bound, variance=var.variance, )) - return typ.copy_modified(variables=new_variables) - - -class TypeVarExtractor(TypeQuery[Set[TypeVarType]]): + out = typ.copy_modified( + variables=new_variables, + arg_types=type_list[:-1], + ret_type=type_list[-1], + ) + ''' + print(typ.name) + print(' before:', typ) + print(' after: ', out) + print(' type list (old):', old_type_list) + print(' type list (new):', type_list) + print(' old_vars:', typ.variables) + print(' new_vars:', out.variables) + print(' appear_map:', appear_map) + #''' + return out + + +class TypeVarExtractor(TypeQuery[List[TypeVarType]]): def __init__(self) -> None: super().__init__(self._merge) - def _merge(self, iter: Iterable[Set[TypeVarType]]) -> Set[TypeVarType]: - out = set() + def _merge(self, iter: Iterable[List[TypeVarType]]) -> List[TypeVarType]: + out = [] for item in iter: - out.update(item) + out.extend(item) return out - def visit_type_var(self, t: TypeVarType) -> Set[TypeVarType]: - return {t} + def visit_type_var(self, t: TypeVarType) -> List[TypeVarType]: + return [t] def overload_can_never_match(signature: CallableType, other: CallableType) -> bool: From b55402320600aed31108b528e6949e4383c8d38b Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Thu, 19 Jul 2018 17:57:17 -0700 Subject: [PATCH 09/21] WIP: Allow overlapping type checks to disable promotions --- mypy/checker.py | 38 ++-- mypy/meet.py | 51 ++++-- mypy/subtypes.py | 200 +++++++++++++-------- mypy/typestate.py | 46 ++--- test-data/unit/check-isinstance.test | 50 ++++++ test-data/unit/check-overloading.test | 53 +++++- test-data/unit/fixtures/isinstancelist.pyi | 1 + 7 files changed, 310 insertions(+), 129 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 938dbfc6990e..05cdceb6396a 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -495,12 +495,12 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: # Is the overload alternative's arguments subtypes of the implementation's? if not is_callable_compatible(impl, sig1, - is_compat=is_subtype, + is_compat=is_subtype_no_promote, ignore_return=True): self.msg.overloaded_signatures_arg_specific(i + 1, defn.impl) # Is the overload alternative's return type a subtype of the implementation's? - if not is_subtype(sig1.ret_type, impl.ret_type): + if not is_subtype_no_promote(sig1.ret_type, impl.ret_type): self.msg.overloaded_signatures_ret_specific(i + 1, defn.impl) # Here's the scoop about generators and coroutines. @@ -3653,21 +3653,21 @@ def is_unsafe_overlapping_overload_signatures(signature: CallableType, Assumes that 'signature' appears earlier in the list of overload alternatives then 'other' and that their argument counts are overlapping. """ - #if "foo" in signature.name or "bar" in signature.name or "chain_call" in signature.name: + # if "foo" in signature.name or "bar" in signature.name or "chain_call" in signature.name: # print("in first") signature = detach_callable(signature) other = detach_callable(other) return (is_callable_compatible(signature, other, - is_compat=is_more_precise, - is_compat_return=lambda l, r: not is_subtype(l, r), + is_compat=is_more_precise_no_promote, + is_compat_return=lambda l, r: not is_subtype_no_promote(l, r), ignore_return=False, check_args_covariantly=True, allow_partial_overlap=True) or is_callable_compatible(other, signature, - is_compat=is_more_precise, - is_compat_return=lambda l, r: not is_subtype(r, l), + is_compat=is_more_precise_no_promote, + is_compat_return=lambda l, r: not is_subtype_no_promote(r, l), ignore_return=False, check_args_covariantly=False, allow_partial_overlap=True)) @@ -3686,11 +3686,11 @@ def is_unsafe_partially_overlapping_overload_signatures(signature: CallableType, Assumes that 'signature' appears earlier in the list of overload alternatives then 'other' and that their argument counts are overlapping. """ - #if "foo" in signature.name or "bar" in signature.name or "chain_call" in signature.name: + # if "foo" in signature.name or "bar" in signature.name or "chain_call" in signature.name: # print("in second") def is_more_precise_or_partially_overlapping(t: Type, s: Type) -> bool: - return is_more_precise(t, s) or is_overlapping_types(t, s) + return is_more_precise_no_promote(t, s) or is_overlapping_types_no_promote(t, s) # Try detaching callables from the containing class so that all TypeVars # are treated as being free. @@ -3713,13 +3713,13 @@ def is_more_precise_or_partially_overlapping(t: Type, s: Type) -> bool: # checks twice in both directions for now. return (is_callable_compatible(signature, other, is_compat=is_more_precise_or_partially_overlapping, - is_compat_return=lambda l, r: not is_subtype(l, r), + is_compat_return=lambda l, r: not is_subtype_no_promote(l, r), ignore_return=False, check_args_covariantly=True, allow_partial_overlap=True) or is_callable_compatible(other, signature, is_compat=is_more_precise_or_partially_overlapping, - is_compat_return=lambda l, r: not is_subtype(r, l), + is_compat_return=lambda l, r: not is_subtype_no_promote(r, l), ignore_return=False, check_args_covariantly=False, allow_partial_overlap=True)) @@ -3738,7 +3738,7 @@ def detach_callable(typ: CallableType) -> CallableType: The caller can then unify on all type variables whether or not the callable is originally from a class or not.""" type_list = typ.arg_types + [typ.ret_type] - old_type_list = list(type_list) + # old_type_list = list(type_list) appear_map = {} # type: Dict[str, List[int]] for i, inner_type in enumerate(type_list): @@ -3765,7 +3765,7 @@ def detach_callable(typ: CallableType) -> CallableType: for var in set(all_type_vars): if var.fullname not in used_type_var_names: continue - #new_variables.append(var) + # new_variables.append(var) new_variables.append(TypeVarDef( name=var.name, fullname=var.fullname, @@ -4077,3 +4077,15 @@ def is_static(func: Union[FuncBase, Decorator]) -> bool: elif isinstance(func, FuncBase): return func.is_static assert False, "Unexpected func type: {}".format(type(func)) + + +def is_subtype_no_promote(left: Type, right: Type) -> bool: + return is_subtype(left, right, ignore_promotions=True) + + +def is_more_precise_no_promote(left: Type, right: Type) -> bool: + return is_more_precise(left, right, ignore_promotions=True) + + +def is_overlapping_types_no_promote(left: Type, right: Type) -> bool: + return is_overlapping_types(left, right, ignore_promotions=True) diff --git a/mypy/meet.py b/mypy/meet.py index 341fb7718bb7..d81a0f98133d 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -95,9 +95,15 @@ def get_possible_variants(typ: Type) -> List[Type]: return [typ] -def is_overlapping_types(left: Type, right: Type) -> bool: +def is_overlapping_types(left: Type, right: Type, ignore_promotions: bool = False) -> bool: """Can a value of type 'left' also be of type 'right' or vice-versa?""" + def _is_overlapping_types(left: Type, right: Type) -> bool: + '''Encode the kind of overlapping check to perform. + + This function mostly exists so we don't have to repeat keyword arguments everywhere.''' + return is_overlapping_types(left, right, ignore_promotions=ignore_promotions) + # We should never encounter these types, but if we do, we handle # them in the same way we handle 'Any'. illegal_types = (UnboundType, PartialType, ErasedType, DeletedType) @@ -114,7 +120,8 @@ def is_overlapping_types(left: Type, right: Type) -> bool: # If this check fails, we start checking to see if there exists a # *partial* overlap between types. - if is_subtype(left, right) or is_subtype(right, left): + if (is_subtype(left, right, ignore_promotions=ignore_promotions) + or is_subtype(right, left, ignore_promotions=ignore_promotions)): return True # See the docstring for 'get_possible_variants' for more info on what the @@ -136,7 +143,7 @@ def is_overlapping_types(left: Type, right: Type) -> bool: if isinstance(left, TypeVarType) or isinstance(right, TypeVarType): for l in left_possible: for r in right_possible: - if is_overlapping_types(l, r): + if _is_overlapping_types(l, r): return True return False @@ -157,14 +164,14 @@ def is_overlapping_types(left: Type, right: Type) -> bool: # into their 'Instance' fallbacks. if isinstance(left, TypedDictType) and isinstance(right, TypedDictType): - return are_typed_dicts_overlapping(left, right) + return are_typed_dicts_overlapping(left, right, ignore_promotions=ignore_promotions) elif isinstance(left, TypedDictType): left = left.fallback elif isinstance(right, TypedDictType): right = right.fallback if is_tuple(left) and is_tuple(right): - return are_tuples_overlapping(left, right) + return are_tuples_overlapping(left, right, ignore_promotions=ignore_promotions) elif isinstance(left, TupleType): left = left.fallback elif isinstance(right, TupleType): @@ -174,14 +181,14 @@ def is_overlapping_types(left: Type, right: Type) -> bool: # but do require custom logic to inspect. if isinstance(left, TypeType) and isinstance(right, TypeType): - return is_overlapping_types(left.item, right.item) + return _is_overlapping_types(left.item, right.item) elif isinstance(left, TypeType) or isinstance(right, TypeType): # TODO: Can Callable[[...], T] and Type[T] be partially overlapping? return False if isinstance(left, CallableType) and isinstance(right, CallableType): return is_callable_compatible(left, right, - is_compat=is_overlapping_types, + is_compat=_is_overlapping_types, ignore_pos_arg_names=True, allow_partial_overlap=True) elif isinstance(left, CallableType): @@ -198,8 +205,9 @@ def is_overlapping_types(left: Type, right: Type) -> bool: return True # Two unrelated types cannot be partially overlapping: they're disjoint. - # We don't need to handle promotions because promotable types either - # are overlapping or are not -- they can't be partially overlapping. + # We don't need to handle promotions because they've already been handled + # by the calls to `is_subtype(...)` up above (and promotable types never + # have any generic arguments we need to recurse on). if left.type.has_base(right.type.fullname()): left = map_instance_to_supertype(left, right.type) elif right.type.has_base(left.type.fullname()): @@ -209,7 +217,7 @@ def is_overlapping_types(left: Type, right: Type) -> bool: if len(left.args) == len(right.args): for left_arg, right_arg in zip(left.args, right.args): - if is_overlapping_types(left_arg, right_arg): + if _is_overlapping_types(left_arg, right_arg): return True return False @@ -223,7 +231,7 @@ def is_overlapping_types(left: Type, right: Type) -> bool: if len(left_possible) >= 1 or len(right_possible) >= 1: for a in left_possible: for b in right_possible: - if is_overlapping_types(a, b): + if _is_overlapping_types(a, b): return True # We ought to have handled every case by now: we conclude the @@ -232,25 +240,30 @@ def is_overlapping_types(left: Type, right: Type) -> bool: return False -def is_overlapping_erased_types(left: Type, right: Type) -> bool: +def is_overlapping_erased_types(left: Type, right: Type, *, + ignore_promotions: bool = False) -> bool: """The same as 'is_overlapping_erased_types', except the types are erased first.""" - return is_overlapping_types(erase_type(left), erase_type(right)) + return is_overlapping_types(erase_type(left), erase_type(right), + ignore_promotions=ignore_promotions) -def are_typed_dicts_overlapping(left: TypedDictType, right: TypedDictType) -> bool: +def are_typed_dicts_overlapping(left: TypedDictType, right: TypedDictType, *, + ignore_promotions: bool = False) -> bool: """Returns 'true' if left and right are overlapping TypeDictTypes.""" # All required keys in left are present and overlapping with something in right for key in left.required_keys: if key not in right.items: return False - if not is_overlapping_types(left.items[key], right.items[key]): + if not is_overlapping_types(left.items[key], right.items[key], + ignore_promotions=ignore_promotions): return False # Repeat check in the other direction for key in right.required_keys: if key not in left.items: return False - if not is_overlapping_types(left.items[key], right.items[key]): + if not is_overlapping_types(left.items[key], right.items[key], + ignore_promotions=ignore_promotions): return False # The presence of any additional optional keys does not affect whether the two @@ -259,7 +272,8 @@ def are_typed_dicts_overlapping(left: TypedDictType, right: TypedDictType) -> bo return True -def are_tuples_overlapping(left: Type, right: Type) -> bool: +def are_tuples_overlapping(left: Type, right: Type, *, + ignore_promotions: bool = False) -> bool: """Returns true if left and right are overlapping tuples. Precondition: is_tuple(left) and is_tuple(right) are both true.""" @@ -269,7 +283,8 @@ def are_tuples_overlapping(left: Type, right: Type) -> bool: assert isinstance(right, TupleType) if len(left.items) != len(right.items): return False - return all(is_overlapping_types(l, r) for l, r in zip(left.items, right.items)) + return all(is_overlapping_types(l, r, ignore_promotions=ignore_promotions) + for l, r in zip(left.items, right.items)) def adjust_tuple(left: Type, r: Type) -> Optional[TupleType]: diff --git a/mypy/subtypes.py b/mypy/subtypes.py index edb3d19cca0b..98f7cc19f99a 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -20,7 +20,7 @@ from mypy.maptype import map_instance_to_supertype from mypy.expandtype import expand_type_by_instance from mypy.sametypes import is_same_type -from mypy.typestate import TypeState +from mypy.typestate import TypeState, SubtypeKind from mypy import experiments @@ -46,7 +46,8 @@ def check_type_parameter(lefta: Type, righta: Type, variance: int) -> bool: def is_subtype(left: Type, right: Type, type_parameter_checker: TypeParameterChecker = check_type_parameter, *, ignore_pos_arg_names: bool = False, - ignore_declared_variance: bool = False) -> bool: + ignore_declared_variance: bool = False, + ignore_promotions: bool = False) -> bool: """Is 'left' subtype of 'right'? Also consider Any to be a subtype of any type, and vice versa. This @@ -66,7 +67,9 @@ def is_subtype(left: Type, right: Type, # 'left' can be a subtype of the union 'right' is if it is a # subtype of one of the items making up the union. is_subtype_of_item = any(is_subtype(left, item, type_parameter_checker, - ignore_pos_arg_names=ignore_pos_arg_names) + ignore_pos_arg_names=ignore_pos_arg_names, + ignore_declared_variance=ignore_declared_variance, + ignore_promotions=ignore_promotions) for item in right.items) # However, if 'left' is a type variable T, T might also have # an upper bound which is itself a union. This case will be @@ -81,7 +84,8 @@ def is_subtype(left: Type, right: Type, # otherwise, fall through return left.accept(SubtypeVisitor(right, type_parameter_checker, ignore_pos_arg_names=ignore_pos_arg_names, - ignore_declared_variance=ignore_declared_variance)) + ignore_declared_variance=ignore_declared_variance, + ignore_promotions=ignore_promotions)) def is_subtype_ignoring_tvars(left: Type, right: Type) -> bool: @@ -106,11 +110,43 @@ class SubtypeVisitor(TypeVisitor[bool]): def __init__(self, right: Type, type_parameter_checker: TypeParameterChecker, *, ignore_pos_arg_names: bool = False, - ignore_declared_variance: bool = False) -> None: + ignore_declared_variance: bool = False, + ignore_promotions: bool = False) -> None: self.right = right self.check_type_parameter = type_parameter_checker self.ignore_pos_arg_names = ignore_pos_arg_names self.ignore_declared_variance = ignore_declared_variance + self.ignore_promotions = ignore_promotions + self._subtype_kind = SubtypeVisitor.build_subtype_kind( + type_parameter_checker=type_parameter_checker, + ignore_pos_arg_names=ignore_pos_arg_names, + ignore_declared_variance=ignore_declared_variance, + ignore_promotions=ignore_promotions) + + @staticmethod + def build_subtype_kind(*, + type_parameter_checker: TypeParameterChecker = check_type_parameter, + ignore_pos_arg_names: bool = False, + ignore_declared_variance: bool = False, + ignore_promotions: bool = False) -> SubtypeKind: + return hash(('subtype', + type_parameter_checker, + ignore_pos_arg_names, + ignore_declared_variance, + ignore_promotions)) + + def _lookup_cache(self, left: Instance, right: Instance) -> bool: + return TypeState.is_cached_subtype_check(self._subtype_kind, left, right) + + def _record_cache(self, left: Instance, right: Instance) -> None: + TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) + + def _is_subtype(self, left: Type, right: Type) -> bool: + return is_subtype(left, right, + type_parameter_checker=self.check_type_parameter, + ignore_pos_arg_names=self.ignore_pos_arg_names, + ignore_declared_variance=self.ignore_declared_variance, + ignore_promotions=self.ignore_promotions) # visit_x(left) means: is left (which is an instance of X) a subtype of # right? @@ -144,19 +180,17 @@ def visit_instance(self, left: Instance) -> bool: return True right = self.right if isinstance(right, TupleType) and right.fallback.type.is_enum: - return is_subtype(left, right.fallback) + return self._is_subtype(left, right.fallback) if isinstance(right, Instance): - if TypeState.is_cached_subtype_check(left, right): + if self._lookup_cache(left, right): return True # NOTE: left.type.mro may be None in quick mode if there # was an error somewhere. - if left.type.mro is not None: + if not self.ignore_promotions and left.type.mro is not None: for base in left.type.mro: # TODO: Also pass recursively ignore_declared_variance - if base._promote and is_subtype( - base._promote, self.right, self.check_type_parameter, - ignore_pos_arg_names=self.ignore_pos_arg_names): - TypeState.record_subtype_cache_entry(left, right) + if base._promote and self._is_subtype(base._promote, self.right): + self._record_cache(left, right) return True rname = right.type.fullname() # Always try a nominal check if possible, @@ -169,7 +203,7 @@ def visit_instance(self, left: Instance) -> bool: for lefta, righta, tvar in zip(t.args, right.args, right.type.defn.type_vars)) if nominal: - TypeState.record_subtype_cache_entry(left, right) + self._record_cache(left, right) return nominal if right.type.is_protocol and is_protocol_implementation(left, right): return True @@ -179,7 +213,7 @@ def visit_instance(self, left: Instance) -> bool: if isinstance(item, TupleType): item = item.fallback if is_named_instance(left, 'builtins.type'): - return is_subtype(TypeType(AnyType(TypeOfAny.special_form)), right) + return self._is_subtype(TypeType(AnyType(TypeOfAny.special_form)), right) if left.type.is_metaclass(): if isinstance(item, AnyType): return True @@ -189,7 +223,7 @@ def visit_instance(self, left: Instance) -> bool: # Special case: Instance can be a subtype of Callable. call = find_member('__call__', left, left) if call: - return is_subtype(call, right) + return self._is_subtype(call, right) return False else: return False @@ -198,27 +232,24 @@ def visit_type_var(self, left: TypeVarType) -> bool: right = self.right if isinstance(right, TypeVarType) and left.id == right.id: return True - if left.values and is_subtype(UnionType.make_simplified_union(left.values), right): + if left.values and self._is_subtype(UnionType.make_simplified_union(left.values), right): return True - return is_subtype(left.upper_bound, self.right) + return self._is_subtype(left.upper_bound, self.right) def visit_callable_type(self, left: CallableType) -> bool: right = self.right if isinstance(right, CallableType): return is_callable_compatible( left, right, - is_compat=is_subtype, + is_compat=self._is_subtype, ignore_pos_arg_names=self.ignore_pos_arg_names) elif isinstance(right, Overloaded): - return all(is_subtype(left, item, self.check_type_parameter, - ignore_pos_arg_names=self.ignore_pos_arg_names) - for item in right.items()) + return all(self._is_subtype(left, item) for item in right.items()) elif isinstance(right, Instance): - return is_subtype(left.fallback, right, - ignore_pos_arg_names=self.ignore_pos_arg_names) + return is_subtype(left.fallback, right) elif isinstance(right, TypeType): # This is unsound, we don't check the __init__ signature. - return left.is_type_obj() and is_subtype(left.ret_type, right.item) + return left.is_type_obj() and self._is_subtype(left.ret_type, right.item) else: return False @@ -236,17 +267,17 @@ def visit_tuple_type(self, left: TupleType) -> bool: iter_type = right.args[0] else: iter_type = AnyType(TypeOfAny.special_form) - return all(is_subtype(li, iter_type) for li in left.items) - elif is_subtype(left.fallback, right, self.check_type_parameter): + return all(self._is_subtype(li, iter_type) for li in left.items) + elif self._is_subtype(left.fallback, right): return True return False elif isinstance(right, TupleType): if len(left.items) != len(right.items): return False for l, r in zip(left.items, right.items): - if not is_subtype(l, r, self.check_type_parameter): + if not self._is_subtype(l, r): return False - if not is_subtype(left.fallback, right.fallback, self.check_type_parameter): + if not self._is_subtype(left.fallback, right.fallback): return False return True else: @@ -255,7 +286,7 @@ def visit_tuple_type(self, left: TupleType) -> bool: def visit_typeddict_type(self, left: TypedDictType) -> bool: right = self.right if isinstance(right, Instance): - return is_subtype(left.fallback, right, self.check_type_parameter) + return self._is_subtype(left.fallback, right) elif isinstance(right, TypedDictType): if not left.names_are_wider_than(right): return False @@ -281,11 +312,10 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool: def visit_overloaded(self, left: Overloaded) -> bool: right = self.right if isinstance(right, Instance): - return is_subtype(left.fallback, right) + return self._is_subtype(left.fallback, right) elif isinstance(right, CallableType): for item in left.items(): - if is_subtype(item, right, self.check_type_parameter, - ignore_pos_arg_names=self.ignore_pos_arg_names): + if self._is_subtype(item, right): return True return False elif isinstance(right, Overloaded): @@ -298,8 +328,7 @@ def visit_overloaded(self, left: Overloaded) -> bool: found_match = False for left_index, left_item in enumerate(left.items()): - subtype_match = is_subtype(left_item, right_item, self.check_type_parameter, - ignore_pos_arg_names=self.ignore_pos_arg_names) + subtype_match = self._is_subtype(left_item, right_item)\ # Order matters: we need to make sure that the index of # this item is at least the index of the previous one. @@ -314,10 +343,10 @@ def visit_overloaded(self, left: Overloaded) -> bool: # If this one overlaps with the supertype in any way, but it wasn't # an exact match, then it's a potential error. if (is_callable_compatible(left_item, right_item, - is_compat=is_subtype, ignore_return=True, + is_compat=self._is_subtype, ignore_return=True, ignore_pos_arg_names=self.ignore_pos_arg_names) or is_callable_compatible(right_item, left_item, - is_compat=is_subtype, ignore_return=True, + is_compat=self._is_subtype, ignore_return=True, ignore_pos_arg_names=self.ignore_pos_arg_names)): # If this is an overload that's already been matched, there's no # problem. @@ -338,13 +367,12 @@ def visit_overloaded(self, left: Overloaded) -> bool: # All the items must have the same type object status, so # it's sufficient to query only (any) one of them. # This is unsound, we don't check all the __init__ signatures. - return left.is_type_obj() and is_subtype(left.items()[0], right) + return left.is_type_obj() and self._is_subtype(left.items()[0], right) else: return False def visit_union_type(self, left: UnionType) -> bool: - return all(is_subtype(item, self.right, self.check_type_parameter) - for item in left.items) + return all(self._is_subtype(item, self.right) for item in left.items) def visit_partial_type(self, left: PartialType) -> bool: # This is indeterminate as we don't really know the complete type yet. @@ -353,10 +381,10 @@ def visit_partial_type(self, left: PartialType) -> bool: def visit_type_type(self, left: TypeType) -> bool: right = self.right if isinstance(right, TypeType): - return is_subtype(left.item, right.item) + return self._is_subtype(left.item, right.item) if isinstance(right, CallableType): # This is unsound, we don't check the __init__ signature. - return is_subtype(left.item, right.ret_type) + return self._is_subtype(left.item, right.ret_type) if isinstance(right, Instance): if right.type.fullname() in ['builtins.object', 'builtins.type']: return True @@ -365,7 +393,7 @@ def visit_type_type(self, left: TypeType) -> bool: item = item.upper_bound if isinstance(item, Instance): metaclass = item.type.metaclass_type - return metaclass is not None and is_subtype(metaclass, right) + return metaclass is not None and self._is_subtype(metaclass, right) return False @@ -420,6 +448,8 @@ def f(self) -> A: ... return False if not proper_subtype: # Nominal check currently ignores arg names + # NOTE: If we ever change this, be sure to also change the call to + # SubtypeVisitor.build_subtype_kind(...) down below. is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True) else: is_compat = is_proper_subtype(subtype, supertype) @@ -441,10 +471,13 @@ def f(self) -> A: ... # This rule is copied from nominal check in checker.py if IS_CLASS_OR_STATIC in superflags and IS_CLASS_OR_STATIC not in subflags: return False - if proper_subtype: - TypeState.record_proper_subtype_cache_entry(left, right) + + if not proper_subtype: + # Nominal check currently ignores arg names + subtype_kind = SubtypeVisitor.build_subtype_kind(ignore_pos_arg_names=True) else: - TypeState.record_subtype_cache_entry(left, right) + subtype_kind = ProperSubtypeVisitor.build_subtype_kind() + TypeState.record_subtype_cache_entry(subtype_kind, left, right) return True @@ -959,21 +992,38 @@ def restrict_subtype_away(t: Type, s: Type) -> Type: return t -def is_proper_subtype(left: Type, right: Type) -> bool: +def is_proper_subtype(left: Type, right: Type, *, ignore_promotions: bool = False) -> bool: """Is left a proper subtype of right? For proper subtypes, there's no need to rely on compatibility due to Any types. Every usable type is a proper subtype of itself. """ if isinstance(right, UnionType) and not isinstance(left, UnionType): - return any([is_proper_subtype(left, item) + return any([is_proper_subtype(left, item, ignore_promotions=ignore_promotions) for item in right.items]) - return left.accept(ProperSubtypeVisitor(right)) + return left.accept(ProperSubtypeVisitor(right, ignore_promotions=ignore_promotions)) class ProperSubtypeVisitor(TypeVisitor[bool]): - def __init__(self, right: Type) -> None: + def __init__(self, right: Type, *, ignore_promotions: bool = False) -> None: self.right = right + self.ignore_promotions = ignore_promotions + self._subtype_kind = ProperSubtypeVisitor.build_subtype_kind( + ignore_promotions=ignore_promotions, + ) + + @staticmethod + def build_subtype_kind(*, ignore_promotions: bool = False) -> SubtypeKind: + return hash(('subtype_proper', ignore_promotions)) + + def _lookup_cache(self, left: Instance, right: Instance) -> bool: + return TypeState.is_cached_subtype_check(self._subtype_kind, left, right) + + def _record_cache(self, left: Instance, right: Instance) -> None: + TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) + + def _is_proper_subtype(self, left: Type, right: Type) -> bool: + return is_proper_subtype(left, right, ignore_promotions=self.ignore_promotions) def visit_unbound_type(self, left: UnboundType) -> bool: # This can be called if there is a bad type annotation. The result probably @@ -1004,19 +1054,20 @@ def visit_deleted_type(self, left: DeletedType) -> bool: def visit_instance(self, left: Instance) -> bool: right = self.right if isinstance(right, Instance): - if TypeState.is_cached_proper_subtype_check(left, right): + if self._lookup_cache(left, right): return True - for base in left.type.mro: - if base._promote and is_proper_subtype(base._promote, right): - TypeState.record_proper_subtype_cache_entry(left, right) - return True + if not self.ignore_promotions: + for base in left.type.mro: + if base._promote and self._is_proper_subtype(base._promote, right): + self._record_cache(left, right) + return True if left.type.has_base(right.type.fullname()): def check_argument(leftarg: Type, rightarg: Type, variance: int) -> bool: if variance == COVARIANT: - return is_proper_subtype(leftarg, rightarg) + return self._is_proper_subtype(leftarg, rightarg) elif variance == CONTRAVARIANT: - return is_proper_subtype(rightarg, leftarg) + return self._is_proper_subtype(rightarg, leftarg) else: return sametypes.is_same_type(leftarg, rightarg) # Map left type to corresponding right instances. @@ -1025,7 +1076,7 @@ def check_argument(leftarg: Type, rightarg: Type, variance: int) -> bool: nominal = all(check_argument(ta, ra, tvar.variance) for ta, ra, tvar in zip(left.args, right.args, right.type.defn.type_vars)) if nominal: - TypeState.record_proper_subtype_cache_entry(left, right) + self._record_cache(left, right) return nominal if (right.type.is_protocol and is_protocol_implementation(left, right, proper_subtype=True)): @@ -1034,29 +1085,30 @@ def check_argument(leftarg: Type, rightarg: Type, variance: int) -> bool: if isinstance(right, CallableType): call = find_member('__call__', left, left) if call: - return is_proper_subtype(call, right) + return self._is_proper_subtype(call, right) return False return False def visit_type_var(self, left: TypeVarType) -> bool: if isinstance(self.right, TypeVarType) and left.id == self.right.id: return True - if left.values and is_subtype(UnionType.make_simplified_union(left.values), self.right): + if left.values and is_subtype(UnionType.make_simplified_union(left.values), self.right, + ignore_promotions=self.ignore_promotions): return True - return is_proper_subtype(left.upper_bound, self.right) + return self._is_proper_subtype(left.upper_bound, self.right) def visit_callable_type(self, left: CallableType) -> bool: right = self.right if isinstance(right, CallableType): - return is_callable_compatible(left, right, is_compat=is_proper_subtype) + return is_callable_compatible(left, right, is_compat=self._is_proper_subtype) elif isinstance(right, Overloaded): - return all(is_proper_subtype(left, item) + return all(self._is_proper_subtype(left, item) for item in right.items()) elif isinstance(right, Instance): - return is_proper_subtype(left.fallback, right) + return self._is_proper_subtype(left.fallback, right) elif isinstance(right, TypeType): # This is unsound, we don't check the __init__ signature. - return left.is_type_obj() and is_proper_subtype(left.ret_type, right.item) + return left.is_type_obj() and self._is_proper_subtype(left.ret_type, right.item) return False def visit_tuple_type(self, left: TupleType) -> bool: @@ -1074,15 +1126,15 @@ def visit_tuple_type(self, left: TupleType) -> bool: # TODO: We shouldn't need this special case. This is currently needed # for isinstance(x, tuple), though it's unclear why. return True - return all(is_proper_subtype(li, iter_type) for li in left.items) - return is_proper_subtype(left.fallback, right) + return all(self._is_proper_subtype(li, iter_type) for li in left.items) + return self._is_proper_subtype(left.fallback, right) elif isinstance(right, TupleType): if len(left.items) != len(right.items): return False for l, r in zip(left.items, right.items): - if not is_proper_subtype(l, r): + if not self._is_proper_subtype(l, r): return False - return is_proper_subtype(left.fallback, right.fallback) + return self._is_proper_subtype(left.fallback, right.fallback) return False def visit_typeddict_type(self, left: TypedDictType) -> bool: @@ -1095,14 +1147,14 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool: if name not in left.items: return False return True - return is_proper_subtype(left.fallback, right) + return self._is_proper_subtype(left.fallback, right) def visit_overloaded(self, left: Overloaded) -> bool: # TODO: What's the right thing to do here? return False def visit_union_type(self, left: UnionType) -> bool: - return all([is_proper_subtype(item, self.right) for item in left.items]) + return all([self._is_proper_subtype(item, self.right) for item in left.items]) def visit_partial_type(self, left: PartialType) -> bool: # TODO: What's the right thing to do here? @@ -1113,10 +1165,10 @@ def visit_type_type(self, left: TypeType) -> bool: right = self.right if isinstance(right, TypeType): # This is unsound, we don't check the __init__ signature. - return is_proper_subtype(left.item, right.item) + return self._is_proper_subtype(left.item, right.item) if isinstance(right, CallableType): # This is also unsound because of __init__. - return right.is_type_obj() and is_proper_subtype(left.item, right.ret_type) + return right.is_type_obj() and self._is_proper_subtype(left.item, right.ret_type) if isinstance(right, Instance): if right.type.fullname() == 'builtins.type': # TODO: Strictly speaking, the type builtins.type is considered equivalent to @@ -1129,7 +1181,7 @@ def visit_type_type(self, left: TypeType) -> bool: return False -def is_more_precise(left: Type, right: Type) -> bool: +def is_more_precise(left: Type, right: Type, *, ignore_promotions: bool = False) -> bool: """Check if left is a more precise type than right. A left is a proper subtype of right, left is also more precise than @@ -1139,4 +1191,4 @@ def is_more_precise(left: Type, right: Type) -> bool: # TODO Should List[int] be more precise than List[Any]? if isinstance(right, AnyType): return True - return is_proper_subtype(left, right) + return is_proper_subtype(left, right, ignore_promotions=ignore_promotions) diff --git a/mypy/typestate.py b/mypy/typestate.py index 337aac21d714..27a3c98a1ccc 100644 --- a/mypy/typestate.py +++ b/mypy/typestate.py @@ -3,7 +3,7 @@ and potentially other mutable TypeInfo state. This module contains mutable global state. """ -from typing import Dict, Set, Tuple, Optional +from typing import Any, Dict, Set, Tuple, Optional MYPY = False if MYPY: @@ -12,6 +12,17 @@ from mypy.types import Instance from mypy.server.trigger import make_trigger +# Represents that the 'left' instance is a subtype of the 'right' instance +SubtypeRelationship = Tuple[Instance, Instance] + +# A hash encoding the specific conditions under which we performed the subtype check. +# (e.g. did we want a proper subtype? A regular subtype while ignoring variance?) +SubtypeKind = int + +# A cache that keeps track of whether the given TypeInfo is a part of a particular +# subtype relationship +SubtypeCache = Dict[TypeInfo, Dict[SubtypeKind, Set[SubtypeRelationship]]] + class TypeState: """This class provides subtype caching to improve performance of subtype checks. @@ -23,13 +34,11 @@ class TypeState: The protocol dependencies however are only stored here, and shouldn't be deleted unless not needed any more (e.g. during daemon shutdown). """ - # 'caches' and 'caches_proper' are subtype caches, implemented as sets of pairs - # of (subtype, supertype), where supertypes are instances of given TypeInfo. + # '_subtype_caches' keeps track of (subtype, supertype) pairs where supertypes are + # instances of the given TypeInfo. The cache also keeps track of the specific + # *kind* of subtyping relationship, which we represent as an arbitrary hashable tuple. # We need the caches, since subtype checks for structural types are very slow. - # _subtype_caches_proper is for caching proper subtype checks (i.e. not assuming that - # Any is consistent with every type). - _subtype_caches = {} # type: ClassVar[Dict[TypeInfo, Set[Tuple[Instance, Instance]]]] - _subtype_caches_proper = {} # type: ClassVar[Dict[TypeInfo, Set[Tuple[Instance, Instance]]]] + _subtype_caches = {} # type: ClassVar[SubtypeCache] # This contains protocol dependencies generated after running a full build, # or after an update. These dependencies are special because: @@ -70,13 +79,11 @@ class TypeState: def reset_all_subtype_caches(cls) -> None: """Completely reset all known subtype caches.""" cls._subtype_caches = {} - cls._subtype_caches_proper = {} @classmethod def reset_subtype_caches_for(cls, info: TypeInfo) -> None: """Reset subtype caches (if any) for a given supertype TypeInfo.""" - cls._subtype_caches.setdefault(info, set()).clear() - cls._subtype_caches_proper.setdefault(info, set()).clear() + cls._subtype_caches.setdefault(info, dict()).clear() @classmethod def reset_all_subtype_caches_for(cls, info: TypeInfo) -> None: @@ -85,20 +92,15 @@ def reset_all_subtype_caches_for(cls, info: TypeInfo) -> None: cls.reset_subtype_caches_for(item) @classmethod - def is_cached_subtype_check(cls, left: Instance, right: Instance) -> bool: - return (left, right) in cls._subtype_caches.setdefault(right.type, set()) - - @classmethod - def is_cached_proper_subtype_check(cls, left: Instance, right: Instance) -> bool: - return (left, right) in cls._subtype_caches_proper.setdefault(right.type, set()) - - @classmethod - def record_subtype_cache_entry(cls, left: Instance, right: Instance) -> None: - cls._subtype_caches.setdefault(right.type, set()).add((left, right)) + def is_cached_subtype_check(cls, kind: SubtypeKind, left: Instance, right: Instance) -> bool: + subtype_kinds = cls._subtype_caches.setdefault(right.type, dict()) + return (left, right) in subtype_kinds.setdefault(kind, set()) @classmethod - def record_proper_subtype_cache_entry(cls, left: Instance, right: Instance) -> None: - cls._subtype_caches_proper.setdefault(right.type, set()).add((left, right)) + def record_subtype_cache_entry(cls, kind: SubtypeKind, + left: Instance, right: Instance) -> None: + subtype_kinds = cls._subtype_caches.setdefault(right.type, dict()) + subtype_kinds.setdefault(kind, set()).add((left, right)) @classmethod def reset_protocol_deps(cls) -> None: diff --git a/test-data/unit/check-isinstance.test b/test-data/unit/check-isinstance.test index 6d76dcdd4605..a78fa8d52d58 100644 --- a/test-data/unit/check-isinstance.test +++ b/test-data/unit/check-isinstance.test @@ -2005,3 +2005,53 @@ def f(x: Union[A, str]) -> None: if isinstance(x, A): x.method_only_in_a() [builtins fixtures/isinstance.pyi] + +[case testIsInstanceInitialNoneCheckSkipsImpossibleCasesNoStrictOptional] +# flags: --strict-optional +from typing import Optional, Union + +class A: pass + +def foo1(x: Union[A, str, None]) -> None: + if x is None: + reveal_type(x) # E: Revealed type is 'None' + elif isinstance(x, A): + reveal_type(x) # E: Revealed type is '__main__.A' + else: + reveal_type(x) # E: Revealed type is 'builtins.str' + +def foo2(x: Optional[str]) -> None: + if x is None: + reveal_type(x) # E: Revealed type is 'None' + elif isinstance(x, A): + reveal_type(x) + else: + reveal_type(x) # E: Revealed type is 'builtins.str' +[builtins fixtures/isinstance.pyi] + +[case testIsInstanceInitialNoneCheckSkipsImpossibleCasesInNoStrictOptional] +# flags: --no-strict-optional +from typing import Optional, Union + +class A: pass + +def foo1(x: Union[A, str, None]) -> None: + if x is None: + # Since None is a subtype of all types in no-strict-optional, + # we can't really narrow the type here + reveal_type(x) # E: Revealed type is 'Union[__main__.A, builtins.str, None]' + elif isinstance(x, A): + # Note that Union[None, A] == A in no-strict-optional + reveal_type(x) # E: Revealed type is '__main__.A' + else: + reveal_type(x) # E: Revealed type is 'builtins.str' + +def foo2(x: Optional[str]) -> None: + if x is None: + reveal_type(x) # E: Revealed type is 'Union[builtins.str, None]' + elif isinstance(x, A): + # Mypy should, however, be able to skip impossible cases + reveal_type(x) + else: + reveal_type(x) # E: Revealed type is 'Union[builtins.str, None]' +[builtins fixtures/isinstance.pyi] diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 4bcbaf626aa8..e0d8862625e3 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -1482,14 +1482,25 @@ reveal_type(f(z='', x=a, y=1)) # E: Revealed type is 'Any' [case testOverloadWithOverlappingItemsAndAnyArgument5] from typing import overload, Any, Union +class A: pass +class B(A): pass + @overload -def f(x: int) -> int: ... +def f(x: B) -> B: ... @overload -def f(x: Union[int, float]) -> float: ... +def f(x: Union[A, B]) -> A: ... def f(x): pass +# Note: overloads ignore promotions so we treat 'int' and 'float' as distinct types +@overload +def g(x: int) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +@overload +def g(x: Union[int, float]) -> float: ... +def g(x): pass + a: Any reveal_type(f(a)) # E: Revealed type is 'Any' +reveal_type(g(a)) # E: Revealed type is 'Any' [case testOverloadWithOverlappingItemsAndAnyArgument6] from typing import overload, Any @@ -4484,3 +4495,41 @@ def g(x: str) -> int: ... [builtins fixtures/list.pyi] [typing fixtures/typing-full.pyi] [out] + +[case testOverloadsIgnorePromotions] +from typing import overload, List, Union, _promote + +class Parent: pass +class Child(Parent): pass + +children: List[Child] +parents: List[Parent] + +@overload +def f(x: Child) -> List[Child]: pass # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +@overload +def f(x: Parent) -> List[Parent]: pass +def f(x: Union[Child, Parent]) -> Union[List[Child], List[Parent]]: + if isinstance(x, Child): + reveal_type(x) # E: Revealed type is '__main__.Child' + return children + else: + reveal_type(x) # E: Revealed type is '__main__.Parent' + return parents + +ints: List[int] +floats: List[float] + +@overload +def g(x: int) -> List[int]: pass +@overload +def g(x: float) -> List[float]: pass +def g(x: Union[int, float]) -> Union[List[int], List[float]]: + if isinstance(x, int): + reveal_type(x) # E: Revealed type is 'builtins.int' + return ints + else: + reveal_type(x) # E: Revealed type is 'builtins.float' + return floats + +[builtins fixtures/isinstancelist.pyi] diff --git a/test-data/unit/fixtures/isinstancelist.pyi b/test-data/unit/fixtures/isinstancelist.pyi index 99aca1befe39..1831411319ef 100644 --- a/test-data/unit/fixtures/isinstancelist.pyi +++ b/test-data/unit/fixtures/isinstancelist.pyi @@ -14,6 +14,7 @@ def issubclass(x: object, t: Union[type, Tuple]) -> bool: pass class int: def __add__(self, x: int) -> int: pass +class float: pass class bool(int): pass class str: def __add__(self, x: str) -> str: pass From 18305f2c247f7d40ebc74e9683219e28d16efef0 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Fri, 20 Jul 2018 11:05:02 -0700 Subject: [PATCH 10/21] WIP: Make checks less sensitive to 'Any' weirdness --- mypy/meet.py | 6 ++-- test-data/unit/check-unreachable-code.test | 40 ++++++++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/mypy/meet.py b/mypy/meet.py index d81a0f98133d..113529630a19 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -8,7 +8,7 @@ UninhabitedType, TypeType, TypeOfAny, Overloaded, FunctionLike, ) from mypy.subtypes import ( - is_equivalent, is_subtype, is_protocol_implementation, is_callable_compatible, + is_equivalent, is_subtype, is_protocol_implementation, is_callable_compatible, is_proper_subtype ) from mypy.erasetype import erase_type from mypy.maptype import map_instance_to_supertype @@ -120,8 +120,8 @@ def _is_overlapping_types(left: Type, right: Type) -> bool: # If this check fails, we start checking to see if there exists a # *partial* overlap between types. - if (is_subtype(left, right, ignore_promotions=ignore_promotions) - or is_subtype(right, left, ignore_promotions=ignore_promotions)): + if (is_proper_subtype(left, right, ignore_promotions=ignore_promotions) + or is_proper_subtype(right, left, ignore_promotions=ignore_promotions)): return True # See the docstring for 'get_possible_variants' for more info on what the diff --git a/test-data/unit/check-unreachable-code.test b/test-data/unit/check-unreachable-code.test index b86154302f8c..988038264d54 100644 --- a/test-data/unit/check-unreachable-code.test +++ b/test-data/unit/check-unreachable-code.test @@ -564,3 +564,43 @@ if typing.TYPE_CHECKING: reveal_type(x) # E: Revealed type is '__main__.B' [builtins fixtures/isinstancelist.pyi] + +[case testUnreachableWhenSuperclassIsAny] +# flags: --strict-optional +from typing import Any + +# This can happen if we're importing a class from a missing module +Parent: Any +class Child(Parent): + def foo(self) -> int: + reveal_type(self) # E: Revealed type is '__main__.Child' + if self is None: + reveal_type(self) + return None + reveal_type(self) # E: Revealed type is '__main__.Child' + return 3 + + def bar(self) -> int: + self = super(Child, self).something() + reveal_type(self) # E: Revealed type is '__main__.Child' + if self is None: + reveal_type(self) + return None + reveal_type(self) # E: Revealed type is '__main__.Child' + return 3 +[builtins fixtures/isinstance.pyi] + +[case testUnreachableWhenSuperclassIsAnyNoStrictOptional] +# flags: --no-strict-optional +from typing import Any + +Parent: Any +class Child(Parent): + def foo(self) -> int: + reveal_type(self) # E: Revealed type is '__main__.Child' + if self is None: + reveal_type(self) # E: Revealed type is '__main__.Child' + return None + reveal_type(self) # E: Revealed type is '__main__.Child' + return 3 +[builtins fixtures/isinstance.pyi] From bb6fbc5a5653413ba8f843263edee08a7e5f8704 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Fri, 20 Jul 2018 13:31:31 -0700 Subject: [PATCH 11/21] WIP: Refine and clean up overlapping meets logic --- mypy/meet.py | 49 ++++++++++++++++++++----------------------------- 1 file changed, 20 insertions(+), 29 deletions(-) diff --git a/mypy/meet.py b/mypy/meet.py index 113529630a19..7041185de89e 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -8,7 +8,8 @@ UninhabitedType, TypeType, TypeOfAny, Overloaded, FunctionLike, ) from mypy.subtypes import ( - is_equivalent, is_subtype, is_protocol_implementation, is_callable_compatible, is_proper_subtype + is_equivalent, is_subtype, is_protocol_implementation, is_callable_compatible, + is_proper_subtype, ) from mypy.erasetype import erase_type from mypy.maptype import map_instance_to_supertype @@ -119,6 +120,8 @@ def _is_overlapping_types(left: Type, right: Type) -> bool: # We check for complete overlaps first as a general-purpose failsafe. # If this check fails, we start checking to see if there exists a # *partial* overlap between types. + # + # These checks will also handle the NoneTyp and UninhabitedType cases for us. if (is_proper_subtype(left, right, ignore_promotions=ignore_promotions) or is_proper_subtype(right, left, ignore_promotions=ignore_promotions)): @@ -126,21 +129,22 @@ def _is_overlapping_types(left: Type, right: Type) -> bool: # See the docstring for 'get_possible_variants' for more info on what the # following lines are doing. - # - # Note that we use 'left_possible' and 'right_possible' in two different - # locations: immediately after to handle TypeVars, and near the end of - # 'is_overlapping_types' to handle types like Unions or Overloads. left_possible = get_possible_variants(left) right_possible = get_possible_variants(right) - # We start by checking TypeVars first: this is because in some of the checks - # below, it's convenient to just return early in certain cases. + # We start by checking multi-variant types like Unions first. We also perform + # the same logic if either type happens to be a TypeVar. # - # If we were to defer checking TypeVars to down below, that would end up - # causing issues since the TypeVars would never have the opportunity to - # try binding to the relevant types. - if isinstance(left, TypeVarType) or isinstance(right, TypeVarType): + # Handling the TypeVars now lets us simulate having them bind to the corresponding + # type -- if we deferred these checks, the "return-early" logic of the other + # checks will prevent us from detecting certain overlaps. + # + # If both types are singleton variants (and are not TypeVars), we've hit the base case: + # we skip these checks to avoid infinitely recursing. + + if (len(left_possible) > 1 or len(right_possible) > 1 + or isinstance(left, TypeVarType) or isinstance(right, TypeVarType)): for l in left_possible: for r in right_possible: if _is_overlapping_types(l, r): @@ -151,6 +155,7 @@ def _is_overlapping_types(left: Type, right: Type) -> bool: # if one one of the types is None and we're running in strict-optional # mode. (We must perform this check after the TypeVar checks because # a TypeVar could be bound to None, for example.) + if experiments.STRICT_OPTIONAL: if isinstance(left, NoneTyp) != isinstance(right, NoneTyp): return False @@ -179,12 +184,12 @@ def _is_overlapping_types(left: Type, right: Type) -> bool: # Next, we handle single-variant types that cannot be inherently partially overlapping, # but do require custom logic to inspect. + # + # As before, we degrade into 'Instance' whenever possible. if isinstance(left, TypeType) and isinstance(right, TypeType): - return _is_overlapping_types(left.item, right.item) - elif isinstance(left, TypeType) or isinstance(right, TypeType): # TODO: Can Callable[[...], T] and Type[T] be partially overlapping? - return False + return _is_overlapping_types(left.item, right.item) if isinstance(left, CallableType) and isinstance(right, CallableType): return is_callable_compatible(left, right, @@ -196,7 +201,7 @@ def _is_overlapping_types(left: Type, right: Type) -> bool: elif isinstance(right, CallableType): right = right.fallback - # Next, we check if left and right are instances + # Finally, we handle the case where left and right are instances. if isinstance(left, Instance) and isinstance(right, Instance): if left.type.is_protocol and is_protocol_implementation(right, left): @@ -220,20 +225,6 @@ def _is_overlapping_types(left: Type, right: Type) -> bool: if _is_overlapping_types(left_arg, right_arg): return True - return False - - # We handle all remaining types here: in particular, types like - # UnionType, Overloaded, NoneTyp, and UninhabitedType. - # - # We deliberately skip comparing two singleton variant types to avoid - # infinitely recursing. - - if len(left_possible) >= 1 or len(right_possible) >= 1: - for a in left_possible: - for b in right_possible: - if _is_overlapping_types(a, b): - return True - # We ought to have handled every case by now: we conclude the # two types are not overlapping, either completely or partially. From 55a01cf1f38397fe6173196f2ebc23623c57c2ee Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Fri, 20 Jul 2018 13:40:59 -0700 Subject: [PATCH 12/21] Add unit tests to cover a failure fixed by previous commit --- test-data/unit/check-overloading.test | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index e0d8862625e3..5184ecfad049 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -4533,3 +4533,16 @@ def g(x: Union[int, float]) -> Union[List[int], List[float]]: return floats [builtins fixtures/isinstancelist.pyi] + +[case testOverloadsTypesAndUnions] +from typing import overload, Type, Union + +class A: pass +class B: pass + +@overload +def f(x: Type[A]) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +@overload +def f(x: Union[Type[A], Type[B]]) -> str: ... +def f(x: Union[Type[A], Type[B]]) -> Union[int, str]: + return 1 From 7e2eb9b9fbf526d14a6eb4eddf27f166cca64a65 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Wed, 8 Aug 2018 10:34:23 -0700 Subject: [PATCH 13/21] WIP: refactor operator code --- mypy/checker.py | 119 ++++++++++++++----------- test-data/unit/check-classes.test | 78 +++++++++++++++- test-data/unit/check-expressions.test | 2 +- test-data/unit/fixtures/isinstance.pyi | 4 +- test-data/unit/typexport-basic.test | 2 + 5 files changed, 149 insertions(+), 56 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index e497d6aeb09b..8d500e45f6f7 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1041,72 +1041,91 @@ def check_overlapping_op_methods(self, """Check for overlapping method and reverse method signatures. Assume reverse method has valid argument count and kinds. + + Precondition: + If the reverse operator method accepts some argument of type + X, the forward operator method must belong to class X. + + For example, if we have the reverse operator `A.__radd__(B)`, then the + corresponding forward operator must have the type `B.__add__(...)`. """ - # Reverse operator method that overlaps unsafely with the - # forward operator method can result in type unsafety. This is - # similar to overlapping overload variants. + # Note: Suppose we have two operator methods "A.__rOP__(B) -> R1" and + # "B.__OP__(C) -> R2". We check if these two methods are unsafely overlapping + # by using the following algorithm: + # + # 1. Rewrite "B.__OP__(C) -> R1" to "temp1(B, C) -> R1" + # + # 2. Rewrite "A.__rOP__(B) -> R2" to "temp2(B, A) -> R2" # - # This example illustrates the issue: + # 3. Treat temp1 and temp2 as if they were both variants in the same + # overloaded function. (This mirrors how the Python runtime calls + # operator methods: we first try __OP__, then __rOP__.) # - # class X: pass - # class A: - # def __add__(self, x: X) -> int: - # if isinstance(x, X): - # return 1 - # return NotImplemented - # class B: - # def __radd__(self, x: A) -> str: return 'x' - # class C(X, B): pass - # def f(b: B) -> None: - # A() + b # Result is 1, even though static type seems to be str! - # f(C()) + # If the first signature is unsafely overlapping with the second, + # report an error. # - # The reason for the problem is that B and X are overlapping - # types, and the return types are different. Also, if the type - # of x in __radd__ would not be A, the methods could be - # non-overlapping. + # 4. However, if temp1 shadows temp2 (e.g. the __rOP__ method can never + # be called), do NOT report an error. + # + # This behavior deviates from how we handle overloads -- many of the + # modules in typeshed seem to define __OP__ methods that shadow the + # corresponding __rOP__ method. + # + # Note: we do not attempt to handle unsafe overlaps related to multiple + # inheritance. for forward_item in union_items(forward_type): if isinstance(forward_item, CallableType): - # TODO check argument kinds - if len(forward_item.arg_types) < 1: - # Not a valid operator method -- can't succeed anyway. - return - - # Construct normalized function signatures corresponding to the - # operator methods. The first argument is the left operand and the - # second operand is the right argument -- we switch the order of - # the arguments of the reverse method. - forward_tweaked = CallableType( - [forward_base, forward_item.arg_types[0]], - [nodes.ARG_POS] * 2, - [None] * 2, - forward_item.ret_type, - forward_item.fallback, - name=forward_item.name) - reverse_args = reverse_type.arg_types - reverse_tweaked = CallableType( - [reverse_args[1], reverse_args[0]], - [nodes.ARG_POS] * 2, - [None] * 2, - reverse_type.ret_type, - fallback=self.named_type('builtins.function'), - name=reverse_type.name) - - if is_unsafe_overlapping_operator_signatures( - forward_tweaked, reverse_tweaked): + if self.is_unsafe_overlapping_op(forward_item, forward_base, reverse_type): self.msg.operator_method_signatures_overlap( reverse_class, reverse_name, forward_base, forward_name, context) elif isinstance(forward_item, Overloaded): for item in forward_item.items(): - self.check_overlapping_op_methods( - reverse_type, reverse_name, reverse_class, - item, forward_name, forward_base, context) + if not self.is_unsafe_overlapping_op(item, forward_base, reverse_type): + return + self.msg.operator_method_signatures_overlap( + reverse_class, reverse_name, + forward_base, forward_name, + context) elif not isinstance(forward_item, AnyType): self.msg.forward_operator_not_callable(forward_name, context) + def is_unsafe_overlapping_op(self, + forward_item: CallableType, + forward_base: Type, + reverse_type: CallableType) -> bool: + # TODO check argument kinds + if len(forward_item.arg_types) < 1: + # Not a valid operator method -- can't succeed anyway. + return False + + # Erase the type if necessary to make sure we don't have a dangling + # TypeVar in forward_tweaked + forward_base_erased = forward_base + if isinstance(forward_base, TypeVarType): + forward_base_erased = erase_to_bound(forward_base) + + # Construct normalized function signatures corresponding to the + # operator methods. The first argument is the left operand and the + # second operand is the right argument -- we switch the order of + # the arguments of the reverse method. + + forward_tweaked = forward_item.copy_modified( + arg_types=[forward_base_erased, forward_item.arg_types[0]], + arg_kinds=[nodes.ARG_POS] * 2, + arg_names=[None] * 2, + ) + reverse_tweaked = reverse_type.copy_modified( + arg_types=[reverse_type.arg_types[1], reverse_type.arg_types[0]], + arg_kinds=[nodes.ARG_POS] * 2, + arg_names=[None] * 2, + ) + + return is_unsafe_partially_overlapping_overload_signatures( + forward_tweaked, reverse_tweaked) + def check_inplace_operator_method(self, defn: FuncBase) -> None: """Check an inplace operator method such as __iadd__. diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index 8dfe4717a479..84f08b3a472f 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -1599,6 +1599,45 @@ class A: class B(A): def __add__(self, x): pass +[case testOperatorMethodAgainstSameType] +class A: + def __add__(self, x: int) -> 'A': + if isinstance(x, int): + return A() + else: + return NotImplemented + + def __radd__(self, x: 'A') -> 'A': + if isinstance(x, A): + return A() + else: + return NotImplemented + +class B(A): pass + +# Note: This is actually a runtime error. If we run x.__add__(y) +# where x and y are *not* the same type, Python will not try +# calling __radd__. +# +# It is, however, difficult to detect this statically -- for example, +# if we have a variable of type A, that doesn't mean that variable is +# *exactly* an instance of A. It could be a subclass, for example. +# +# We also can't really print a warning when checking the definition +# of A.__add__ and A.__radd__ -- the user might intentionally be defining +# those two methods to have non-conventional semantics. +# +# As a result, we don't try and handle this case and pretend that +# Python actually *will* call `__radd__` even when the types are +# the same. +reveal_type(A() + A()) # E: Revealed type is '__main__.A' + +# Here, Python *will* call __radd__(...) so the revealed type +# and the runtime behavior match. +reveal_type(B() + A()) # E: Revealed type is '__main__.A' +reveal_type(A() + B()) # E: Revealed type is '__main__.A' +[builtins fixtures/isinstance.pyi] + [case testOperatorMethodOverrideWithIdenticalOverloadedType] from foo import * [file foo.pyi] @@ -1755,20 +1794,51 @@ class B: def __radd__(*self) -> int: pass def __rsub__(*self: 'B') -> int: pass -[case testReverseOperatorTypeVar] +[case testReverseOperatorTypeVar1] +from typing import TypeVar, Any +T = TypeVar("T", bound='Real') +class Real: + def __add__(self, other: Any) -> str: ... +class Fraction(Real): + def __radd__(self, other: T) -> T: ... # E: Signatures "__radd__" of "Fraction" and "__add__" of "T" are unsafely overlapping + +reveal_type(Real() + Fraction()) # E: Revealed type is 'builtins.str' +reveal_type(Fraction() + Fraction()) # E: Revealed type is 'builtins.str' + +[case testReverseOperatorTypeVar2] from typing import TypeVar T = TypeVar("T", bound='Real') class Real: - def __add__(self, other) -> str: ... + def __add__(self, other: Fraction) -> str: ... class Fraction(Real): def __radd__(self, other: T) -> T: ... # E: Signatures of "__radd__" of "Fraction" and "__add__" of "T" are unsafely overlapping +reveal_type(Real() + Fraction()) # E: Revealed type is 'builtins.str' +reveal_type(Fraction() + Fraction()) # E: Revealed type is 'builtins.str' + +[case testReverseOperatorTypeVar3] +from typing import TypeVar, Any +T = TypeVar("T", bound='Real') +class Real: + def __add__(self, other: FractionChild) -> str: ... +class Fraction(Real): + def __radd__(self, other: T) -> T: ... # E: Signatures of "__radd__" of "Fraction" and "__add__" of "T" are unsafely overlapping +class FractionChild(Fraction): pass + +reveal_type(Real() + Fraction()) # E: Revealed type is '__main__.Real*' +reveal_type(FractionChild() + Fraction()) # E: Revealed type is '__main__.FractionChild*' +reveal_type(FractionChild() + FractionChild()) # E: Revealed type is 'builtins.str' + +# Technically not correct +reveal_type(Fraction() + Fraction()) # E: Revealed type is '__main__.Fraction*' + [case testReverseOperatorTypeType] from typing import TypeVar, Type class Real(type): - def __add__(self, other) -> str: ... + def __add__(self, other: FractionChild) -> str: ... class Fraction(Real): def __radd__(self, other: Type['A']) -> Real: ... # E: Signatures of "__radd__" of "Fraction" and "__add__" of "Type[A]" are unsafely overlapping +class FractionChild(Fraction): pass class A(metaclass=Real): pass @@ -1811,7 +1881,7 @@ class B: @overload def __radd__(self, x: A) -> str: pass # Error class X: - def __add__(self, x): pass + def __add__(self, x: B) -> int: pass [out] tmp/foo.pyi:6: error: Signatures of "__radd__" of "B" and "__add__" of "X" are unsafely overlapping diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 93023dbb3ac4..39a2b11e3d35 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -664,7 +664,7 @@ A() + cast(Any, 1) class C: def __gt__(self, x: 'A') -> object: pass class A: - def __lt__(self, x: C) -> int: pass + def __lt__(self, x: C) -> int: pass # E: Signatures of "__lt__" of "A" and "__gt__" of "C" are unsafely overlapping class B: def __gt__(self, x: A) -> str: pass s = None # type: str diff --git a/test-data/unit/fixtures/isinstance.pyi b/test-data/unit/fixtures/isinstance.pyi index ded946ce73fe..35535b9a588f 100644 --- a/test-data/unit/fixtures/isinstance.pyi +++ b/test-data/unit/fixtures/isinstance.pyi @@ -1,4 +1,4 @@ -from typing import Tuple, TypeVar, Generic, Union +from typing import Tuple, TypeVar, Generic, Union, cast, Any T = TypeVar('T') @@ -22,3 +22,5 @@ class bool(int): pass class str: def __add__(self, other: 'str') -> 'str': pass class ellipsis: pass + +NotImplemented = cast(Any, None) diff --git a/test-data/unit/typexport-basic.test b/test-data/unit/typexport-basic.test index 263be9837616..890f3d7e2ed0 100644 --- a/test-data/unit/typexport-basic.test +++ b/test-data/unit/typexport-basic.test @@ -147,6 +147,8 @@ class type: pass class function: pass class str: pass [out] +tmp/builtins.py:5: error: Signatures of "__lt__" of "int" and "__gt__" of "int" are unsafely overlapping +tmp/builtins.py:6: error: Signatures of "__gt__" of "int" and "__lt__" of "int" are unsafely overlapping ComparisonExpr(3) : builtins.bool ComparisonExpr(4) : builtins.bool ComparisonExpr(5) : builtins.bool From f098191866aaad0242d7c47f48edce95f711fa7e Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Wed, 8 Aug 2018 15:18:32 -0700 Subject: [PATCH 14/21] Fix typo in test --- test-data/unit/check-classes.test | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index 84f08b3a472f..8c75e1396cae 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -1800,7 +1800,7 @@ T = TypeVar("T", bound='Real') class Real: def __add__(self, other: Any) -> str: ... class Fraction(Real): - def __radd__(self, other: T) -> T: ... # E: Signatures "__radd__" of "Fraction" and "__add__" of "T" are unsafely overlapping + def __radd__(self, other: T) -> T: ... # E: Signatures of "__radd__" of "Fraction" and "__add__" of "T" are unsafely overlapping reveal_type(Real() + Fraction()) # E: Revealed type is 'builtins.str' reveal_type(Fraction() + Fraction()) # E: Revealed type is 'builtins.str' From 539ed6aa4fdf43583cc619d6bbac832e2d25b9dc Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Sat, 11 Aug 2018 09:49:27 -0700 Subject: [PATCH 15/21] WIP: Try adding more precise support for reversable operators --- mypy/checker.py | 33 ++++++--- mypy/checkexpr.py | 115 +++++++++++++++++++++++++++++- mypy/messages.py | 9 +++ mypy/nodes.py | 9 +++ test-data/unit/check-classes.test | 22 +++++- 5 files changed, 176 insertions(+), 12 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index d23903b63512..58899f2af839 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1022,6 +1022,12 @@ def check_reverse_op_method(self, defn: FuncItem, opt_meta = item.type.metaclass_type if opt_meta is not None: forward_inst = opt_meta + #if (isinstance(forward_inst, (Instance, UnionType)) + # and not forward_inst.has_readable_member(forward_name)): + # self.msg + if isinstance(forward_inst, Instance) and forward_inst.type.fullname() == defn.info.fullname() and not forward_inst.has_readable_member(forward_name): + self.msg.object_with_reverse_operator_missing_forward_operator(forward_name, forward_inst.type, reverse_name, defn) + return if not (isinstance(forward_inst, (Instance, UnionType)) and forward_inst.has_readable_member(forward_name)): return @@ -1085,12 +1091,11 @@ def check_overlapping_op_methods(self, forward_base, forward_name, context) elif isinstance(forward_item, Overloaded): for item in forward_item.items(): - if not self.is_unsafe_overlapping_op(item, forward_base, reverse_type): - return - self.msg.operator_method_signatures_overlap( - reverse_class, reverse_name, - forward_base, forward_name, - context) + if self.is_unsafe_overlapping_op(item, forward_base, reverse_type): + self.msg.operator_method_signatures_overlap( + reverse_class, reverse_name, + forward_base, forward_name, + context) elif not isinstance(forward_item, AnyType): self.msg.forward_operator_not_callable(forward_name, context) @@ -1125,8 +1130,20 @@ def is_unsafe_overlapping_op(self, arg_names=[None] * 2, ) - return is_unsafe_partially_overlapping_overload_signatures( - forward_tweaked, reverse_tweaked) + reverse_base_erased = reverse_type.arg_types[1] + if isinstance(reverse_base_erased, TypeVarType): + reverse_base_erased = erase_to_bound(reverse_base_erased) + + if is_same_type(reverse_base_erased, forward_base_erased): + return False + elif is_proper_subtype(reverse_base_erased, forward_base_erased): + first = reverse_tweaked + second = forward_tweaked + else: + first = forward_tweaked + second = reverse_tweaked + + return is_unsafe_partially_overlapping_overload_signatures(first, second) def check_inplace_operator_method(self, defn: FuncBase) -> None: """Check an inplace operator method such as __iadd__. diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 5a816c80783a..29d042bf25a2 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -42,7 +42,9 @@ from mypy import join from mypy.meet import narrow_declared_type from mypy.maptype import map_instance_to_supertype -from mypy.subtypes import is_subtype, is_equivalent, find_member, non_method_protocol_members +from mypy.subtypes import ( + is_subtype, is_proper_subtype, is_equivalent, find_member, non_method_protocol_members, +) from mypy import applytype from mypy import erasetype from mypy.checkmember import analyze_member_access, type_object_type, bind_self @@ -1904,6 +1906,112 @@ def check_op_local(self, method: str, base_type: Type, arg: Expression, context, arg_messages=local_errors, callable_name=callable_name, object_type=object_type) + def check_op_reversible(self, + op_name: str, + left_type: Type, + left_expr: Expression, + right_type: Type, + right_expr: Expression, + context: Context) -> Tuple[Type, Type]: + def lookup_operator(op_name: str, base_type: Type) -> Optional[CallableType]: + if not self.has_member(base_type, op_name): + return None + local_errors = self.msg.clean_copy() + local_errors.disable_count = 0 + method = analyze_member_access( + name=op_name, + typ=base_type, + node=context, + is_lvalue=False, + is_super=False, + is_operator=True, + builtin_type=self.named_type, + not_ready_callback=self.not_ready_callback, + msg=local_errors, + original_type=base_type, + chk=self.chk, + ) + if local_errors.is_errors(): + return None + elif isinstance(method, CallableType): + return method + else: + return None + + if isinstance(left_type, AnyType): + # If either side is Any, we can't necessarily conclude anything. + any_type = AnyType(TypeOfAny.from_another_any, source_any=left_type) + return any_type, any_type + if isinstance(right_type, AnyType): + any_type = AnyType(TypeOfAny.from_another_any, source_any=right_type) + return any_type, any_type + + rev_op_name = self.get_reverse_op_method(op_name) + + # Stage 1: Get all variants + variants_raw = [] # type: List[Tuple[Optional[CallableType], Type, Expression]] + + left_op = lookup_operator(op_name, left_type) + right_op = lookup_operator(rev_op_name, right_type) + + bias_right = is_proper_subtype(right_type, left_type) + + #if is_same_type(left_type, right_type): + # variants_raw.append((left_op, left_type, right_expr)) + #''' + if (is_proper_subtype(right_type, left_type) + and isinstance(left_type, Instance) + and isinstance(right_type, Instance) + and left_type.type.get_definition_mro_name(op_name) != right_type.type.get_definition_mro_name(rev_op_name)): + variants_raw.append((right_op, right_type, left_expr)) + variants_raw.append((left_op, left_type, right_expr)) + else: + variants_raw.append((left_op, left_type, right_expr)) + variants_raw.append((right_op, right_type, left_expr)) + + is_python_2 = self.chk.options.python_version[0] == 2 + if is_python_2 and op_name in nodes.ops_falling_back_to_cmp: + cmp_method = nodes.comparison_fallback_method + left_cmp_op = lookup_operator(cmp_method, left_type) + right_cmp_op = lookup_operator(cmp_method, right_type) + + if bias_right: + variants_raw.append((right_cmp_op, right_type, left_expr)) + variants_raw.append((left_cmp_op, left_type, right_expr)) + else: + variants_raw.append((left_cmp_op, left_type, right_expr)) + variants_raw.append((right_cmp_op, right_type, left_expr)) + + variants = [(op, obj, arg) for (op, obj, arg) in variants_raw if op is not None] + + results = [] + errors = [] + for method, obj, arg in variants: + local_errors = self.msg.clean_copy() + local_errors.disable_count = 0 + + if isinstance(obj, Instance): + # TODO: Find out in which class the method was defined originally? + # TODO: Support non-Instance types. + callable_name = '{}.{}'.format(obj.type.fullname(), method) # type: Optional[str] + else: + callable_name = None + result = self.check_call(method, [arg], [nodes.ARG_POS], + context, arg_messages=local_errors, + callable_name=callable_name, object_type=obj) + if local_errors.is_errors(): + results.append(result) + errors.append(local_errors) + else: + return result + + if len(errors) > 0: + self.msg.add_errors(errors[0]) + return results[0] + else: + return self.check_op_local(op_name, left_type, right_expr, context, + self.msg) + def check_op(self, method: str, base_type: Type, arg: Expression, context: Context, allow_reverse: bool = False) -> Tuple[Type, Type]: @@ -1911,6 +2019,11 @@ def check_op(self, method: str, base_type: Type, arg: Expression, Return tuple (result type, inferred operator method type). """ + + if allow_reverse: + return self.check_op_reversible(method, base_type, TempNode(base_type), self.accept(arg), arg, context) + + # Use a local error storage for errors related to invalid argument # type (but NOT other errors). This error may need to be suppressed # for operators which support __rX methods. diff --git a/mypy/messages.py b/mypy/messages.py index 3f5d647ff6c1..aa154587964e 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1000,6 +1000,15 @@ def overloaded_signatures_ret_specific(self, index: int, context: Context) -> No self.fail('Overloaded function implementation cannot produce return type ' 'of signature {}'.format(index), context) + def object_with_reverse_operator_missing_forward_operator(self, + forward_method: str, + reverse_class: TypeInfo, + reverse_method: str, + context: Context) -> None: + self.fail("Cannot define a reverse operator method {} that accepts an instance of " + "the containing class {} without also defining the corresponding forward" + "operator method {}".format(reverse_method, reverse_class.name(), forward_method), context) + def operator_method_signatures_overlap( self, reverse_class: TypeInfo, reverse_method: str, forward_class: Type, forward_method: str, context: Context) -> None: diff --git a/mypy/nodes.py b/mypy/nodes.py index 9f8d10666251..31db25103f0c 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -2240,6 +2240,15 @@ def __getitem__(self, name: str) -> 'SymbolTableNode': def __repr__(self) -> str: return '' % self.fullname() + def get_definition_mro_name(self, name: str) -> Optional[str]: + if self.mro is None: + return None + + for cls in self.mro: + if cls.names.get(name): + return cls.fullname() + return None + def has_readable_member(self, name: str) -> bool: return self.get(name) is not None diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index 8c75e1396cae..6c846a405dce 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -1802,10 +1802,14 @@ class Real: class Fraction(Real): def __radd__(self, other: T) -> T: ... # E: Signatures of "__radd__" of "Fraction" and "__add__" of "T" are unsafely overlapping -reveal_type(Real() + Fraction()) # E: Revealed type is 'builtins.str' +# Note: When doing A + B and if B is a subtype of A, we will always call B.__radd__(A) first +# and only try A.__add__(B) second if necessary. +reveal_type(Real() + Fraction()) # E: Revealed type is '__main__.Real*' + +# Note: When doing A + A, we only ever call A.__add__(A), never A.__radd__(A). reveal_type(Fraction() + Fraction()) # E: Revealed type is 'builtins.str' -[case testReverseOperatorTypeVar2] +[case testReverseOperatorTypeVar2a] from typing import TypeVar T = TypeVar("T", bound='Real') class Real: @@ -1813,7 +1817,19 @@ class Real: class Fraction(Real): def __radd__(self, other: T) -> T: ... # E: Signatures of "__radd__" of "Fraction" and "__add__" of "T" are unsafely overlapping -reveal_type(Real() + Fraction()) # E: Revealed type is 'builtins.str' +reveal_type(Real() + Fraction()) # E: Revealed type is '__main__.Real*' +reveal_type(Fraction() + Fraction()) # E: Revealed type is 'builtins.str' + + +[case testReverseOperatorTypeVar2b] +from typing import TypeVar +T = TypeVar("T", Real, Fraction) +class Real: + def __add__(self, other: Fraction) -> str: ... +class Fraction: + def __radd__(self, other: T) -> T: ... # E: Signatures of "__radd__" of "Fraction" and "__add__" of "T" are unsafely overlapping + +reveal_type(Real() + Fraction()) # E: Revealed type is '__main__.Real*' reveal_type(Fraction() + Fraction()) # E: Revealed type is 'builtins.str' [case testReverseOperatorTypeVar3] From b5e42d30b29724ba14b1f39ca8bafa21cdcf62cc Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Mon, 13 Aug 2018 16:13:16 -0700 Subject: [PATCH 16/21] WIP: operators refinement (incomplete) --- mypy/checker.py | 10 +---- mypy/checkexpr.py | 64 ++++++++++++++++++--------- mypy/messages.py | 23 ++++++---- mypy/nodes.py | 22 +++++++++ test-data/unit/check-attr.test | 8 ++-- test-data/unit/check-classes.test | 32 +++++--------- test-data/unit/check-expressions.test | 4 +- test-data/unit/check-namedtuple.test | 2 +- test-data/unit/check-statements.test | 2 +- test-data/unit/pythoneval.test | 8 ++-- test-data/unit/typexport-basic.test | 2 - 11 files changed, 104 insertions(+), 73 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 58899f2af839..d1d511cab247 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1022,12 +1022,6 @@ def check_reverse_op_method(self, defn: FuncItem, opt_meta = item.type.metaclass_type if opt_meta is not None: forward_inst = opt_meta - #if (isinstance(forward_inst, (Instance, UnionType)) - # and not forward_inst.has_readable_member(forward_name)): - # self.msg - if isinstance(forward_inst, Instance) and forward_inst.type.fullname() == defn.info.fullname() and not forward_inst.has_readable_member(forward_name): - self.msg.object_with_reverse_operator_missing_forward_operator(forward_name, forward_inst.type, reverse_name, defn) - return if not (isinstance(forward_inst, (Instance, UnionType)) and forward_inst.has_readable_member(forward_name)): return @@ -1130,13 +1124,13 @@ def is_unsafe_overlapping_op(self, arg_names=[None] * 2, ) - reverse_base_erased = reverse_type.arg_types[1] + reverse_base_erased = reverse_type.arg_types[0] if isinstance(reverse_base_erased, TypeVarType): reverse_base_erased = erase_to_bound(reverse_base_erased) if is_same_type(reverse_base_erased, forward_base_erased): return False - elif is_proper_subtype(reverse_base_erased, forward_base_erased): + elif is_subtype(reverse_base_erased, forward_base_erased): first = reverse_tweaked second = forward_tweaked else: diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 29d042bf25a2..79215c2b5fb3 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1913,12 +1913,21 @@ def check_op_reversible(self, right_type: Type, right_expr: Expression, context: Context) -> Tuple[Type, Type]: - def lookup_operator(op_name: str, base_type: Type) -> Optional[CallableType]: - if not self.has_member(base_type, op_name): - return None + # TODO: Document this kludge + unions_present = isinstance(left_type, UnionType) + + def make_local_errors() -> MessageBuilder: local_errors = self.msg.clean_copy() local_errors.disable_count = 0 - method = analyze_member_access( + if unions_present: + local_errors.disable_type_names += 1 + return local_errors + + def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: + if not self.has_member(base_type, op_name): + return None + local_errors = make_local_errors() + member = analyze_member_access( name=op_name, typ=base_type, node=context, @@ -1933,10 +1942,8 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[CallableType]: ) if local_errors.is_errors(): return None - elif isinstance(method, CallableType): - return method else: - return None + return member if isinstance(left_type, AnyType): # If either side is Any, we can't necessarily conclude anything. @@ -1954,12 +1961,14 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[CallableType]: left_op = lookup_operator(op_name, left_type) right_op = lookup_operator(rev_op_name, right_type) + warn_about_uncalled_reverse_operator = False bias_right = is_proper_subtype(right_type, left_type) - #if is_same_type(left_type, right_type): - # variants_raw.append((left_op, left_type, right_expr)) - #''' - if (is_proper_subtype(right_type, left_type) + if op_name in nodes.op_methods_that_shortcut and is_same_type(left_type, right_type): + variants_raw.append((left_op, left_type, right_expr)) + if right_op is not None: + warn_about_uncalled_reverse_operator = True + elif (is_subtype(right_type, left_type) and isinstance(left_type, Instance) and isinstance(right_type, Instance) and left_type.type.get_definition_mro_name(op_name) != right_type.type.get_definition_mro_name(rev_op_name)): @@ -1984,33 +1993,44 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[CallableType]: variants = [(op, obj, arg) for (op, obj, arg) in variants_raw if op is not None] - results = [] errors = [] for method, obj, arg in variants: - local_errors = self.msg.clean_copy() - local_errors.disable_count = 0 + local_errors = make_local_errors() if isinstance(obj, Instance): # TODO: Find out in which class the method was defined originally? # TODO: Support non-Instance types. - callable_name = '{}.{}'.format(obj.type.fullname(), method) # type: Optional[str] + callable_name = '{}.{}'.format(obj.type.fullname(), op_name) # type: Optional[str] else: callable_name = None result = self.check_call(method, [arg], [nodes.ARG_POS], context, arg_messages=local_errors, callable_name=callable_name, object_type=obj) if local_errors.is_errors(): - results.append(result) errors.append(local_errors) else: return result - if len(errors) > 0: - self.msg.add_errors(errors[0]) - return results[0] - else: - return self.check_op_local(op_name, left_type, right_expr, context, - self.msg) + if len(errors) == 0: + local_errors = make_local_errors() + result = self.check_op_local(op_name, left_type, right_expr, context, local_errors) + + if local_errors.is_errors(): + errors.append(local_errors) + else: + return result + + self.msg.add_errors(errors[0]) + if warn_about_uncalled_reverse_operator: + self.msg.reverse_operator_method_never_called( + nodes.op_methods_to_symbols[op_name], + op_name, + right_type, + rev_op_name, + context, + ) + error_any = AnyType(TypeOfAny.from_error) + return error_any, error_any def check_op(self, method: str, base_type: Type, arg: Expression, context: Context, diff --git a/mypy/messages.py b/mypy/messages.py index aa154587964e..a33debee153a 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1000,14 +1000,21 @@ def overloaded_signatures_ret_specific(self, index: int, context: Context) -> No self.fail('Overloaded function implementation cannot produce return type ' 'of signature {}'.format(index), context) - def object_with_reverse_operator_missing_forward_operator(self, - forward_method: str, - reverse_class: TypeInfo, - reverse_method: str, - context: Context) -> None: - self.fail("Cannot define a reverse operator method {} that accepts an instance of " - "the containing class {} without also defining the corresponding forward" - "operator method {}".format(reverse_method, reverse_class.name(), forward_method), context) + def reverse_operator_method_never_called(self, + op: str, + forward_method: str, + reverse_type: Type, + reverse_method: str, + context: Context) -> None: + msg = "{rfunc} will not be called when running '{cls} {op} {cls}': must define {ffunc}" + self.note( + msg.format( + op=op, + ffunc=forward_method, + rfunc=reverse_method, + cls=self.format_bare(reverse_type), + ), + context=context) def operator_method_signatures_overlap( self, reverse_class: TypeInfo, reverse_method: str, forward_class: Type, diff --git a/mypy/nodes.py b/mypy/nodes.py index 31db25103f0c..6c2afcd09f1f 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1471,6 +1471,8 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: 'in': '__contains__', } # type: Dict[str, str] +op_methods_to_symbols = {v: k for (k, v) in op_methods.items()} + comparison_fallback_method = '__cmp__' ops_falling_back_to_cmp = {'__ne__', '__eq__', '__lt__', '__le__', @@ -1506,6 +1508,26 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: '__le__': '__ge__', } +# Suppose we have some class A. When we do A() + A(), Python will only check +# the output of A().__add__(A()) and skip calling the __radd__ method entirely. +# This shortcut is used only for the following methods: +op_methods_that_shortcut = { + '__add__', + '__sub__', + '__mul__', + '__truediv__', + '__mod__', + '__divmod__', + '__floordiv__', + '__pow__', + '__matmul__', + '__and__', + '__or__', + '__xor__', + '__lshift__', + '__rshift__', +} + normal_from_reverse_op = dict((m, n) for n, m in reverse_op_methods.items()) reverse_op_method_set = set(reverse_op_methods.values()) diff --git a/test-data/unit/check-attr.test b/test-data/unit/check-attr.test index 8efa81be346c..8ffb516a8a2e 100644 --- a/test-data/unit/check-attr.test +++ b/test-data/unit/check-attr.test @@ -661,16 +661,16 @@ reveal_type(D.__lt__) # E: Revealed type is 'def [AT] (self: AT`1, other: AT`1) A() < A() B() < B() -A() < B() # E: Unsupported operand types for > ("B" and "A") +A() < B() # E: Unsupported operand types for < ("A" and "B") C() > A() C() > B() C() > C() -C() > D() # E: Unsupported operand types for < ("D" and "C") +C() > D() # E: Unsupported operand types for > ("C" and "D") D() >= A() -D() >= B() # E: Unsupported operand types for <= ("B" and "D") -D() >= C() # E: Unsupported operand types for <= ("C" and "D") +D() >= B() # E: Unsupported operand types for >= ("D" and "B") +D() >= C() # E: Unsupported operand types for >= ("D" and "C") D() >= D() A() <= 1 # E: Unsupported operand types for <= ("A" and "int") diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index 6c846a405dce..e48fa70550ec 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -1615,25 +1615,13 @@ class A: class B(A): pass -# Note: This is actually a runtime error. If we run x.__add__(y) +# Note: This is a runtime error. If we run x.__add__(y) # where x and y are *not* the same type, Python will not try # calling __radd__. -# -# It is, however, difficult to detect this statically -- for example, -# if we have a variable of type A, that doesn't mean that variable is -# *exactly* an instance of A. It could be a subclass, for example. -# -# We also can't really print a warning when checking the definition -# of A.__add__ and A.__radd__ -- the user might intentionally be defining -# those two methods to have non-conventional semantics. -# -# As a result, we don't try and handle this case and pretend that -# Python actually *will* call `__radd__` even when the types are -# the same. -reveal_type(A() + A()) # E: Revealed type is '__main__.A' - -# Here, Python *will* call __radd__(...) so the revealed type -# and the runtime behavior match. +A() + A() # E: Unsupported operand types for + ("A" and "A") \ + # N: __radd__ will not be called when running 'A + A': must define __add__ + +# Here, Python *will* call __radd__(...) reveal_type(B() + A()) # E: Revealed type is '__main__.A' reveal_type(A() + B()) # E: Revealed type is '__main__.A' [builtins fixtures/isinstance.pyi] @@ -1826,8 +1814,8 @@ from typing import TypeVar T = TypeVar("T", Real, Fraction) class Real: def __add__(self, other: Fraction) -> str: ... -class Fraction: - def __radd__(self, other: T) -> T: ... # E: Signatures of "__radd__" of "Fraction" and "__add__" of "T" are unsafely overlapping +class Fraction(Real): + def __radd__(self, other: T) -> T: ... # E: Signatures of "__radd__" of "Fraction" and "__add__" of "Real" are unsafely overlapping reveal_type(Real() + Fraction()) # E: Revealed type is '__main__.Real*' reveal_type(Fraction() + Fraction()) # E: Revealed type is 'builtins.str' @@ -1845,8 +1833,10 @@ reveal_type(Real() + Fraction()) # E: Revealed type is '__main__. reveal_type(FractionChild() + Fraction()) # E: Revealed type is '__main__.FractionChild*' reveal_type(FractionChild() + FractionChild()) # E: Revealed type is 'builtins.str' -# Technically not correct -reveal_type(Fraction() + Fraction()) # E: Revealed type is '__main__.Fraction*' +# Runtime error: we try calling __add__, it doesn't match, and we don't try __radd__ since +# the LHS and the RHS are not the same. +Fraction() + Fraction() # E: Unsupported operand types for + ("Fraction" and "Fraction") \ + # N: __radd__ will not be called when running 'Fraction + Fraction': must define __add__ [case testReverseOperatorTypeType] from typing import TypeVar, Type diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 39a2b11e3d35..fd2bc496deb1 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -537,9 +537,9 @@ class B: def __gt__(self, o: 'B') -> bool: pass [builtins fixtures/bool.pyi] [out] -main:3: error: Unsupported operand types for > ("A" and "A") -main:5: error: Unsupported operand types for > ("A" and "A") +main:3: error: Unsupported operand types for < ("A" and "A") main:5: error: Unsupported operand types for < ("A" and "A") +main:5: error: Unsupported operand types for > ("A" and "A") [case testChainedCompBoolRes] diff --git a/test-data/unit/check-namedtuple.test b/test-data/unit/check-namedtuple.test index cf1c3a31d8de..5a1f5869b568 100644 --- a/test-data/unit/check-namedtuple.test +++ b/test-data/unit/check-namedtuple.test @@ -685,7 +685,7 @@ my_eval(A([B(1), B(2)])) # OK from typing import NamedTuple class Real(NamedTuple): - def __sub__(self, other) -> str: return "" + def __sub__(self, other: Real) -> str: return "" class Fraction(Real): def __rsub__(self, other: Real) -> Real: return other # E: Signatures of "__rsub__" of "Fraction" and "__sub__" of "Real" are unsafely overlapping diff --git a/test-data/unit/check-statements.test b/test-data/unit/check-statements.test index df8bc6548f14..850ec9ba6f38 100644 --- a/test-data/unit/check-statements.test +++ b/test-data/unit/check-statements.test @@ -1578,7 +1578,7 @@ d = {'weight0': 65.5} reveal_type(d['weight0']) # E: Revealed type is 'builtins.float*' d['weight0'] = 65 reveal_type(d['weight0']) # E: Revealed type is 'builtins.float*' -d['weight0'] *= 'a' # E: Unsupported operand types for * ("float" and "str") # E: Incompatible types in assignment (expression has type "str", target has type "float") +d['weight0'] *= 'a' # E: Unsupported operand types for * ("float" and "str") d['weight0'] *= 0.5 reveal_type(d['weight0']) # E: Revealed type is 'builtins.float*' d['weight0'] *= object() # E: Unsupported operand types for * ("float" and "object") diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index daefaadddcfe..c0543819e9f1 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -428,10 +428,10 @@ b'' < '' '' < bytearray() bytearray() < '' [out] -_program.py:2: error: Unsupported operand types for > ("bytes" and "str") -_program.py:3: error: Unsupported operand types for > ("str" and "bytes") -_program.py:4: error: Unsupported operand types for > ("bytearray" and "str") -_program.py:5: error: Unsupported operand types for > ("str" and "bytearray") +_program.py:2: error: Unsupported operand types for < ("str" and "bytes") +_program.py:3: error: Unsupported operand types for < ("bytes" and "str") +_program.py:4: error: Unsupported operand types for < ("str" and "bytearray") +_program.py:5: error: Unsupported operand types for < ("bytearray" and "str") [case testInplaceOperatorMethod] import typing diff --git a/test-data/unit/typexport-basic.test b/test-data/unit/typexport-basic.test index 890f3d7e2ed0..263be9837616 100644 --- a/test-data/unit/typexport-basic.test +++ b/test-data/unit/typexport-basic.test @@ -147,8 +147,6 @@ class type: pass class function: pass class str: pass [out] -tmp/builtins.py:5: error: Signatures of "__lt__" of "int" and "__gt__" of "int" are unsafely overlapping -tmp/builtins.py:6: error: Signatures of "__gt__" of "int" and "__lt__" of "int" are unsafely overlapping ComparisonExpr(3) : builtins.bool ComparisonExpr(4) : builtins.bool ComparisonExpr(5) : builtins.bool From e12fa0bee398f31d3256b16d8d2f5be07f88d407 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Mon, 13 Aug 2018 18:55:59 -0700 Subject: [PATCH 17/21] WIP: Modify Any fallback from reverse operators --- mypy/checkexpr.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 5626e69cfc44..d4030eee3e2d 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1930,6 +1930,7 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: variants = [(op, obj, arg) for (op, obj, arg) in variants_raw if op is not None] errors = [] + results = [] for method, obj, arg in variants: local_errors = make_local_errors() @@ -1944,6 +1945,7 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: callable_name=callable_name, object_type=obj) if local_errors.is_errors(): errors.append(local_errors) + results.append(result) else: return result @@ -1953,6 +1955,7 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: if local_errors.is_errors(): errors.append(local_errors) + results.append(result) else: return result @@ -1965,8 +1968,7 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: rev_op_name, context, ) - error_any = AnyType(TypeOfAny.from_error) - return error_any, error_any + return results[0] def check_op(self, method: str, base_type: Type, arg: Expression, context: Context, From e564df36a6faed1a55b6d465d55e82d65e0a4aa9 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Mon, 13 Aug 2018 21:27:36 -0700 Subject: [PATCH 18/21] WIP: Fix last failing tests? --- mypy/checker.py | 2 +- mypy/checkexpr.py | 124 ++++++++++++++++------------------------------ mypy/meet.py | 24 +++++++-- mypy/sametypes.py | 3 +- 4 files changed, 66 insertions(+), 87 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 15390c7aa424..4ffffff21ae3 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3475,7 +3475,7 @@ def conditional_type_map(expr: Expression, and is_proper_subtype(current_type, proposed_type)): # Expression is always of one of the types in proposed_type_ranges return {}, None - elif not is_overlapping_types(current_type, proposed_type): + elif not is_overlapping_types(current_type, proposed_type, prohibit_none_typevar_overlap=True): # Expression is never of any type in proposed_type_ranges return None, {} else: diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index d4030eee3e2d..91a7840e96d8 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1849,10 +1849,14 @@ def check_op_reversible(self, right_type: Type, right_expr: Expression, context: Context) -> Tuple[Type, Type]: - # TODO: Document this kludge + # Note: this kludge exists mostly to maintain compatibility with + # existing error messages. Apparently, if the left-hand-side is a + # union and we have a type mismatch, we print out a special, + # abbreviated error message. (See messages.unsupported_operand_types). unions_present = isinstance(left_type, UnionType) def make_local_errors() -> MessageBuilder: + """Creates a new MessageBuilder object.""" local_errors = self.msg.clean_copy() local_errors.disable_count = 0 if unions_present: @@ -1860,6 +1864,8 @@ def make_local_errors() -> MessageBuilder: return local_errors def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: + """Looks up the given operator and returns the corresponding type, + if it exists.""" if not self.has_member(base_type, op_name): return None local_errors = make_local_errors() @@ -1881,8 +1887,11 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: else: return member + # If either the LHS or the RHS are Any, we can't really concluding anything + # about the operation since the Any type may or may not define an + # __op__ or __rop__ method. So, we punt and return Any instead. + if isinstance(left_type, AnyType): - # If either side is Any, we can't necessarily conclude anything. any_type = AnyType(TypeOfAny.from_another_any, source_any=left_type) return any_type, any_type if isinstance(right_type, AnyType): @@ -1891,7 +1900,9 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: rev_op_name = self.get_reverse_op_method(op_name) - # Stage 1: Get all variants + # Step 1: We start by getting the __op__ and __rop__ methods, if they exist. + + # Records the method type, the base type, and the argument. variants_raw = [] # type: List[Tuple[Optional[CallableType], Type, Expression]] left_op = lookup_operator(op_name, left_type) @@ -1901,6 +1912,13 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: bias_right = is_proper_subtype(right_type, left_type) if op_name in nodes.op_methods_that_shortcut and is_same_type(left_type, right_type): + # When we do "A() + A()", for example, Python will only call the __add__ method, + # never the __radd__ method. + # + # This is the case even if the __add__ method is completely missing and the __radd__ + # method is defined. + # + # We report this error message here instead of in the definition checks variants_raw.append((left_op, left_type, right_expr)) if right_op is not None: warn_about_uncalled_reverse_operator = True @@ -1968,7 +1986,13 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: rev_op_name, context, ) - return results[0] + if len(results) == 1: + return results[0] + else: + error_any = AnyType(TypeOfAny.from_error) + result = error_any, error_any + return result + def check_op(self, method: str, base_type: Type, arg: Expression, context: Context, @@ -1979,85 +2003,21 @@ def check_op(self, method: str, base_type: Type, arg: Expression, """ if allow_reverse: - return self.check_op_reversible(method, base_type, TempNode(base_type), self.accept(arg), arg, context) - - - # Use a local error storage for errors related to invalid argument - # type (but NOT other errors). This error may need to be suppressed - # for operators which support __rX methods. - local_errors = self.msg.copy() - local_errors.disable_count = 0 - if not allow_reverse or self.has_member(base_type, method): - result = self.check_op_local(method, base_type, arg, context, - local_errors) - if allow_reverse: - arg_type = self.chk.type_map[arg] - if isinstance(arg_type, AnyType): - # If the right operand has type Any, we can't make any - # conjectures about the type of the result, since the - # operand could have a __r method that returns anything. - any_type = AnyType(TypeOfAny.from_another_any, source_any=arg_type) - result = any_type, result[1] - success = not local_errors.is_errors() - else: - error_any = AnyType(TypeOfAny.from_error) - result = error_any, error_any - success = False - if success or not allow_reverse or isinstance(base_type, AnyType): - # We were able to call the normal variant of the operator method, - # or there was some problem not related to argument type - # validity, or the operator has no __rX method. In any case, we - # don't need to consider the __rX method. - self.msg.add_errors(local_errors) - return result + return self.check_op_reversible( + op_name=method, + left_type=base_type, + left_expr=TempNode(base_type), + right_type=self.accept(arg), + right_expr=arg, + context=context) else: - # Calling the operator method was unsuccessful. Try the __rX - # method of the other operand instead. - rmethod = self.get_reverse_op_method(method) - arg_type = self.accept(arg) - base_arg_node = TempNode(base_type) - # In order to be consistent with showing an error about the lhs not matching if neither - # the lhs nor the rhs have a compatible signature, we keep track of the first error - # message generated when considering __rX methods and __cmp__ methods for Python 2. - first_error = None # type: Optional[Tuple[Tuple[Type, Type], MessageBuilder]] - if self.has_member(arg_type, rmethod): - result, local_errors = self._check_op_for_errors(rmethod, arg_type, - base_arg_node, context) - if not local_errors.is_errors(): - return result - first_error = first_error or (result, local_errors) - # If we've failed to find an __rX method and we're checking Python 2, check to see if - # there is a __cmp__ method on the lhs or on the rhs. - if (self.chk.options.python_version[0] == 2 and - method in nodes.ops_falling_back_to_cmp): - cmp_method = nodes.comparison_fallback_method - if self.has_member(base_type, cmp_method): - # First check the if the lhs has a __cmp__ method that works - result, local_errors = self._check_op_for_errors(cmp_method, base_type, - arg, context) - if not local_errors.is_errors(): - return result - first_error = first_error or (result, local_errors) - if self.has_member(arg_type, cmp_method): - # Failed to find a __cmp__ method on the lhs, check if - # the rhs as a __cmp__ method that can operate on lhs - result, local_errors = self._check_op_for_errors(cmp_method, arg_type, - base_arg_node, context) - if not local_errors.is_errors(): - return result - first_error = first_error or (result, local_errors) - if first_error: - # We found either a __rX method, a __cmp__ method on the base_type, or a __cmp__ - # method on the rhs and failed match. Return the error for the first of these to - # fail. - self.msg.add_errors(first_error[1]) - return first_error[0] - else: - # No __rX method or __cmp__. Do deferred type checking to - # produce error message that we may have missed previously. - # TODO Fix type checking an expression more than once. - return self.check_op_local(method, base_type, arg, context, - self.msg) + self.check_op_local( + method=method, + base_type=base_type, + arg=arg, + context=context, + local_errors=self.msg, + ) def get_reverse_op_method(self, method: str) -> str: if method == '__div__' and self.chk.options.python_version[0] == 2: diff --git a/mypy/meet.py b/mypy/meet.py index 24952df3083e..4ee31e39abe4 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -96,14 +96,25 @@ def get_possible_variants(typ: Type) -> List[Type]: return [typ] -def is_overlapping_types(left: Type, right: Type, ignore_promotions: bool = False) -> bool: - """Can a value of type 'left' also be of type 'right' or vice-versa?""" +def is_overlapping_types(left: Type, + right: Type, + ignore_promotions: bool = False, + prohibit_none_typevar_overlap: bool = False) -> bool: + """Can a value of type 'left' also be of type 'right' or vice-versa? + + If 'ignore_promotions' is True, we ignore promotions while checking for overlaps. + If 'prohibit_none_typevar_overlap' is True, we disallow None from overlapping with + TypeVars (in both strict-optional and non-strict-optional mode). + """ def _is_overlapping_types(left: Type, right: Type) -> bool: '''Encode the kind of overlapping check to perform. This function mostly exists so we don't have to repeat keyword arguments everywhere.''' - return is_overlapping_types(left, right, ignore_promotions=ignore_promotions) + return is_overlapping_types( + left, right, + ignore_promotions=ignore_promotions, + prohibit_none_typevar_overlap=prohibit_none_typevar_overlap) # We should never encounter these types, but if we do, we handle # them in the same way we handle 'Any'. @@ -143,6 +154,13 @@ def _is_overlapping_types(left: Type, right: Type) -> bool: # If both types are singleton variants (and are not TypeVars), we've hit the base case: # we skip these checks to avoid infinitely recursing. + def is_none_typevar_overlap(t1: Type, t2: Type) -> bool: + return isinstance(t1, NoneTyp) and isinstance(t2, TypeVarType) + + if prohibit_none_typevar_overlap: + if is_none_typevar_overlap(left, right) or is_none_typevar_overlap(right, left): + return False + if (len(left_possible) > 1 or len(right_possible) > 1 or isinstance(left, TypeVarType) or isinstance(right, TypeVarType)): for l in left_possible: diff --git a/mypy/sametypes.py b/mypy/sametypes.py index b382c632ffe3..ef053a5b4b19 100644 --- a/mypy/sametypes.py +++ b/mypy/sametypes.py @@ -98,7 +98,8 @@ def visit_callable_type(self, left: CallableType) -> bool: def visit_tuple_type(self, left: TupleType) -> bool: if isinstance(self.right, TupleType): - return is_same_types(left.items, self.right.items) + return (is_same_type(left.fallback, self.right.fallback) + and is_same_types(left.items, self.right.items)) else: return False From 4e3c8b8a51ff14f0bebaa48f8fef4c61367511b9 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Mon, 13 Aug 2018 21:28:48 -0700 Subject: [PATCH 19/21] Add missing return --- mypy/checkexpr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 91a7840e96d8..7bf49f70f065 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1918,7 +1918,7 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: # This is the case even if the __add__ method is completely missing and the __radd__ # method is defined. # - # We report this error message here instead of in the definition checks + # We report this error message here instead of in the definition checks variants_raw.append((left_op, left_type, right_expr)) if right_op is not None: warn_about_uncalled_reverse_operator = True @@ -2011,7 +2011,7 @@ def check_op(self, method: str, base_type: Type, arg: Expression, right_expr=arg, context=context) else: - self.check_op_local( + return self.check_op_local( method=method, base_type=base_type, arg=arg, From 71437c40233ddd80472f7421343b9a2705cb8322 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Mon, 13 Aug 2018 21:55:29 -0700 Subject: [PATCH 20/21] WIP: Misc cleanup --- mypy/checker.py | 95 +--------------------------- mypy/checkexpr.py | 82 +++++++++++++++++------- test-data/unit/check-isinstance.test | 28 ++++++++ 3 files changed, 90 insertions(+), 115 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 4ffffff21ae3..f5ecde4af015 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3475,7 +3475,8 @@ def conditional_type_map(expr: Expression, and is_proper_subtype(current_type, proposed_type)): # Expression is always of one of the types in proposed_type_ranges return {}, None - elif not is_overlapping_types(current_type, proposed_type, prohibit_none_typevar_overlap=True): + elif not is_overlapping_types(current_type, proposed_type, + prohibit_none_typevar_overlap=True): # Expression is never of any type in proposed_type_ranges return None, {} else: @@ -3874,69 +3875,6 @@ def overload_can_never_match(signature: CallableType, other: CallableType) -> bo ignore_return=True) -def is_unsafe_overlapping_operator_signatures(signature: Type, other: Type) -> bool: - """Check if two operator method signatures may be unsafely overlapping. - - Two signatures s and t are overlapping if both can be valid for the same - statically typed values and the return types are incompatible. - - Assume calls are first checked against 'signature', then against 'other'. - Thus if 'signature' is more general than 'other', there is no unsafe - overlapping. - - TODO: Clean up this function and make it not perform type erasure. - - Context: This function was previously used to make sure both overloaded - functions and operator methods were not unsafely overlapping. - - We changed the semantics for we should handle overloaded definitions, - but not operator functions. (We can't reuse the same semantics for both: - the overload semantics are too restrictive here). - - We should rewrite this method so that: - - 1. It uses many of the improvements made to overloads: in particular, - eliminating type erasure. - - 2. It contains just the logic necessary for operator methods. - """ - if isinstance(signature, CallableType): - if isinstance(other, CallableType): - # TODO varargs - # TODO keyword args - # TODO erasure - # TODO allow to vary covariantly - # Check if the argument counts are overlapping. - min_args = max(signature.min_args, other.min_args) - max_args = min(len(signature.arg_types), len(other.arg_types)) - if min_args > max_args: - # Argument counts are not overlapping. - return False - # Signatures are overlapping iff if they are overlapping for the - # smallest common argument count. - for i in range(min_args): - t1 = signature.arg_types[i] - t2 = other.arg_types[i] - if not is_overlapping_erased_types(t1, t2): - return False - # All arguments types for the smallest common argument count are - # overlapping => the signature is overlapping. The overlapping is - # safe if the return types are identical. - if is_same_type(signature.ret_type, other.ret_type): - return False - # If the first signature has more general argument types, the - # latter will never be called - if is_more_general_arg_prefix(signature, other): - return False - # Special case: all args are subtypes, and returns are subtypes - if (all(is_proper_subtype(s, o) - for (s, o) in zip(signature.arg_types, other.arg_types)) and - is_subtype(signature.ret_type, other.ret_type)): - return False - return not is_more_precise_signature(signature, other) - return True - - def is_more_general_arg_prefix(t: FunctionLike, s: FunctionLike) -> bool: """Does t have wider arguments than s?""" # TODO should an overload with additional items be allowed to be more @@ -3954,20 +3892,6 @@ def is_more_general_arg_prefix(t: FunctionLike, s: FunctionLike) -> bool: return False -def is_equivalent_type_var_def(tv1: TypeVarDef, tv2: TypeVarDef) -> bool: - """Are type variable definitions equivalent? - - Ignore ids, locations in source file and names. - """ - return ( - tv1.variance == tv2.variance - and is_same_types(tv1.values, tv2.values) - and ((tv1.upper_bound is None and tv2.upper_bound is None) - or (tv1.upper_bound is not None - and tv2.upper_bound is not None - and is_same_type(tv1.upper_bound, tv2.upper_bound)))) - - def is_same_arg_prefix(t: CallableType, s: CallableType) -> bool: return is_callable_compatible(t, s, is_compat=is_same_type, @@ -3976,21 +3900,6 @@ def is_same_arg_prefix(t: CallableType, s: CallableType) -> bool: ignore_pos_arg_names=True) -def is_more_precise_signature(t: CallableType, s: CallableType) -> bool: - """Is t more precise than s? - A signature t is more precise than s if all argument types and the return - type of t are more precise than the corresponding types in s. - Assume that the argument kinds and names are compatible, and that the - argument counts are overlapping. - """ - # TODO generic function types - # Only consider the common prefix of argument types. - for argt, args in zip(t.arg_types, s.arg_types): - if not is_more_precise(argt, args): - return False - return is_more_precise(t.ret_type, s.ret_type) - - def infer_operator_assignment_method(typ: Type, operator: str) -> Tuple[bool, str]: """Determine if operator assignment on given value type is in-place, and the method name. diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 7bf49f70f065..024056c9e983 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1808,20 +1808,6 @@ def get_operator_method(self, op: str) -> str: else: return nodes.op_methods[op] - def _check_op_for_errors(self, method: str, base_type: Type, arg: Expression, - context: Context - ) -> Tuple[Tuple[Type, Type], MessageBuilder]: - """Type check a binary operation which maps to a method call. - - Return ((result type, inferred operator method type), error message). - """ - local_errors = self.msg.copy() - local_errors.disable_count = 0 - result = self.check_op_local(method, base_type, - arg, context, - local_errors) - return result, local_errors - def check_op_local(self, method: str, base_type: Type, arg: Expression, context: Context, local_errors: MessageBuilder) -> Tuple[Type, Type]: """Type check a binary operation which maps to a method call. @@ -1887,6 +1873,26 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: else: return member + def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]: + """Returns the name of the class that contains the actual definition of attr_name. + + So if class A defines foo and class B subclasses A, running + 'get_class_defined_in(B, "foo")` would return the full name of A. + + However, if B were to override and redefine foo, that method call would + return the full name of B instead. + + If the attr name is not present in the given class or its MRO, returns None. + """ + mro = typ.type.mro + if mro is None: + return None + + for cls in mro: + if cls.names.get(attr_name): + return cls.fullname() + return None + # If either the LHS or the RHS are Any, we can't really concluding anything # about the operation since the Any type may or may not define an # __op__ or __rop__ method. So, we punt and return Any instead. @@ -1900,17 +1906,22 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: rev_op_name = self.get_reverse_op_method(op_name) - # Step 1: We start by getting the __op__ and __rop__ methods, if they exist. + # STEP 1: + # We start by getting the __op__ and __rop__ methods, if they exist. # Records the method type, the base type, and the argument. - variants_raw = [] # type: List[Tuple[Optional[CallableType], Type, Expression]] + variants_raw = [] # type: List[Tuple[Optional[Type], Type, Expression]] left_op = lookup_operator(op_name, left_type) right_op = lookup_operator(rev_op_name, right_type) + # STEP 2a: + # We figure out in which order Python will call the operator methods. As it + # turns out, it's not as simple as just trying to call __op__ first and + # __rop__ second. + warn_about_uncalled_reverse_operator = False bias_right = is_proper_subtype(right_type, left_type) - if op_name in nodes.op_methods_that_shortcut and is_same_type(left_type, right_type): # When we do "A() + A()", for example, Python will only call the __add__ method, # never the __radd__ method. @@ -1919,19 +1930,33 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: # method is defined. # # We report this error message here instead of in the definition checks + variants_raw.append((left_op, left_type, right_expr)) if right_op is not None: warn_about_uncalled_reverse_operator = True elif (is_subtype(right_type, left_type) and isinstance(left_type, Instance) and isinstance(right_type, Instance) - and left_type.type.get_definition_mro_name(op_name) != right_type.type.get_definition_mro_name(rev_op_name)): + and lookup_definer(left_type, op_name) != lookup_definer(right_type, rev_op_name)): + # When we do "A() + B()" where B is a subclass of B, we'll actually try calling + # B's __radd__ method first, but ONLY if B explicitly defines or overrides the + # __radd__ method. + # + # This mechanism lets subclasses "refine" the expected outcome of the operation, even + # if they're located on the RHS. + variants_raw.append((right_op, right_type, left_expr)) variants_raw.append((left_op, left_type, right_expr)) else: + # In all other cases, we do the usual thing and call __add__ first and + # __radd__ second when doing "A() + B()". + variants_raw.append((left_op, left_type, right_expr)) variants_raw.append((right_op, right_type, left_expr)) + # STEP 2b: + # When running Python 2, we might also try calling the __cmp__ method. + is_python_2 = self.chk.options.python_version[0] == 2 if is_python_2 and op_name in nodes.ops_falling_back_to_cmp: cmp_method = nodes.comparison_fallback_method @@ -1945,19 +1970,29 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: variants_raw.append((left_cmp_op, left_type, right_expr)) variants_raw.append((right_cmp_op, right_type, left_expr)) + # STEP 3: + # We now filter out all non-existant operators. The 'variants' list contains + # all operator methods that are actually present, in the order that Python + # attempts to invoke them. + variants = [(op, obj, arg) for (op, obj, arg) in variants_raw if op is not None] + # STEP 4: + # We now try invoking each one. If an operation succeeds, end early and return + # the corresponding result. Otherwise, return the result and errors associated + # with the first entry. + errors = [] results = [] for method, obj, arg in variants: local_errors = make_local_errors() + callable_name = None # type: Optional[str] if isinstance(obj, Instance): # TODO: Find out in which class the method was defined originally? # TODO: Support non-Instance types. - callable_name = '{}.{}'.format(obj.type.fullname(), op_name) # type: Optional[str] - else: - callable_name = None + callable_name = '{}.{}'.format(obj.type.fullname(), op_name) + result = self.check_call(method, [arg], [nodes.ARG_POS], context, arg_messages=local_errors, callable_name=callable_name, object_type=obj) @@ -1967,6 +2002,10 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: else: return result + # STEP 4b: + # Sometimes, the variants list is empty. In that case, we fall-back to attempting to + # call the __op__ method (even though it's missing). + if len(errors) == 0: local_errors = make_local_errors() result = self.check_op_local(op_name, left_type, right_expr, context, local_errors) @@ -1993,7 +2032,6 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: result = error_any, error_any return result - def check_op(self, method: str, base_type: Type, arg: Expression, context: Context, allow_reverse: bool = False) -> Tuple[Type, Type]: diff --git a/test-data/unit/check-isinstance.test b/test-data/unit/check-isinstance.test index a78fa8d52d58..7157c0ef58a4 100644 --- a/test-data/unit/check-isinstance.test +++ b/test-data/unit/check-isinstance.test @@ -2055,3 +2055,31 @@ def foo2(x: Optional[str]) -> None: else: reveal_type(x) # E: Revealed type is 'Union[builtins.str, None]' [builtins fixtures/isinstance.pyi] + +[case testNoneCheckDoesNotNarrowWhenUsingTypeVars] +# flags: --strict-optional +from typing import TypeVar + +T = TypeVar('T') + +def foo(x: T) -> T: + out = None + out = x + if out is None: + pass + return out +[builtins fixtures/isinstance.pyi] + +[case testNoneCheckDoesNotNarrowWhenUsingTypeVarsNoStrictOptional] +# flags: --no-strict-optional +from typing import TypeVar + +T = TypeVar('T') + +def foo(x: T) -> T: + out = None + out = x + if out is None: + pass + return out +[builtins fixtures/isinstance.pyi] From 99d48b7aa4c77b90d18bf715dfcfbcc8dd1bbb6c Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Mon, 13 Aug 2018 22:09:24 -0700 Subject: [PATCH 21/21] Misc cleanup --- mypy/checker.py | 19 ------------------- mypy/nodes.py | 9 --------- 2 files changed, 28 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index f5ecde4af015..d093293a5b2c 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3804,16 +3804,8 @@ def detach_callable(typ: CallableType) -> CallableType: appear_map[var.fullname] = [] appear_map[var.fullname].append(i) - from mypy.erasetype import erase_type - used_type_var_names = set() for var_name, appearances in appear_map.items(): - '''if len(appearances) == 1: - entry = appearances[0] - type_list[entry] = erase_type(type_list[entry]) - else: - used_type_var_names.add(var_name)''' - used_type_var_names.add(var_name) all_type_vars = typ.accept(TypeVarExtractor()) @@ -3821,7 +3813,6 @@ def detach_callable(typ: CallableType) -> CallableType: for var in set(all_type_vars): if var.fullname not in used_type_var_names: continue - # new_variables.append(var) new_variables.append(TypeVarDef( name=var.name, fullname=var.fullname, @@ -3835,16 +3826,6 @@ def detach_callable(typ: CallableType) -> CallableType: arg_types=type_list[:-1], ret_type=type_list[-1], ) - ''' - print(typ.name) - print(' before:', typ) - print(' after: ', out) - print(' type list (old):', old_type_list) - print(' type list (new):', type_list) - print(' old_vars:', typ.variables) - print(' new_vars:', out.variables) - print(' appear_map:', appear_map) - #''' return out diff --git a/mypy/nodes.py b/mypy/nodes.py index 6c2afcd09f1f..a0acd73714b4 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -2262,15 +2262,6 @@ def __getitem__(self, name: str) -> 'SymbolTableNode': def __repr__(self) -> str: return '' % self.fullname() - def get_definition_mro_name(self, name: str) -> Optional[str]: - if self.mro is None: - return None - - for cls in self.mro: - if cls.names.get(name): - return cls.fullname() - return None - def has_readable_member(self, name: str) -> bool: return self.get(name) is not None