Skip to content

Adds support for basic union math with overloads #4842

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,8 +1801,8 @@ def check_multi_assignment_from_union(self, lvalues: List[Expression], rvalue: E
expr = expr.expr
types, declared_types = zip(*items)
self.binder.assign_type(expr,
UnionType.make_simplified_union(types),
UnionType.make_simplified_union(declared_types),
UnionType.make_simplified_union(list(types)),
UnionType.make_simplified_union(list(declared_types)),
False)
for union, lv in zip(union_types, self.flatten_lvalues(lvalues)):
# Properly store the inferred types.
Expand Down
226 changes: 168 additions & 58 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,10 +611,64 @@ def check_call(self, callee: Type, args: List[Expression],
arg_types = self.infer_arg_types_in_context(None, args)
self.msg.enable_errors()

target = self.overload_call_target(arg_types, arg_kinds, arg_names,
callee, context,
messages=arg_messages)
return self.check_call(target, args, arg_kinds, context, arg_names,
overload_messages = arg_messages.copy()
targets = self.overload_call_targets(arg_types, arg_kinds, arg_names,
callee, context,
messages=overload_messages)

# If there are multiple targets, that means that there were
# either multiple possible matches or the types were overlapping in some
# way. In either case, we default to picking the first match and
# see what happens if we try using it.
#
# Note: if we pass in an argument that inherits from two overloaded
# types, we default to picking the first match. For example:
#
# class A: pass
# class B: pass
# class C(A, B): pass
#
# @overload
# def f(x: A) -> int: ...
# @overload
# def f(x: B) -> str: ...
# def f(x): ...
#
# reveal_type(f(C())) # Will be 'int', not 'Union[int, str]'
#
# It's unclear if this is really the best thing to do, but multiple
# inheritance is rare. See the docstring of mypy.meet.is_overlapping_types
# for more about this.

original_output = self.check_call(targets[0], args, arg_kinds, context, arg_names,
arg_messages=overload_messages,
callable_name=callable_name,
object_type=object_type)

if not overload_messages.is_errors() or len(targets) == 1:
# If there were no errors or if there was only one match, we can end now.
#
# Note that if we have only one target, there's nothing else we
# can try doing. In that case, we just give up and return early
# and skip the below steps.
arg_messages.add_errors(overload_messages)
return original_output

# Otherwise, we attempt to synthesize together a new callable by combining
# together the different matches by union-ing together their arguments
# and return type.

targets = cast(List[CallableType], targets)
unioned_callable = self.union_overload_matches(
targets, args, arg_kinds, arg_names, context)
if unioned_callable is None:
# If it was not possible to actually combine together the
# callables in a sound way, we give up and return the original
# error message.
arg_messages.add_errors(overload_messages)
return original_output

return self.check_call(unioned_callable, args, arg_kinds, context, arg_names,
arg_messages=arg_messages,
callable_name=callable_name,
object_type=object_type)
Expand Down Expand Up @@ -1089,83 +1143,139 @@ def check_arg(self, caller_type: Type, original_caller_type: Type,
(callee_type.item.type.is_abstract or callee_type.item.type.is_protocol) and
# ...except for classmethod first argument
not caller_type.is_classmethod_class):
self.msg.concrete_only_call(callee_type, context)
messages.concrete_only_call(callee_type, context)
elif not is_subtype(caller_type, callee_type):
if self.chk.should_suppress_optional_error([caller_type, callee_type]):
return
messages.incompatible_argument(n, m, callee, original_caller_type,
caller_kind, context)
if (isinstance(original_caller_type, (Instance, TupleType, TypedDictType)) and
isinstance(callee_type, Instance) and callee_type.type.is_protocol):
self.msg.report_protocol_problems(original_caller_type, callee_type, context)
messages.report_protocol_problems(original_caller_type, callee_type, context)
if (isinstance(callee_type, CallableType) and
isinstance(original_caller_type, Instance)):
call = find_member('__call__', original_caller_type, original_caller_type)
if call:
self.msg.note_call(original_caller_type, call, context)

def overload_call_target(self, arg_types: List[Type], arg_kinds: List[int],
arg_names: Optional[Sequence[Optional[str]]],
overload: Overloaded, context: Context,
messages: Optional[MessageBuilder] = None) -> Type:
"""Infer the correct overload item to call with given argument types.

The return value may be CallableType or AnyType (if an unique item
could not be determined).
messages.note_call(original_caller_type, call, context)

def overload_call_targets(self, arg_types: List[Type], arg_kinds: List[int],
arg_names: Optional[Sequence[Optional[str]]],
overload: Overloaded, context: Context,
messages: Optional[MessageBuilder] = None) -> Sequence[Type]:
"""Infer all possible overload targets to call with given argument types.
The list is guaranteed be one of the following:

1. A List[CallableType] of length 1 if we were able to find an
unambiguous best match.
2. A List[AnyType] of length 1 if we were unable to find any match
or discovered the match was ambiguous due to conflicting Any types.
3. A List[CallableType] of length 2 or more if there were multiple
plausible matches. The matches are returned in the order they
were defined.
"""
messages = messages or self.msg
# TODO: For overlapping signatures we should try to get a more precise
# result than 'Any'.
match = [] # type: List[CallableType]
best_match = 0
for typ in overload.items():
similarity = self.erased_signature_similarity(arg_types, arg_kinds, arg_names,
typ, context=context)
if similarity > 0 and similarity >= best_match:
if (match and not is_same_type(match[-1].ret_type,
typ.ret_type) and
(not mypy.checker.is_more_precise_signature(match[-1], typ)
or (any(isinstance(arg, AnyType) for arg in arg_types)
and any_arg_causes_overload_ambiguity(
match + [typ], arg_types, arg_kinds, arg_names)))):
# Ambiguous return type. Either the function overload is
# overlapping (which we don't handle very well here) or the
# caller has provided some Any argument types; in either
# case we'll fall back to Any. It's okay to use Any types
# in calls.
#
# Overlapping overload items are generally fine if the
# overlapping is only possible when there is multiple
# inheritance, as this is rare. See docstring of
# mypy.meet.is_overlapping_types for more about this.
#
# Note that there is no ambiguity if the items are
# covariant in both argument types and return types with
# respect to type precision. We'll pick the best/closest
# match.
#
# TODO: Consider returning a union type instead if the
# overlapping is NOT due to Any types?
return AnyType(TypeOfAny.special_form)
else:
match.append(typ)
if (match and not is_same_type(match[-1].ret_type, typ.ret_type)
and any(isinstance(arg, AnyType) for arg in arg_types)
and any_arg_causes_overload_ambiguity(
match + [typ], arg_types, arg_kinds, arg_names)):
# Ambiguous return type. The caller has provided some
# Any argument types (which are okay to use in calls),
# so we fall back to returning 'Any'.
return [AnyType(TypeOfAny.special_form)]
match.append(typ)
best_match = max(best_match, similarity)
if not match:

if len(match) == 0:
if not self.chk.should_suppress_optional_error(arg_types):
messages.no_variant_matches_arguments(overload, arg_types, context)
return AnyType(TypeOfAny.from_error)
return [AnyType(TypeOfAny.from_error)]
elif len(match) == 1:
return match
else:
if len(match) == 1:
return match[0]
else:
# More than one signature matches. Pick the first *non-erased*
# matching signature, or default to the first one if none
# match.
for m in match:
if self.match_signature_types(arg_types, arg_kinds, arg_names, m,
context=context):
return m
return match[0]
# More than one signature matches or the signatures are
# overlapping. In either case, we return all of the matching
# signatures and let the caller decide what to do with them.
out = [m for m in match if self.match_signature_types(
arg_types, arg_kinds, arg_names, m, context=context)]
return out if len(out) >= 1 else match

def union_overload_matches(self, callables: List[CallableType],
args: List[Expression],
arg_kinds: List[int],
arg_names: Optional[Sequence[Optional[str]]],
context: Context) -> Optional[CallableType]:
"""Accepts a list of overload signatures and attempts to combine them together into a
new CallableType consisting of the union of all of the given arguments and return types.

Returns None if it is not possible to combine the different callables together in a
sound manner."""
new_args = [[] for _ in range(len(callables[0].arg_types))] # type: List[List[Type]]
new_returns = [] # type: List[Type]

expected_names = callables[0].arg_names
expected_kinds = callables[0].arg_kinds

for target in callables:
if target.arg_names != expected_names or target.arg_kinds != expected_kinds:
# We conservatively end if the overloads do not have the exact same signature.
# TODO: Enhance the union overload logic to handle a wider variety of signatures.
return None

if target.is_generic():
formal_to_actual = map_actuals_to_formals(
arg_kinds, arg_names,
target.arg_kinds, target.arg_names,
lambda i: self.accept(args[i]))

target = freshen_function_type_vars(target)
target = self.infer_function_type_arguments_using_context(target, context)
target = self.infer_function_type_arguments(
target, args, arg_kinds, formal_to_actual, context)

for i, arg in enumerate(target.arg_types):
new_args[i].append(arg)
new_returns.append(target.ret_type)

union_count = 0
final_args = []
for args_list in new_args:
new_type = UnionType.make_simplified_union(args_list)
union_count += 1 if isinstance(new_type, UnionType) else 0
final_args.append(new_type)

# TODO: Modify this check to be less conservative.
#
# Currently, we permit only one union union in the arguments because if we allow
# multiple, we can't always guarantee the synthesized callable will be correct.
#
# For example, suppose we had the following two overloads:
#
# @overload
# def f(x: A, y: B) -> None: ...
# @overload
# def f(x: B, y: A) -> None: ...
#
# If we continued and synthesize "def f(x: Union[A,B], y: Union[A,B]) -> None: ...",
# then we'd incorrectly accept calls like "f(A(), A())" when they really ought to
# be rejected.
#
# However, that means we'll also give up if the original overloads contained
# any unions. This is likely unnecessary -- we only really need to give up if
# there are more then one *synthesized* union arguments.
if union_count >= 2:
return None

return callables[0].copy_modified(
arg_types=final_args,
ret_type=UnionType.make_simplified_union(new_returns),
implicit=True,
from_overloads=True)

def erased_signature_similarity(self, arg_types: List[Type], arg_kinds: List[int],
arg_names: Optional[Sequence[Optional[str]]],
Expand Down
14 changes: 12 additions & 2 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,19 @@ def incompatible_argument(self, n: int, m: int, callee: CallableType, arg_type:
expected_type = callee.arg_types[m - 1]
except IndexError: # Varargs callees
expected_type = callee.arg_types[-1]

arg_type_str, expected_type_str = self.format_distinctly(
arg_type, expected_type, bare=True)
expected_type_str = self.quote_type_string(expected_type_str)

if callee.from_overloads and isinstance(expected_type, UnionType):
expected_formatted = []
for e in expected_type.items:
type_str = self.format_distinctly(arg_type, e, bare=True)[1]
expected_formatted.append(self.quote_type_string(type_str))
expected_type_str = 'one of {} based on available overloads'.format(
', '.join(expected_formatted))

if arg_kind == ARG_STAR:
arg_type_str = '*' + arg_type_str
elif arg_kind == ARG_STAR2:
Expand All @@ -645,8 +656,7 @@ def incompatible_argument(self, n: int, m: int, callee: CallableType, arg_type:
arg_label = '"{}"'.format(arg_name)

msg = 'Argument {} {}has incompatible type {}; expected {}'.format(
arg_label, target, self.quote_type_string(arg_type_str),
self.quote_type_string(expected_type_str))
arg_label, target, self.quote_type_string(arg_type_str), expected_type_str)
if isinstance(arg_type, Instance) and isinstance(expected_type, Instance):
notes = append_invariance_notes(notes, arg_type, expected_type)
self.fail(msg, context)
Expand Down
2 changes: 1 addition & 1 deletion mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2204,7 +2204,7 @@ def process_module_assignment(self, lvals: List[Lvalue], rval: Expression,
# about the length mismatch in type-checking.
elementwise_assignments = zip(rval.items, *[v.items for v in seq_lvals])
for rv, *lvs in elementwise_assignments:
self.process_module_assignment(lvs, rv, ctx)
self.process_module_assignment(list(lvs), rv, ctx)
elif isinstance(rval, RefExpr):
rnode = self.lookup_type_node(rval)
if rnode and rnode.kind == MODULE_REF:
Expand Down
Loading