|
57 | 57 | UninhabitedType,
|
58 | 58 | UnionType,
|
59 | 59 | UnpackType,
|
| 60 | + flatten_nested_unions, |
60 | 61 | get_proper_type,
|
61 | 62 | is_named_instance,
|
62 | 63 | )
|
@@ -877,19 +878,50 @@ def visit_overloaded(self, left: Overloaded) -> bool:
|
877 | 878 | return False
|
878 | 879 |
|
879 | 880 | def visit_union_type(self, left: UnionType) -> bool:
|
880 |
| - if isinstance(self.right, Instance): |
| 881 | + if isinstance(self.right, (UnionType, Instance)): |
| 882 | + # prune literals early to avoid nasty quadratic behavior which would otherwise arise when checking |
| 883 | + # subtype relationships between slightly different narrowings of an Enum |
| 884 | + # we achieve O(N+M) instead of O(N*M) |
| 885 | + |
| 886 | + right_lit_types: set[Instance] = set() |
| 887 | + right_lit_values: set[LiteralType] = set() |
| 888 | + |
| 889 | + if isinstance(self.right, UnionType): |
| 890 | + for item in flatten_nested_unions( |
| 891 | + self.right.relevant_items(), handle_type_alias_type=True |
| 892 | + ): |
| 893 | + p_item = get_proper_type(item) |
| 894 | + if isinstance(p_item, LiteralType): |
| 895 | + right_lit_values.add(p_item) |
| 896 | + elif isinstance(p_item, Instance): |
| 897 | + if p_item.last_known_value is None: |
| 898 | + right_lit_types.add(p_item) |
| 899 | + else: |
| 900 | + right_lit_values.add(p_item.last_known_value) |
| 901 | + elif isinstance(self.right, Instance): |
| 902 | + if self.right.last_known_value is None: |
| 903 | + right_lit_types.add(self.right) |
| 904 | + else: |
| 905 | + right_lit_values.add(self.right.last_known_value) |
| 906 | + |
881 | 907 | literal_types: set[Instance] = set()
|
882 |
| - # avoid redundant check for union of literals |
883 | 908 | for item in left.relevant_items():
|
884 | 909 | p_item = get_proper_type(item)
|
| 910 | + if p_item in right_lit_types or p_item in right_lit_values: |
| 911 | + continue |
885 | 912 | lit_type = mypy.typeops.simple_literal_type(p_item)
|
886 | 913 | if lit_type is not None:
|
887 |
| - if lit_type in literal_types: |
| 914 | + if lit_type in right_lit_types: |
888 | 915 | continue
|
889 |
| - literal_types.add(lit_type) |
890 |
| - item = lit_type |
| 916 | + if isinstance(self.right, Instance): |
| 917 | + if lit_type in literal_types: |
| 918 | + continue |
| 919 | + literal_types.add(lit_type) |
| 920 | + item = lit_type |
| 921 | + |
891 | 922 | if not self._is_subtype(item, self.orig_right):
|
892 | 923 | return False
|
| 924 | + |
893 | 925 | return True
|
894 | 926 | return all(self._is_subtype(item, self.orig_right) for item in left.items)
|
895 | 927 |
|
|
0 commit comments