From 057e329e465732d6a40a820d959157b040d3ac34 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 20 Jun 2018 18:35:13 +0100 Subject: [PATCH 01/12] Alternative algorithm for union math --- mypy/checkexpr.py | 172 ++++++++++---------------- test-data/unit/check-overloading.test | 90 +++++++++----- 2 files changed, 130 insertions(+), 132 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index ae2a4907e0c2..9ee70f6b0413 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -118,6 +118,7 @@ def __init__(self, self.msg = msg self.plugin = plugin self.type_context = [None] + self.type_overrides = {} # type: Dict[Expression, Type] self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg) def visit_name_expr(self, e: NameExpr) -> Type: @@ -519,7 +520,9 @@ def check_call(self, callee: Type, args: List[Expression], callable_node: Optional[Expression] = None, arg_messages: Optional[MessageBuilder] = None, callable_name: Optional[str] = None, - object_type: Optional[Type] = None) -> Tuple[Type, Type]: + object_type: Optional[Type] = None, + *, + arg_types_override: Optional[List[Type]] = None) -> Tuple[Type, Type]: """Type check a call. Also infer type arguments if the callee is a generic function. @@ -575,9 +578,11 @@ def check_call(self, callee: Type, args: List[Expression], callee, context) callee = self.infer_function_type_arguments( callee, args, arg_kinds, formal_to_actual, context) - - arg_types = self.infer_arg_types_in_context2( - callee, args, arg_kinds, formal_to_actual) + if arg_types_override is not None: + arg_types = arg_types_override.copy() + else: + arg_types = self.infer_arg_types_in_context2( + callee, args, arg_kinds, formal_to_actual) self.check_argument_count(callee, arg_types, arg_kinds, arg_names, formal_to_actual, context, self.msg) @@ -1130,22 +1135,15 @@ def check_overload_call(self, unioned_result = None # type: Optional[Tuple[Type, Type]] unioned_errors = None # type: Optional[MessageBuilder] union_success = False - if any(isinstance(arg, UnionType) and len(arg.relevant_items()) > 1 # "real" union - for arg in arg_types): - erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types, - arg_kinds, arg_names, context) - unioned_callable = self.union_overload_matches(erased_targets) - - if unioned_callable is not None: - unioned_errors = arg_messages.clean_copy() - unioned_result = self.check_call(unioned_callable, args, arg_kinds, - context, arg_names, - arg_messages=unioned_errors, - callable_name=callable_name, - object_type=object_type) - # Record if we succeeded. Next we need to see if maybe normal procedure - # gives a narrower type. - union_success = unioned_result is not None and not unioned_errors.is_errors() + if any(self.real_union(arg) for arg in arg_types): + unioned_errors = arg_messages.clean_copy() + unioned_result = self.union_overload_result(plausible_targets, args, arg_types, + arg_kinds, arg_names, + callable_name, object_type, + context, arg_messages=unioned_errors) + # Record if we succeeded. Next we need to see if maybe normal procedure + # gives a narrower type. + union_success = unioned_result is not None and not unioned_errors.is_errors() # Step 3: We try checking each branch one-by-one. inferred_result = self.infer_overload_return_type(plausible_targets, args, arg_types, @@ -1173,9 +1171,8 @@ def check_overload_call(self, # # Neither alternative matches, but we can guess the user probably wants the # second one. - if erased_targets is None: - erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types, - arg_kinds, arg_names, context) + erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types, + arg_kinds, arg_names, context) # Step 5: We try and infer a second-best alternative if possible. If not, fall back # to using 'Any'. @@ -1350,91 +1347,56 @@ def overload_erased_call_targets(self, matches.append(typ) return matches - def union_overload_matches(self, callables: List[CallableType]) -> Optional[CallableType]: - """Accepts a list of overload signatures and attempts to combine them together into a - new CallableType consisting of the union of all of the given arguments and return types. - - Returns None if it is not possible to combine the different callables together in a - sound manner. - - Assumes all of the given callables have argument counts compatible with the caller. + def union_overload_result(self, + plausible_targets: List[CallableType], + args: List[Expression], + arg_types: List[Type], + arg_kinds: List[int], + arg_names: Optional[Sequence[Optional[str]]], + callable_name: Optional[str], + object_type: Optional[Type], + context: Context, + arg_messages: Optional[MessageBuilder] = None, + ) -> Optional[Tuple[Type, Type]]: + """Accepts a list of overload signatures and attempts to match calls by destructuring + the first union. Returns None if there is no match. """ - if len(callables) == 0: - return None - elif len(callables) == 1: - return callables[0] - - # Note: we are assuming here that if a user uses some TypeVar 'T' in - # two different overloads, they meant for that TypeVar to mean the - # same thing. - # - # This function will make sure that all instances of that TypeVar 'T' - # refer to the same underlying TypeVarType and TypeVarDef objects to - # simplify the union-ing logic below. - # - # (If the user did *not* mean for 'T' to be consistently bound to the - # same type in their overloads, well, their code is probably too - # confusing and ought to be re-written anyways.) - callables, variables = merge_typevars_in_callables_by_name(callables) - - new_args = [[] for _ in range(len(callables[0].arg_types))] # type: List[List[Type]] - new_kinds = list(callables[0].arg_kinds) - new_returns = [] # type: List[Type] - - for target in callables: - # We conservatively end if the overloads do not have the exact same signature. - # The only exception is if one arg is optional and the other is positional: in that - # case, we continue unioning (and expect a positional arg). - # TODO: Enhance the union overload logic to handle a wider variety of signatures. - if len(new_kinds) != len(target.arg_kinds): + if not any(self.real_union(typ) for typ in arg_types): + # No unions in args, just fall back to normal inference + for arg, typ in zip(args, arg_types): + self.type_overrides[arg] = typ + res = self.infer_overload_return_type(plausible_targets, args, arg_types, + arg_kinds, arg_names, callable_name, + object_type, context, arg_messages) + for arg, typ in zip(args, arg_types): + del self.type_overrides[arg] + return res + first_union = next(typ for typ in arg_types if self.real_union(typ)) + idx = arg_types.index(first_union) + assert isinstance(first_union, UnionType) + returns = [] + inferred_types = [] + for item in first_union.relevant_items(): + new_arg_types = arg_types.copy() + new_arg_types[idx] = item + sub_result = self.union_overload_result(plausible_targets, args, new_arg_types, + arg_kinds, arg_names, callable_name, + object_type, context, arg_messages) + if sub_result is not None: + ret, inferred = sub_result + returns.append(ret) + inferred_types.append(inferred) + else: return None - for i, (new_kind, target_kind) in enumerate(zip(new_kinds, target.arg_kinds)): - if new_kind == target_kind: - continue - elif new_kind in (ARG_POS, ARG_OPT) and target_kind in (ARG_POS, ARG_OPT): - new_kinds[i] = ARG_POS - else: - return None - for i, arg in enumerate(target.arg_types): - new_args[i].append(arg) - new_returns.append(target.ret_type) - - union_count = 0 - final_args = [] - for args_list in new_args: - new_type = UnionType.make_simplified_union(args_list) - union_count += 1 if isinstance(new_type, UnionType) else 0 - final_args.append(new_type) - - # TODO: Modify this check to be less conservative. - # - # Currently, we permit only one union in the arguments because if we allow - # multiple, we can't always guarantee the synthesized callable will be correct. - # - # For example, suppose we had the following two overloads: - # - # @overload - # def f(x: A, y: B) -> None: ... - # @overload - # def f(x: B, y: A) -> None: ... - # - # If we continued and synthesize "def f(x: Union[A,B], y: Union[A,B]) -> None: ...", - # then we'd incorrectly accept calls like "f(A(), A())" when they really ought to - # be rejected. - # - # However, that means we'll also give up if the original overloads contained - # any unions. This is likely unnecessary -- we only really need to give up if - # there are more then one *synthesized* union arguments. - if union_count >= 2: - return None + if returns: + print('un', returns, UnionType.make_simplified_union(returns, context.line, context.column)) + return (UnionType.make_simplified_union(returns, context.line, context.column), + UnionType.make_simplified_union(inferred_types, context.line, context.column)) + return None - return callables[0].copy_modified( - arg_types=final_args, - arg_kinds=new_kinds, - ret_type=UnionType.make_simplified_union(new_returns), - variables=variables, - implicit=True) + def real_union(self, typ: Type) -> bool: + return isinstance(typ, UnionType) and len(typ.relevant_items()) > 1 def erased_signature_similarity(self, arg_types: List[Type], arg_kinds: List[int], arg_names: Optional[Sequence[Optional[str]]], @@ -2666,6 +2628,8 @@ def accept(self, is True and this expression is a call, allow it to return None. This applies only to this expression and not any subexpressions. """ + if node in self.type_overrides: + return self.type_overrides[node] self.type_context.append(type_context) try: if allow_none_return and isinstance(node, CallExpr): diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index c74ec5e32958..5ced8c09bf59 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -2352,7 +2352,8 @@ reveal_type(f2(arg1, C())) main:15: error: Revealed type is '__main__.B' main:15: error: Argument 1 to "f1" has incompatible type "Union[A, C]"; expected "A" main:15: error: Argument 2 to "f1" has incompatible type "Union[A, C]"; expected "C" -main:23: error: Revealed type is 'Union[__main__.B, __main__.D]' +main:23: error: Revealed type is '__main__.B' +main:23: error: Argument 1 to "f2" has incompatible type "Union[A, C]"; expected "A" main:23: error: Argument 2 to "f2" has incompatible type "Union[A, C]"; expected "C" main:24: error: Revealed type is 'Union[__main__.B, __main__.D]' @@ -2379,7 +2380,7 @@ compat: Union[WrapperCo[C], WrapperContra[A]] reveal_type(foo(compat)) # E: Revealed type is 'Union[builtins.int, builtins.str]' not_compat: Union[WrapperCo[A], WrapperContra[C]] -foo(not_compat) # E: Argument 1 to "foo" has incompatible type "Union[WrapperCo[A], WrapperContra[C]]"; expected "Union[WrapperCo[B], WrapperContra[B]]" +foo(not_compat) # E: Argument 1 to "foo" has incompatible type "Union[WrapperCo[A], WrapperContra[C]]"; expected "WrapperCo[B]" [case testOverloadInferUnionIfParameterNamesAreDifferent] from typing import overload, Union @@ -2439,10 +2440,8 @@ def f(x: A) -> Child: ... def f(x: B, y: B = B()) -> Parent: ... def f(*args): ... -# TODO: It would be nice if we could successfully do union math -# in this case. See comments in checkexpr.union_overload_matches. x: Union[A, B] -f(x) # E: Argument 1 to "f" has incompatible type "Union[A, B]"; expected "A" +f(x) # OK f(x, B()) # E: Argument 1 to "f" has incompatible type "Union[A, B]"; expected "B" [case testOverloadInferUnionWithMixOfPositionalAndOptionalArgs] @@ -2489,7 +2488,7 @@ x: Union[A, B] reveal_type(obj.f(A())) # E: Revealed type is '__main__.C' reveal_type(obj.f(B())) # E: Revealed type is '__main__.B' -reveal_type(obj.f(x)) # E: Revealed type is 'Union[__main__.B, __main__.C]' +reveal_type(obj.f(x)) # E: Revealed type is 'Union[__main__.C, __main__.B]' [case testOverloadingInferUnionReturnWithFunctionTypevarReturn] from typing import overload, Union, TypeVar, Generic @@ -2519,10 +2518,10 @@ def wrapper() -> None: obj2: Union[W1[A], W2[B]] - foo(obj2) # E: Cannot infer type argument 1 of "foo" + foo(obj2) # OK bar(obj2) # E: Cannot infer type argument 1 of "bar" - b1_overload: A = foo(obj2) # E: Cannot infer type argument 1 of "foo" + b1_overload: A = foo(obj2) # E: Incompatible types in assignment (expression has type "Union[A, B]", variable has type "A") b1_union: A = bar(obj2) # E: Cannot infer type argument 1 of "bar" [case testOverloadingInferUnionReturnWithObjectTypevarReturn] @@ -2552,10 +2551,10 @@ def wrapper() -> None: # Note: These should be fine, but mypy has an unrelated bug # that makes them error out? - a2_overload: A = SomeType().foo(obj1) # E: Argument 1 to "foo" of "SomeType" has incompatible type "Union[W1[A], W2[A]]"; expected "Union[W1[], W2[]]" + a2_overload: A = SomeType().foo(obj1) # E: Argument 1 to "foo" of "SomeType" has incompatible type "Union[W1[A], W2[A]]"; expected "W1[]" a2_union: A = SomeType().bar(obj1) # E: Argument 1 to "bar" of "SomeType" has incompatible type "Union[W1[A], W2[A]]"; expected "Union[W1[], W2[]]" - SomeType().foo(obj1) # E: Argument 1 to "foo" of "SomeType" has incompatible type "Union[W1[A], W2[A]]"; expected "Union[W1[], W2[]]" + SomeType().foo(obj1) # E: Argument 1 to "foo" of "SomeType" has incompatible type "Union[W1[A], W2[A]]"; expected "W1[]" SomeType().bar(obj1) # E: Argument 1 to "bar" of "SomeType" has incompatible type "Union[W1[A], W2[A]]"; expected "Union[W1[], W2[]]" [case testOverloadingInferUnionReturnWithBadObjectTypevarReturn] @@ -2580,13 +2579,13 @@ class SomeType(Generic[T]): def wrapper(mysterious: T) -> T: obj1: Union[W1[A], W2[B]] - SomeType().foo(obj1) # E: Argument 1 to "foo" of "SomeType" has incompatible type "Union[W1[A], W2[B]]"; expected "Union[W1[], W2[]]" + SomeType().foo(obj1) # E: Argument 1 to "foo" of "SomeType" has incompatible type "Union[W1[A], W2[B]]"; expected "W1[]" SomeType().bar(obj1) # E: Argument 1 to "bar" of "SomeType" has incompatible type "Union[W1[A], W2[B]]"; expected "Union[W1[], W2[]]" - SomeType[A]().foo(obj1) # E: Argument 1 to "foo" of "SomeType" has incompatible type "Union[W1[A], W2[B]]"; expected "Union[W1[A], W2[A]]" + SomeType[A]().foo(obj1) # E: Argument 1 to "foo" of "SomeType" has incompatible type "Union[W1[A], W2[B]]"; expected "W1[A]" SomeType[A]().bar(obj1) # E: Argument 1 to "bar" of "SomeType" has incompatible type "Union[W1[A], W2[B]]"; expected "Union[W1[A], W2[A]]" - SomeType[T]().foo(obj1) # E: Argument 1 to "foo" of "SomeType" has incompatible type "Union[W1[A], W2[B]]"; expected "Union[W1[T], W2[T]]" + SomeType[T]().foo(obj1) # E: Argument 1 to "foo" of "SomeType" has incompatible type "Union[W1[A], W2[B]]"; expected "W1[T]" SomeType[T]().bar(obj1) # E: Argument 1 to "bar" of "SomeType" has incompatible type "Union[W1[A], W2[B]]"; expected "Union[W1[T], W2[T]]" return mysterious @@ -2613,7 +2612,7 @@ T1 = TypeVar('T1', bound=A) def t_is_same_bound(arg1: T1, arg2: S) -> Tuple[T1, S]: x1: Union[List[S], List[Tuple[T1, S]]] y1: S - reveal_type(Dummy[T1]().foo(x1, y1)) # E: Revealed type is 'Union[T1`-1, S`-2]' + reveal_type(Dummy[T1]().foo(x1, y1)) # E: Revealed type is 'Union[S`-2, T1`-1]' x2: Union[List[T1], List[Tuple[T1, T1]]] y2: T1 @@ -2646,12 +2645,13 @@ def t_is_same_bound(arg1: T1, arg2: S) -> Tuple[T1, S]: # The arguments in the tuple are swapped x3: Union[List[S], List[Tuple[S, T1]]] y3: S - Dummy[T1]().foo(x3, y3) # E: Argument 1 to "foo" of "Dummy" has incompatible type "Union[List[S], List[Tuple[S, T1]]]"; expected "Union[List[Tuple[T1, S]], List[S]]" + Dummy[T1]().foo(x3, y3) # E: Cannot infer type argument 1 of "foo" of "Dummy" \ + # E: Argument 1 to "foo" of "Dummy" has incompatible type "Union[List[S], List[Tuple[S, T1]]]"; expected "List[Tuple[T1, Any]]" x4: Union[List[int], List[Tuple[C, int]]] y4: int - reveal_type(Dummy[C]().foo(x4, y4)) # E: Revealed type is 'Union[__main__.C, builtins.int*]' - Dummy[A]().foo(x4, y4) # E: Cannot infer type argument 1 of "foo" of "Dummy" + reveal_type(Dummy[C]().foo(x4, y4)) # E: Revealed type is 'Union[builtins.int*, __main__.C]' + Dummy[A]().foo(x4, y4) # E: Argument 1 to "foo" of "Dummy" has incompatible type "Union[List[int], List[Tuple[C, int]]]"; expected "List[Tuple[A, int]]" return arg1, arg2 @@ -2679,7 +2679,7 @@ T1 = TypeVar('T1', bound=B) def t_is_tighter_bound(arg1: T1, arg2: S) -> Tuple[T1, S]: x1: Union[List[S], List[Tuple[T1, S]]] y1: S - reveal_type(Dummy[T1]().foo(x1, y1)) # E: Revealed type is 'Union[T1`-1, S`-2]' + reveal_type(Dummy[T1]().foo(x1, y1)) # E: Revealed type is 'Union[S`-2, T1`-1]' x2: Union[List[T1], List[Tuple[T1, T1]]] y2: T1 @@ -2721,10 +2721,10 @@ def t_is_compatible_bound(arg1: T3, arg2: S) -> Tuple[T3, S]: [builtins fixtures/list.pyi] [out] -main:22: error: Revealed type is 'Union[__main__.B, S`-2]' -main:22: error: Revealed type is 'Union[__main__.C, S`-2]' -main:26: error: Revealed type is '__main__.B' -main:26: error: Revealed type is '__main__.C' +main:22: error: Revealed type is 'Union[S`-2, __main__.B]' +main:22: error: Revealed type is 'Union[S`-2, __main__.C]' +main:26: error: Revealed type is '__main__.B*' +main:26: error: Revealed type is '__main__.C*' [case testOverloadInferUnionReturnWithInconsistentTypevarNames] from typing import overload, TypeVar, Union @@ -2751,10 +2751,8 @@ def test(x: T) -> T: reveal_type(consistent(x, y)) # E: Revealed type is 'T`-1' - # TODO: Should we try and handle this differently? - # On one hand, this overload is defined in a weird way so it's arguably - # the user's fault; on the other, there's nothing overtly wrong with it. - inconsistent(x, y) # E: Argument 2 to "inconsistent" has incompatible type "Union[str, int]"; expected "str" + # On one hand, this overload is defined in a weird way; on the other, there's technically nothing wrong with it. + inconsistent(x, y) return x @@ -2805,7 +2803,7 @@ b: int c: Optional[int] reveal_type(g(a)) # E: Revealed type is 'builtins.int' reveal_type(g(b)) # E: Revealed type is 'builtins.str' -reveal_type(g(c)) # E: Revealed type is 'Union[builtins.int, builtins.str]' +reveal_type(g(c)) # E: Revealed type is 'Union[builtins.str, builtins.int]' [case testOverloadsNoneAndTypeVarsWithNoStrictOptional] # flags: --no-strict-optional @@ -2852,7 +2850,7 @@ f3: Optional[Callable[[int], str]] reveal_type(mymap(f1, seq)) # E: Revealed type is 'typing.Iterable[builtins.str*]' reveal_type(mymap(f2, seq)) # E: Revealed type is 'typing.Iterable[builtins.int*]' -reveal_type(mymap(f3, seq)) # E: Revealed type is 'Union[typing.Iterable[builtins.int], typing.Iterable[builtins.str*]]' +reveal_type(mymap(f3, seq)) # E: Revealed type is 'Union[typing.Iterable[builtins.str*], typing.Iterable[builtins.int*]]' [builtins fixtures/list.pyi] [typing fixtures/typing-full.pyi] @@ -3717,3 +3715,39 @@ def relpath(path: str) -> str: ... @overload def relpath(path: unicode) -> unicode: ... [out] + +[case testUnionMathTrickyOverload1] +from typing import Union, overload + +@overload +def f(x: int, y: int) -> int: ... +@overload +def f(x: object, y: str) -> str: ... +def f(x): + pass + +x: Union[int, str] +y: Union[int, str] +f(x, y) +[out] +main:12: error: Argument 1 to "f" has incompatible type "Union[int, str]"; expected "int" +main:12: error: Argument 2 to "f" has incompatible type "Union[int, str]"; expected "int" + +[case testUnionMathTrickyOverload2] +from typing import overload, Union, Any + +class C: + def f(self, other: C) -> C: ... + +class D(C): + @overload # type: ignore + def f(self, other: D) -> D: ... + @overload + def f(self, other: C) -> C: ... + def f(self, other): ... + +x: D +y: Union[D, Any] +# TODO: update after we decide on https://github.com/python/mypy/pull/5254 +reveal_type(x.f(y)) # E: Revealed type is 'Any' +[out] From b796af77c4c865bbc8888e6ef769903b1911ed72 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 21 Jun 2018 08:33:40 +0100 Subject: [PATCH 02/12] Return soon on success --- mypy/checkexpr.py | 8 +++++++- test-data/unit/check-overloading.test | 23 +++++++++++++++++++++++ test-data/unit/pythoneval.test | 27 +++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 1 deletion(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 9ee70f6b0413..1720fdd929ff 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1371,6 +1371,13 @@ def union_overload_result(self, for arg, typ in zip(args, arg_types): del self.type_overrides[arg] return res + # Try direct match before splitting + direct = self.infer_overload_return_type(plausible_targets, args, arg_types, + arg_kinds, arg_names, callable_name, + object_type, context, arg_messages) + if direct is not None and not isinstance(direct[0], UnionType): + # We only return non-unions soon, to avoid gredy match. + return direct first_union = next(typ for typ in arg_types if self.real_union(typ)) idx = arg_types.index(first_union) assert isinstance(first_union, UnionType) @@ -1390,7 +1397,6 @@ def union_overload_result(self, return None if returns: - print('un', returns, UnionType.make_simplified_union(returns, context.line, context.column)) return (UnionType.make_simplified_union(returns, context.line, context.column), UnionType.make_simplified_union(inferred_types, context.line, context.column)) return None diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 5ced8c09bf59..eb8277ec8177 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -3751,3 +3751,26 @@ y: Union[D, Any] # TODO: update after we decide on https://github.com/python/mypy/pull/5254 reveal_type(x.f(y)) # E: Revealed type is 'Any' [out] + +[case testManyUnionsInOverload-skip] +from typing import overload, TypeVar, Union + +T = TypeVar('T') + +@overload +def f(x: int, y: object, z: object, t: object, u: object, w: object, v: object, s: object) -> int: ... +@overload +def f(x: str, y: object, z: object, t: object, u: object, w: object, v: object, s: object) -> str: ... +@overload +def f(x: T, y: object, z: object, t: object, u: object, w: object, v: object, s: object) -> T: ... +def f(*args, **kwargs): + pass + +class A: pass +class B: pass +x: Union[int, str, A, B] +y = f(x, x, x, x, x, x, x, x) # 8 args + +reveal_type(y) +[builtins fixtures/dict.pyi] +[out] diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index 0bb9da3525fa..411838eb82a4 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -1272,3 +1272,30 @@ def d() -> Dict[int, int]: return {} [out] _testDictWithStarStarSpecialCase.py:4: error: Argument 1 to "update" of "dict" has incompatible type "Dict[int, int]"; expected "Mapping[int, str]" + +[case testLoadsOfOverloads] +from typing import overload, Any, TypeVar, Iterable, List, Dict, Callable, Union + +S = TypeVar('S') +T = TypeVar('T') + +@overload +def simple_map() -> None: ... +@overload +def simple_map(func: Callable[[T], S], one: Iterable[T]) -> S: ... +@overload +def simple_map(func: Callable[..., S], *iterables: Iterable[Any]) -> S: ... +def simple_map(*args): pass + +def format_row(*entries: object) -> str: pass + +class DateTime: pass +JsonBlob = Dict[str, Any] +Column = Union[List[str], List[int], List[bool], List[float], List[DateTime], List[JsonBlob]] + +def print_custom_table() -> None: + a: Column + + for row in simple_map(format_row, a, a, a, a, a, a, a, a): # 8 columns + print(row) +[out] From e98cb76eaca1c1d67411d5352db48826f1b01181 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 1 Jul 2018 18:18:09 +0100 Subject: [PATCH 03/12] Minor fixes --- mypy/checkexpr.py | 23 ++++++++++++----------- test-data/unit/check-overloading.test | 13 ++++++------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index a0167620ed55..523de4baa552 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -530,9 +530,7 @@ def check_call(self, callee: Type, args: List[Expression], callable_node: Optional[Expression] = None, arg_messages: Optional[MessageBuilder] = None, callable_name: Optional[str] = None, - object_type: Optional[Type] = None, - *, - arg_types_override: Optional[List[Type]] = None) -> Tuple[Type, Type]: + object_type: Optional[Type] = None) -> Tuple[Type, Type]: """Type check a call. Also infer type arguments if the callee is a generic function. @@ -588,11 +586,9 @@ def check_call(self, callee: Type, args: List[Expression], callee, context) callee = self.infer_function_type_arguments( callee, args, arg_kinds, formal_to_actual, context) - if arg_types_override is not None: - arg_types = arg_types_override.copy() - else: - arg_types = self.infer_arg_types_in_context2( - callee, args, arg_kinds, formal_to_actual) + + arg_types = self.infer_arg_types_in_context2( + callee, args, arg_kinds, formal_to_actual) self.check_argument_count(callee, arg_types, arg_kinds, arg_names, formal_to_actual, context, self.msg) @@ -1165,7 +1161,8 @@ def check_overload_call(self, return inferred_result else: assert unioned_result is not None - if is_subtype(inferred_result[0], unioned_result[0]): + if (is_subtype(inferred_result[0], unioned_result[0]) and + not isinstance(inferred_result[0], AnyType)): return inferred_result return unioned_result elif union_success: @@ -1380,11 +1377,15 @@ def union_overload_result(self, del self.type_overrides[arg] return res # Try direct match before splitting + for arg, typ in zip(args, arg_types): + self.type_overrides[arg] = typ direct = self.infer_overload_return_type(plausible_targets, args, arg_types, arg_kinds, arg_names, callable_name, object_type, context, arg_messages) - if direct is not None and not isinstance(direct[0], UnionType): - # We only return non-unions soon, to avoid gredy match. + for arg, typ in zip(args, arg_types): + del self.type_overrides[arg] + if direct is not None and not isinstance(direct[0], (UnionType, AnyType)): + # We only return non-unions soon, to avoid greedy match. return direct first_union = next(typ for typ in arg_types if self.real_union(typ)) idx = arg_types.index(first_union) diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index f9b250162b63..86e5d2139ff4 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -2579,7 +2579,7 @@ def f(x: B, y: B = B()) -> Parent: ... def f(*args): ... x: Union[A, B] -f(x) # OK +reveal_type(f(x)) # E: Revealed type is '__main__.Parent' f(x, B()) # E: Argument 1 to "f" has incompatible type "Union[A, B]"; expected "B" [case testOverloadInferUnionWithMixOfPositionalAndOptionalArgs] @@ -2656,7 +2656,7 @@ def wrapper() -> None: obj2: Union[W1[A], W2[B]] - foo(obj2) # OK + reveal_type(foo(obj2)) # E: Revealed type is 'Union[__main__.A*, __main__.B*]' bar(obj2) # E: Cannot infer type argument 1 of "bar" b1_overload: A = foo(obj2) # E: Incompatible types in assignment (expression has type "Union[A, B]", variable has type "A") @@ -3878,7 +3878,7 @@ class C: def f(self, other: C) -> C: ... class D(C): - @overload # type: ignore + @overload def f(self, other: D) -> D: ... @overload def f(self, other: C) -> C: ... @@ -3886,11 +3886,10 @@ class D(C): x: D y: Union[D, Any] -# TODO: update after we decide on https://github.com/python/mypy/pull/5254 -reveal_type(x.f(y)) # E: Revealed type is 'Any' +reveal_type(x.f(y)) # E: Revealed type is 'Union[__main__.D, Any]' [out] -[case testManyUnionsInOverload-skip] +[case testManyUnionsInOverload] from typing import overload, TypeVar, Union T = TypeVar('T') @@ -3909,7 +3908,7 @@ class B: pass x: Union[int, str, A, B] y = f(x, x, x, x, x, x, x, x) # 8 args -reveal_type(y) +reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.str, __main__.A, __main__.B]' [builtins fixtures/dict.pyi] [out] From 70577960684923d1e3403f44a4cb4253440bc4fd Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 1 Jul 2018 20:25:13 +0100 Subject: [PATCH 04/12] More clean-up; switch to unioned callable --- mypy/checkexpr.py | 125 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 109 insertions(+), 16 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 523de4baa552..a592e816eb99 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1,7 +1,11 @@ """Expression type checker. This file is conceptually part of TypeChecker.""" from collections import OrderedDict -from typing import cast, Dict, Set, List, Tuple, Callable, Union, Optional, Iterable, Sequence, Any +from contextlib import contextmanager +from typing import ( + cast, Dict, Set, List, Tuple, Callable, Union, Optional, Iterable, + Sequence, Any, Iterator +) from mypy.errors import report_internal_error from mypy.typeanal import ( @@ -1368,25 +1372,22 @@ def union_overload_result(self, """ if not any(self.real_union(typ) for typ in arg_types): # No unions in args, just fall back to normal inference - for arg, typ in zip(args, arg_types): - self.type_overrides[arg] = typ - res = self.infer_overload_return_type(plausible_targets, args, arg_types, - arg_kinds, arg_names, callable_name, - object_type, context, arg_messages) - for arg, typ in zip(args, arg_types): - del self.type_overrides[arg] + with self.type_overrides_set(args, arg_types): + res = self.infer_overload_return_type(plausible_targets, args, arg_types, + arg_kinds, arg_names, callable_name, + object_type, context, arg_messages) return res + # Try direct match before splitting - for arg, typ in zip(args, arg_types): - self.type_overrides[arg] = typ - direct = self.infer_overload_return_type(plausible_targets, args, arg_types, - arg_kinds, arg_names, callable_name, - object_type, context, arg_messages) - for arg, typ in zip(args, arg_types): - del self.type_overrides[arg] + with self.type_overrides_set(args, arg_types): + direct = self.infer_overload_return_type(plausible_targets, args, arg_types, + arg_kinds, arg_names, callable_name, + object_type, context, arg_messages) if direct is not None and not isinstance(direct[0], (UnionType, AnyType)): # We only return non-unions soon, to avoid greedy match. return direct + + # Split the first remaining union type in arguments first_union = next(typ for typ in arg_types if self.real_union(typ)) idx = arg_types.index(first_union) assert isinstance(first_union, UnionType) @@ -1407,12 +1408,104 @@ def union_overload_result(self, if returns: return (UnionType.make_simplified_union(returns, context.line, context.column), - UnionType.make_simplified_union(inferred_types, context.line, context.column)) + self.union_overload_matches(inferred_types)) return None def real_union(self, typ: Type) -> bool: return isinstance(typ, UnionType) and len(typ.relevant_items()) > 1 + @contextmanager + def type_overrides_set(self, exprs: Iterable[Expression], + overrides: Iterable[Type]) -> Iterator[None]: + """Set _temporary_ type overrides for given expressions.""" + assert len(exprs) == len(overrides) + for expr, typ in zip(exprs, overrides): + self.type_overrides[expr] = typ + try: + yield + finally: + for expr in exprs: + del self.type_overrides[expr] + + def union_overload_matches(self, callables: List[Type]) -> Union[AnyType, CallableType]: + """Accepts a list of overload signatures and attempts to combine them together into a + new CallableType consisting of the union of all of the given arguments and return types. + + If there is at least one non-callabe type, return Any (this can happen if there is + an ambiguity because of Any in arguments). + """ + assert callables, "Trying to merge no callables" + if not all(isinstance(c, CallableType) for c in callables): + return AnyType(TypeOfAny.special_form) + if len(callables) == 1: + return callables[0] + + # Note: we are assuming here that if a user uses some TypeVar 'T' in + # two different overloads, they meant for that TypeVar to mean the + # same thing. + # + # This function will make sure that all instances of that TypeVar 'T' + # refer to the same underlying TypeVarType and TypeVarDef objects to + # simplify the union-ing logic below. + # + # (If the user did *not* mean for 'T' to be consistently bound to the + # same type in their overloads, well, their code is probably too + # confusing and ought to be re-written anyways.) + callables, variables = merge_typevars_in_callables_by_name(callables) + + new_args = [[] for _ in range(len(callables[0].arg_types))] # type: List[List[Type]] + new_kinds = list(callables[0].arg_kinds) + new_returns = [] # type: List[Type] + + too_complex = False + for target in callables: + # We fall back to Callable[..., Union[]] if the overloads do not have + # the exact same signature. The only exception is if one arg is optional and + # the other is positional: in that case, we continue unioning (and expect a + # positional arg). + # TODO: Enhance the merging logic to handle a wider variety of signatures. + if len(new_kinds) != len(target.arg_kinds): + too_complex = True + break + for i, (new_kind, target_kind) in enumerate(zip(new_kinds, target.arg_kinds)): + if new_kind == target_kind: + continue + elif new_kind in (ARG_POS, ARG_OPT) and target_kind in (ARG_POS, ARG_OPT): + new_kinds[i] = ARG_POS + else: + too_complex = True + break + + if too_complex: + break # outer loop + + for i, arg in enumerate(target.arg_types): + new_args[i].append(arg) + new_returns.append(target.ret_type) + + union_return = UnionType.make_simplified_union(new_returns) + if too_complex: + any = AnyType(TypeOfAny.special_form) + return callables[0].copy_modified( + arg_types=[any, any], + arg_kinds=[ARG_STAR, ARG_STAR2], + arg_names=[None, None], + ret_type=union_return, + variables=variables, + implicit=True) + + final_args = [] + for args_list in new_args: + new_type = UnionType.make_simplified_union(args_list) + final_args.append(new_type) + + return callables[0].copy_modified( + arg_types=final_args, + arg_kinds=new_kinds, + ret_type=union_return, + variables=variables, + implicit=True) + def erased_signature_similarity(self, arg_types: List[Type], arg_kinds: List[int], arg_names: Optional[Sequence[Optional[str]]], callee: CallableType, From 4e6a03dd0d77cfbf6f111d42077a4653f1808dbf Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 1 Jul 2018 21:11:37 +0100 Subject: [PATCH 05/12] Postpone creation of unions; introduce nesting level cutoff --- mypy/checkexpr.py | 77 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 55 insertions(+), 22 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index a592e816eb99..a008d79b4c82 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -62,6 +62,16 @@ ArgChecker = Callable[[Type, Type, int, Type, int, int, CallableType, Context, MessageBuilder], None] +# Maximum nesting level for math union in overloads, setting this to large values +# may cause performance issues. +MAX_UNIONS = 5 + + +class TooManyUnions(Exception): + """Indicates that we need to stop splitting unions in an attempt + to match an overload in order to save performance. + """ + def extract_refexpr_names(expr: RefExpr) -> Set[str]: """Recursively extracts all module references from a reference expression. @@ -1147,13 +1157,25 @@ def check_overload_call(self, union_success = False if any(self.real_union(arg) for arg in arg_types): unioned_errors = arg_messages.clean_copy() - unioned_result = self.union_overload_result(plausible_targets, args, arg_types, - arg_kinds, arg_names, - callable_name, object_type, - context, arg_messages=unioned_errors) - # Record if we succeeded. Next we need to see if maybe normal procedure - # gives a narrower type. - union_success = unioned_result is not None and not unioned_errors.is_errors() + union_interrupted = False + try: + unioned_return = self.union_overload_result(plausible_targets, args, + arg_types, arg_kinds, arg_names, + callable_name, object_type, + context, + arg_messages=unioned_errors) + except TooManyUnions: + union_interrupted = True + else: + # Record if we succeeded. Next we need to see if maybe normal procedure + # gives a narrower type. + union_success = unioned_return is not None and not unioned_errors.is_errors() + if unioned_return: + returns, inferred_types = zip(*unioned_return) + unioned_result = (UnionType.make_simplified_union(returns, + context.line, + context.column), + self.union_overload_matches(inferred_types)) # Step 3: We try checking each branch one-by-one. inferred_result = self.infer_overload_return_type(plausible_targets, args, arg_types, @@ -1366,17 +1388,22 @@ def union_overload_result(self, object_type: Optional[Type], context: Context, arg_messages: Optional[MessageBuilder] = None, - ) -> Optional[Tuple[Type, Type]]: + level: int = 0 + ) -> Optional[List[Tuple[Type, Type]]]: """Accepts a list of overload signatures and attempts to match calls by destructuring the first union. Returns None if there is no match. """ + if level >= MAX_UNIONS: + raise TooManyUnions if not any(self.real_union(typ) for typ in arg_types): # No unions in args, just fall back to normal inference with self.type_overrides_set(args, arg_types): res = self.infer_overload_return_type(plausible_targets, args, arg_types, arg_kinds, arg_names, callable_name, object_type, context, arg_messages) - return res + if res is not None: + return [res] + return None # Try direct match before splitting with self.type_overrides_set(args, arg_types): @@ -1385,30 +1412,35 @@ def union_overload_result(self, object_type, context, arg_messages) if direct is not None and not isinstance(direct[0], (UnionType, AnyType)): # We only return non-unions soon, to avoid greedy match. - return direct + return [direct] # Split the first remaining union type in arguments first_union = next(typ for typ in arg_types if self.real_union(typ)) idx = arg_types.index(first_union) assert isinstance(first_union, UnionType) - returns = [] - inferred_types = [] + res_items = [] for item in first_union.relevant_items(): new_arg_types = arg_types.copy() new_arg_types[idx] = item sub_result = self.union_overload_result(plausible_targets, args, new_arg_types, arg_kinds, arg_names, callable_name, - object_type, context, arg_messages) + object_type, context, arg_messages, + level + 1) if sub_result is not None: - ret, inferred = sub_result - returns.append(ret) - inferred_types.append(inferred) + res_items.append(sub_result) else: return None - if returns: - return (UnionType.make_simplified_union(returns, context.line, context.column), - self.union_overload_matches(inferred_types)) + # Flatten union results into a single list of unique items + if res_items: + seen = set() # type: Set[Tuple[Type, Type]] + result = [] + for sub_result in res_items: + for pair in sub_result: + if pair not in seen: + seen.add(pair) + result.append(pair) + return result return None def real_union(self, typ: Type) -> bool: @@ -1427,16 +1459,17 @@ def type_overrides_set(self, exprs: Iterable[Expression], for expr in exprs: del self.type_overrides[expr] - def union_overload_matches(self, callables: List[Type]) -> Union[AnyType, CallableType]: + def union_overload_matches(self, types: List[Type]) -> Union[AnyType, CallableType]: """Accepts a list of overload signatures and attempts to combine them together into a new CallableType consisting of the union of all of the given arguments and return types. If there is at least one non-callabe type, return Any (this can happen if there is an ambiguity because of Any in arguments). """ - assert callables, "Trying to merge no callables" - if not all(isinstance(c, CallableType) for c in callables): + assert types, "Trying to merge no callables" + if not all(isinstance(c, CallableType) for c in types): return AnyType(TypeOfAny.special_form) + callables = cast(List[CallableType], types) if len(callables) == 1: return callables[0] From 33e845f014e90307ba51969d18971857fb537394 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 1 Jul 2018 21:26:16 +0100 Subject: [PATCH 06/12] Add a note when we interrupt union math (with a test) --- mypy/checkexpr.py | 13 ++++++++----- test-data/unit/check-overloading.test | 23 +++++++++++++++++++++++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index a008d79b4c82..3ec9f2498cd3 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1224,11 +1224,14 @@ def check_overload_call(self, if not self.chk.should_suppress_optional_error(arg_types): arg_messages.no_variant_matches_arguments(callee, arg_types, context) target = AnyType(TypeOfAny.from_error) - - return self.check_call(target, args, arg_kinds, context, arg_names, - arg_messages=arg_messages, - callable_name=callable_name, - object_type=object_type) + result = self.check_call(target, args, arg_kinds, context, arg_names, + arg_messages=arg_messages, + callable_name=callable_name, + object_type=object_type) + if union_interrupted: + self.chk.msg.note("Not all union combinations were tried" + " because there are too many unions", context) + return result def plausible_overload_call_targets(self, arg_types: List[Type], diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 86e5d2139ff4..9433f338ecc3 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -3957,3 +3957,26 @@ def none_loose_impl(x: int) -> int: [out] main:22: error: Overloaded function implementation does not accept all possible arguments of signature 1 main:22: error: Overloaded function implementation cannot produce return type of signature 1 + +[case testTooManyUnionsException] +from typing import overload, Union + +@overload +def f(*args: int) -> int: ... +@overload +def f(*args: str) -> str: ... +def f(*args): + pass + +x: Union[int, str] +f(x, x, x, x, x, x, x, x) +[out] +main:11: error: Argument 1 to "f" has incompatible type "Union[int, str]"; expected "int" +main:11: error: Argument 2 to "f" has incompatible type "Union[int, str]"; expected "int" +main:11: error: Argument 3 to "f" has incompatible type "Union[int, str]"; expected "int" +main:11: error: Argument 4 to "f" has incompatible type "Union[int, str]"; expected "int" +main:11: error: Argument 5 to "f" has incompatible type "Union[int, str]"; expected "int" +main:11: error: Argument 6 to "f" has incompatible type "Union[int, str]"; expected "int" +main:11: error: Argument 7 to "f" has incompatible type "Union[int, str]"; expected "int" +main:11: error: Argument 8 to "f" has incompatible type "Union[int, str]"; expected "int" +main:11: note: Not all union combinations were tried because there are too many unions From dd8adba2d44d724d1a3f1f72c83f95c33cabf01c Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 1 Jul 2018 22:05:19 +0100 Subject: [PATCH 07/12] Some logic simplification --- mypy/checkexpr.py | 78 ++++++++++++++++++++++++----------------------- 1 file changed, 40 insertions(+), 38 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 3ec9f2498cd3..e8bf47d0086e 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1153,8 +1153,6 @@ def check_overload_call(self, # typevar. See https://github.com/python/mypy/issues/4063 for related discussion. erased_targets = None # type: Optional[List[CallableType]] unioned_result = None # type: Optional[Tuple[Type, Type]] - unioned_errors = None # type: Optional[MessageBuilder] - union_success = False if any(self.real_union(arg) for arg in arg_types): unioned_errors = arg_messages.clean_copy() union_interrupted = False @@ -1169,7 +1167,6 @@ def check_overload_call(self, else: # Record if we succeeded. Next we need to see if maybe normal procedure # gives a narrower type. - union_success = unioned_return is not None and not unioned_errors.is_errors() if unioned_return: returns, inferred_types = zip(*unioned_return) unioned_result = (UnionType.make_simplified_union(returns, @@ -1181,19 +1178,17 @@ def check_overload_call(self, inferred_result = self.infer_overload_return_type(plausible_targets, args, arg_types, arg_kinds, arg_names, callable_name, object_type, context, arg_messages) - if inferred_result is not None: - # Success! Stop early by returning the best among normal and unioned. - if not union_success: + # If any of checks succeed, stop early. + if inferred_result is not None and unioned_result is not None: + # Both unioned and direct checks succeeded, choose the more precise type. + if (is_subtype(inferred_result[0], unioned_result[0]) and + not isinstance(inferred_result[0], AnyType)): return inferred_result - else: - assert unioned_result is not None - if (is_subtype(inferred_result[0], unioned_result[0]) and - not isinstance(inferred_result[0], AnyType)): - return inferred_result - return unioned_result - elif union_success: - assert unioned_result is not None return unioned_result + elif unioned_result is not None: + return unioned_result + elif inferred_result is not None: + return inferred_result # Step 4: Failure. At this point, we know there is no match. We fall back to trying # to find a somewhat plausible overload target using the erased types @@ -1209,13 +1204,7 @@ def check_overload_call(self, # Step 5: We try and infer a second-best alternative if possible. If not, fall back # to using 'Any'. - if unioned_result is not None: - # When possible, return the error messages generated from the union-math attempt: - # they tend to be a little nicer. - assert unioned_errors is not None - arg_messages.add_errors(unioned_errors) - return unioned_result - elif len(erased_targets) > 0: + if len(erased_targets) > 0: # Pick the first plausible erased target as the fallback # TODO: Adjust the error message here to make it clear there was no match. target = erased_targets[0] # type: Type @@ -1394,11 +1383,23 @@ def union_overload_result(self, level: int = 0 ) -> Optional[List[Tuple[Type, Type]]]: """Accepts a list of overload signatures and attempts to match calls by destructuring - the first union. Returns None if there is no match. + the first union. + + Return a list of (, ) if call succeeds for every + item of the desctructured union. Returns None if there is no match. """ + # Step 1: If we are already too deep, then stop immediately. Otherwise mypy might + # hang for long time because of a weird overload call. The caller will get + # the exception and generate an appropriate note message, if needed. if level >= MAX_UNIONS: raise TooManyUnions - if not any(self.real_union(typ) for typ in arg_types): + + # Step 2: Find position of the first union in arguments. Return the normal infered + # type if no more unions left. + for idx, typ in enumerate(arg_types): + if self.real_union(typ): + break + else: # No unions in args, just fall back to normal inference with self.type_overrides_set(args, arg_types): res = self.infer_overload_return_type(plausible_targets, args, arg_types, @@ -1408,7 +1409,8 @@ def union_overload_result(self, return [res] return None - # Try direct match before splitting + # Step 3: Try a direct match before splitting to avoid unnecessary union splits + # and save performance. with self.type_overrides_set(args, arg_types): direct = self.infer_overload_return_type(plausible_targets, args, arg_types, arg_kinds, arg_names, callable_name, @@ -1417,9 +1419,9 @@ def union_overload_result(self, # We only return non-unions soon, to avoid greedy match. return [direct] - # Split the first remaining union type in arguments - first_union = next(typ for typ in arg_types if self.real_union(typ)) - idx = arg_types.index(first_union) + # Step 4: Split the first remaining union type in arguments into items and + # try to match each item individually (recursive). + first_union = arg_types[idx] assert isinstance(first_union, UnionType) res_items = [] for item in first_union.relevant_items(): @@ -1432,19 +1434,19 @@ def union_overload_result(self, if sub_result is not None: res_items.append(sub_result) else: + # Some item doesn't match, return soon. return None - # Flatten union results into a single list of unique items - if res_items: - seen = set() # type: Set[Tuple[Type, Type]] - result = [] - for sub_result in res_items: - for pair in sub_result: - if pair not in seen: - seen.add(pair) - result.append(pair) - return result - return None + # Step 5: If spliting succeeded, then flatten union results into a single + # list of unique items. + seen = set() # type: Set[Tuple[Type, Type]] + result = [] + for sub_result in res_items: + for pair in sub_result: + if pair not in seen: + seen.add(pair) + result.append(pair) + return result def real_union(self, typ: Type) -> bool: return isinstance(typ, UnionType) and len(typ.relevant_items()) > 1 From c93446508467333cf67e2ae67fd57b9cd86ad3cf Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 1 Jul 2018 22:40:33 +0100 Subject: [PATCH 08/12] Add a comment --- mypy/checkexpr.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index e8bf47d0086e..20e6c9c88259 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -134,6 +134,10 @@ def __init__(self, self.msg = msg self.plugin = plugin self.type_context = [None] + # Temporary overrides for expression types. This is currently + # used by the union math in overloads. + # TODO: refactor this to use a pattern similar to one in + # multiassign_from_union, or maybe even combine the two? self.type_overrides = {} # type: Dict[Expression, Type] self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg) From 5f918772c6f232d9b46b75ec1036f74d511af638 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 1 Jul 2018 22:56:44 +0100 Subject: [PATCH 09/12] Minor fixes --- mypy/checkexpr.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 20e6c9c88259..f3cb2ea7ffd0 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1157,9 +1157,9 @@ def check_overload_call(self, # typevar. See https://github.com/python/mypy/issues/4063 for related discussion. erased_targets = None # type: Optional[List[CallableType]] unioned_result = None # type: Optional[Tuple[Type, Type]] + union_interrupted = False # did we try all union combinations? if any(self.real_union(arg) for arg in arg_types): unioned_errors = arg_messages.clean_copy() - union_interrupted = False try: unioned_return = self.union_overload_result(plausible_targets, args, arg_types, arg_kinds, arg_names, @@ -1173,7 +1173,7 @@ def check_overload_call(self, # gives a narrower type. if unioned_return: returns, inferred_types = zip(*unioned_return) - unioned_result = (UnionType.make_simplified_union(returns, + unioned_result = (UnionType.make_simplified_union(list(returns), context.line, context.column), self.union_overload_matches(inferred_types)) @@ -1456,8 +1456,8 @@ def real_union(self, typ: Type) -> bool: return isinstance(typ, UnionType) and len(typ.relevant_items()) > 1 @contextmanager - def type_overrides_set(self, exprs: Iterable[Expression], - overrides: Iterable[Type]) -> Iterator[None]: + def type_overrides_set(self, exprs: Sequence[Expression], + overrides: Sequence[Type]) -> Iterator[None]: """Set _temporary_ type overrides for given expressions.""" assert len(exprs) == len(overrides) for expr, typ in zip(exprs, overrides): @@ -1468,7 +1468,7 @@ def type_overrides_set(self, exprs: Iterable[Expression], for expr in exprs: del self.type_overrides[expr] - def union_overload_matches(self, types: List[Type]) -> Union[AnyType, CallableType]: + def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, CallableType]: """Accepts a list of overload signatures and attempts to combine them together into a new CallableType consisting of the union of all of the given arguments and return types. From b9659e1b10ef9ffd2dbe221b724458b66e44c64d Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 1 Jul 2018 23:21:48 +0100 Subject: [PATCH 10/12] Fix syntax in pythoneval test --- test-data/unit/pythoneval.test | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index 411838eb82a4..489d76b44dfc 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -1294,7 +1294,7 @@ JsonBlob = Dict[str, Any] Column = Union[List[str], List[int], List[bool], List[float], List[DateTime], List[JsonBlob]] def print_custom_table() -> None: - a: Column + a = None # type: Column for row in simple_map(format_row, a, a, a, a, a, a, a, a): # 8 columns print(row) From 7d02ba8d595df73f8969d30a78cd70d99a413ae6 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 2 Jul 2018 01:09:54 +0100 Subject: [PATCH 11/12] 3.5.1 --- test-data/unit/pythoneval.test | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index 489d76b44dfc..87629c789aca 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -1297,5 +1297,6 @@ def print_custom_table() -> None: a = None # type: Column for row in simple_map(format_row, a, a, a, a, a, a, a, a): # 8 columns - print(row) + reveal_type(row) [out] +_testLoadsOfOverloads.py:24: error: Revealed type is 'builtins.str*' From 643c3b8942daca73e35d698801230067b9bf36e4 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 3 Jul 2018 12:35:48 +0100 Subject: [PATCH 12/12] Address CR --- mypy/checkexpr.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index f3cb2ea7ffd0..ae48fb77c39f 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -63,7 +63,9 @@ None] # Maximum nesting level for math union in overloads, setting this to large values -# may cause performance issues. +# may cause performance issues. The reason is that although union math algorithm we use +# nicely captures most corner cases, its worst case complexity is exponential, +# see https://github.com/python/mypy/pull/5255#discussion_r196896335 for discussion. MAX_UNIONS = 5 @@ -1173,6 +1175,10 @@ def check_overload_call(self, # gives a narrower type. if unioned_return: returns, inferred_types = zip(*unioned_return) + # Note that we use `union_overload_matches` instead of just returning + # a union of inferred callables because for example a call + # Union[int -> int, str -> str](Union[int, str]) is invalid and + # we don't want to introduce internal inconsistencies. unioned_result = (UnionType.make_simplified_union(list(returns), context.line, context.column), @@ -1398,7 +1404,7 @@ def union_overload_result(self, if level >= MAX_UNIONS: raise TooManyUnions - # Step 2: Find position of the first union in arguments. Return the normal infered + # Step 2: Find position of the first union in arguments. Return the normal inferred # type if no more unions left. for idx, typ in enumerate(arg_types): if self.real_union(typ): @@ -1436,20 +1442,18 @@ def union_overload_result(self, object_type, context, arg_messages, level + 1) if sub_result is not None: - res_items.append(sub_result) + res_items.extend(sub_result) else: # Some item doesn't match, return soon. return None - # Step 5: If spliting succeeded, then flatten union results into a single - # list of unique items. + # Step 5: If splitting succeeded, then filter out duplicate items before returning. seen = set() # type: Set[Tuple[Type, Type]] result = [] - for sub_result in res_items: - for pair in sub_result: - if pair not in seen: - seen.add(pair) - result.append(pair) + for pair in res_items: + if pair not in seen: + seen.add(pair) + result.append(pair) return result def real_union(self, typ: Type) -> bool: @@ -1472,7 +1476,7 @@ def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, Callab """Accepts a list of overload signatures and attempts to combine them together into a new CallableType consisting of the union of all of the given arguments and return types. - If there is at least one non-callabe type, return Any (this can happen if there is + If there is at least one non-callable type, return Any (this can happen if there is an ambiguity because of Any in arguments). """ assert types, "Trying to merge no callables"