Skip to content

Commit 3dce3fd

Browse files
authored
Add support for narrowing Literals using equality (#8151)
This pull request (finally) adds support for narrowing expressions using Literal types by equality, instead of just identity. For example, the following "tagged union" pattern is now supported: ```python class Foo(TypedDict): key: Literal["A"] blah: int class Bar(TypedDict): key: Literal["B"] something: str x: Union[Foo, Bar] if x.key == "A": reveal_type(x) # Revealed type is 'Foo' else: reveal_type(x) # Revealed type is 'Bar' ``` Previously, this was possible to do only with Enum Literals and the `is` operator, which is perhaps not very intuitive. The main limitation with this pull request is that it'll perform narrowing only if either the LHS or RHS contains an explicit Literal type somewhere. If this limitation is not present, we end up breaking a decent amount of real-world code -- mostly tests -- that do something like this: ```python def some_test_case() -> None: worker = Worker() # Without the limitation, we narrow 'worker.state' to # Literal['ready'] in this assert... assert worker.state == 'ready' worker.start() # ...which subsequently causes this second assert to narrow # worker.state to <uninhabited>, causing the last line to be # unreachable. assert worker.state == 'running' worker.query() ``` I tried for several weeks to find a more intelligent way around this problem, but everything I tried ended up being either insufficient or super-hacky, so I gave up and went for this brute-force solution. The other main limitation is that we perform narrowing only if both the LHS and RHS do not define custom `__eq__` or `__ne__` methods, but this seems like a more reasonable one to me. Resolves #7944.
1 parent 35b5039 commit 3dce3fd

File tree

7 files changed

+755
-174
lines changed

7 files changed

+755
-174
lines changed

mypy/checker.py

+122-92
Large diffs are not rendered by default.

mypy/checkexpr.py

+3-20
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,16 @@
5050
from mypy import erasetype
5151
from mypy.checkmember import analyze_member_access, type_object_type
5252
from mypy.argmap import ArgTypeExpander, map_actuals_to_formals, map_formals_to_actuals
53-
from mypy.checkstrformat import StringFormatterChecker, custom_special_method
53+
from mypy.checkstrformat import StringFormatterChecker
5454
from mypy.expandtype import expand_type, expand_type_by_instance, freshen_function_type_vars
5555
from mypy.util import split_module_names
5656
from mypy.typevars import fill_typevars
5757
from mypy.visitor import ExpressionVisitor
5858
from mypy.plugin import Plugin, MethodContext, MethodSigContext, FunctionContext
5959
from mypy.typeops import (
6060
tuple_fallback, make_simplified_union, true_only, false_only, erase_to_union_or_bound,
61-
function_type, callable_type, try_getting_str_literals
61+
function_type, callable_type, try_getting_str_literals, custom_special_method,
62+
is_literal_type_like,
6263
)
6364
import mypy.errorcodes as codes
6465

@@ -4265,24 +4266,6 @@ def merge_typevars_in_callables_by_name(
42654266
return output, variables
42664267

42674268

4268-
def is_literal_type_like(t: Optional[Type]) -> bool:
4269-
"""Returns 'true' if the given type context is potentially either a LiteralType,
4270-
a Union of LiteralType, or something similar.
4271-
"""
4272-
t = get_proper_type(t)
4273-
if t is None:
4274-
return False
4275-
elif isinstance(t, LiteralType):
4276-
return True
4277-
elif isinstance(t, UnionType):
4278-
return any(is_literal_type_like(item) for item in t.items)
4279-
elif isinstance(t, TypeVarType):
4280-
return (is_literal_type_like(t.upper_bound)
4281-
or any(is_literal_type_like(item) for item in t.values))
4282-
else:
4283-
return False
4284-
4285-
42864269
def try_getting_literal(typ: Type) -> ProperType:
42874270
"""If possible, get a more precise literal type for a given type."""
42884271
typ = get_proper_type(typ)

mypy/checkstrformat.py

+3-32
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919

2020
from mypy.types import (
2121
Type, AnyType, TupleType, Instance, UnionType, TypeOfAny, get_proper_type, TypeVarType,
22-
CallableType, LiteralType, get_proper_types
22+
LiteralType, get_proper_types
2323
)
2424
from mypy.nodes import (
2525
StrExpr, BytesExpr, UnicodeExpr, TupleExpr, DictExpr, Context, Expression, StarExpr, CallExpr,
2626
IndexExpr, MemberExpr, TempNode, ARG_POS, ARG_STAR, ARG_NAMED, ARG_STAR2,
27-
SYMBOL_FUNCBASE_TYPES, Decorator, Var, Node, MypyFile, ExpressionStmt, NameExpr, IntExpr
27+
Node, MypyFile, ExpressionStmt, NameExpr, IntExpr
2828
)
2929
import mypy.errorcodes as codes
3030

@@ -35,7 +35,7 @@
3535
from mypy import message_registry
3636
from mypy.messages import MessageBuilder
3737
from mypy.maptype import map_instance_to_supertype
38-
from mypy.typeops import tuple_fallback
38+
from mypy.typeops import custom_special_method
3939
from mypy.subtypes import is_subtype
4040
from mypy.parse import parse
4141

@@ -961,32 +961,3 @@ def has_type_component(typ: Type, fullname: str) -> bool:
961961
elif isinstance(typ, UnionType):
962962
return any(has_type_component(t, fullname) for t in typ.relevant_items())
963963
return False
964-
965-
966-
def custom_special_method(typ: Type, name: str,
967-
check_all: bool = False) -> bool:
968-
"""Does this type have a custom special method such as __format__() or __eq__()?
969-
970-
If check_all is True ensure all items of a union have a custom method, not just some.
971-
"""
972-
typ = get_proper_type(typ)
973-
if isinstance(typ, Instance):
974-
method = typ.type.get(name)
975-
if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)):
976-
if method.node.info:
977-
return not method.node.info.fullname.startswith('builtins.')
978-
return False
979-
if isinstance(typ, UnionType):
980-
if check_all:
981-
return all(custom_special_method(t, name, check_all) for t in typ.items)
982-
return any(custom_special_method(t, name) for t in typ.items)
983-
if isinstance(typ, TupleType):
984-
return custom_special_method(tuple_fallback(typ), name)
985-
if isinstance(typ, CallableType) and typ.is_type_obj():
986-
# Look up __method__ on the metaclass for class objects.
987-
return custom_special_method(typ.fallback, name)
988-
if isinstance(typ, AnyType):
989-
# Avoid false positives in uncertain cases.
990-
return True
991-
# TODO: support other types (see ExpressionChecker.has_member())?
992-
return False

mypy/typeops.py

+50-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
from mypy.nodes import (
1919
FuncBase, FuncItem, OverloadedFuncDef, TypeInfo, ARG_STAR, ARG_STAR2, ARG_POS,
20-
Expression, StrExpr, Var
20+
Expression, StrExpr, Var, Decorator, SYMBOL_FUNCBASE_TYPES
2121
)
2222
from mypy.maptype import map_instance_to_supertype
2323
from mypy.expandtype import expand_type_by_instance, expand_type
@@ -564,6 +564,24 @@ def try_getting_literals_from_type(typ: Type,
564564
return literals
565565

566566

567+
def is_literal_type_like(t: Optional[Type]) -> bool:
568+
"""Returns 'true' if the given type context is potentially either a LiteralType,
569+
a Union of LiteralType, or something similar.
570+
"""
571+
t = get_proper_type(t)
572+
if t is None:
573+
return False
574+
elif isinstance(t, LiteralType):
575+
return True
576+
elif isinstance(t, UnionType):
577+
return any(is_literal_type_like(item) for item in t.items)
578+
elif isinstance(t, TypeVarType):
579+
return (is_literal_type_like(t.upper_bound)
580+
or any(is_literal_type_like(item) for item in t.values))
581+
else:
582+
return False
583+
584+
567585
def get_enum_values(typ: Instance) -> List[str]:
568586
"""Return the list of values for an Enum."""
569587
return [name for name, sym in typ.type.names.items() if isinstance(sym.node, Var)]
@@ -640,10 +658,11 @@ class Status(Enum):
640658
return typ
641659

642660

643-
def coerce_to_literal(typ: Type) -> ProperType:
661+
def coerce_to_literal(typ: Type) -> Type:
644662
"""Recursively converts any Instances that have a last_known_value or are
645663
instances of enum types with a single value into the corresponding LiteralType.
646664
"""
665+
original_type = typ
647666
typ = get_proper_type(typ)
648667
if isinstance(typ, UnionType):
649668
new_items = [coerce_to_literal(item) for item in typ.items]
@@ -655,7 +674,7 @@ def coerce_to_literal(typ: Type) -> ProperType:
655674
enum_values = get_enum_values(typ)
656675
if len(enum_values) == 1:
657676
return LiteralType(value=enum_values[0], fallback=typ)
658-
return typ
677+
return original_type
659678

660679

661680
def get_type_vars(tp: Type) -> List[TypeVarType]:
@@ -674,3 +693,31 @@ def _merge(self, iter: Iterable[List[TypeVarType]]) -> List[TypeVarType]:
674693

675694
def visit_type_var(self, t: TypeVarType) -> List[TypeVarType]:
676695
return [t]
696+
697+
698+
def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool:
699+
"""Does this type have a custom special method such as __format__() or __eq__()?
700+
701+
If check_all is True ensure all items of a union have a custom method, not just some.
702+
"""
703+
typ = get_proper_type(typ)
704+
if isinstance(typ, Instance):
705+
method = typ.type.get(name)
706+
if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)):
707+
if method.node.info:
708+
return not method.node.info.fullname.startswith('builtins.')
709+
return False
710+
if isinstance(typ, UnionType):
711+
if check_all:
712+
return all(custom_special_method(t, name, check_all) for t in typ.items)
713+
return any(custom_special_method(t, name) for t in typ.items)
714+
if isinstance(typ, TupleType):
715+
return custom_special_method(tuple_fallback(typ), name, check_all)
716+
if isinstance(typ, CallableType) and typ.is_type_obj():
717+
# Look up __method__ on the metaclass for class objects.
718+
return custom_special_method(typ.fallback, name, check_all)
719+
if isinstance(typ, AnyType):
720+
# Avoid false positives in uncertain cases.
721+
return True
722+
# TODO: support other types (see ExpressionChecker.has_member())?
723+
return False

test-data/unit/check-enum.test

+42-27
Original file line numberDiff line numberDiff line change
@@ -978,32 +978,43 @@ class Foo(Enum):
978978
x: Foo
979979
y: Foo
980980

981+
# We can't narrow anything in the else cases -- what if
982+
# x is Foo.A and y is Foo.B or vice versa, for example?
981983
if x is y is Foo.A:
982-
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
983-
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
984+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
985+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
986+
elif x is y is Foo.B:
987+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
988+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
984989
else:
985-
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
986-
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
987-
reveal_type(x) # N: Revealed type is '__main__.Foo'
988-
reveal_type(y) # N: Revealed type is '__main__.Foo'
990+
reveal_type(x) # N: Revealed type is '__main__.Foo'
991+
reveal_type(y) # N: Revealed type is '__main__.Foo'
992+
reveal_type(x) # N: Revealed type is '__main__.Foo'
993+
reveal_type(y) # N: Revealed type is '__main__.Foo'
989994

990995
if x is Foo.A is y:
991-
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
992-
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
996+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
997+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
998+
elif x is Foo.B is y:
999+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
1000+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
9931001
else:
994-
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
995-
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
996-
reveal_type(x) # N: Revealed type is '__main__.Foo'
997-
reveal_type(y) # N: Revealed type is '__main__.Foo'
1002+
reveal_type(x) # N: Revealed type is '__main__.Foo'
1003+
reveal_type(y) # N: Revealed type is '__main__.Foo'
1004+
reveal_type(x) # N: Revealed type is '__main__.Foo'
1005+
reveal_type(y) # N: Revealed type is '__main__.Foo'
9981006

9991007
if Foo.A is x is y:
1000-
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
1001-
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
1008+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
1009+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
1010+
elif Foo.B is x is y:
1011+
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
1012+
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
10021013
else:
1003-
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
1004-
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
1005-
reveal_type(x) # N: Revealed type is '__main__.Foo'
1006-
reveal_type(y) # N: Revealed type is '__main__.Foo'
1014+
reveal_type(x) # N: Revealed type is '__main__.Foo'
1015+
reveal_type(y) # N: Revealed type is '__main__.Foo'
1016+
reveal_type(x) # N: Revealed type is '__main__.Foo'
1017+
reveal_type(y) # N: Revealed type is '__main__.Foo'
10071018

10081019
[builtins fixtures/primitives.pyi]
10091020

@@ -1026,8 +1037,10 @@ if x is Foo.A < y is Foo.B:
10261037
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
10271038
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]'
10281039
else:
1029-
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]'
1030-
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
1040+
# Note: we can't narrow in this case. What if both x and y
1041+
# are Foo.A, for example?
1042+
reveal_type(x) # N: Revealed type is '__main__.Foo'
1043+
reveal_type(y) # N: Revealed type is '__main__.Foo'
10311044
reveal_type(x) # N: Revealed type is '__main__.Foo'
10321045
reveal_type(y) # N: Revealed type is '__main__.Foo'
10331046

@@ -1109,11 +1122,13 @@ if x0 is x1 is Foo.A is x2 < x3 is Foo.B is x4 is x5:
11091122
reveal_type(x4) # N: Revealed type is 'Literal[__main__.Foo.B]'
11101123
reveal_type(x5) # N: Revealed type is 'Literal[__main__.Foo.B]'
11111124
else:
1112-
reveal_type(x0) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
1113-
reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
1114-
reveal_type(x2) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
1115-
1116-
reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]'
1117-
reveal_type(x4) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]'
1118-
reveal_type(x5) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]'
1125+
# We unfortunately can't narrow away anything. For example,
1126+
# what if x0 == Foo.A and x1 == Foo.B or vice versa?
1127+
reveal_type(x0) # N: Revealed type is '__main__.Foo'
1128+
reveal_type(x1) # N: Revealed type is '__main__.Foo'
1129+
reveal_type(x2) # N: Revealed type is '__main__.Foo'
1130+
1131+
reveal_type(x3) # N: Revealed type is '__main__.Foo'
1132+
reveal_type(x4) # N: Revealed type is '__main__.Foo'
1133+
reveal_type(x5) # N: Revealed type is '__main__.Foo'
11191134
[builtins fixtures/primitives.pyi]

0 commit comments

Comments
 (0)