diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 994c4081addd..61ba7af5147f 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -57,6 +57,7 @@ UninhabitedType, UnionType, UnpackType, + _flattened, get_proper_type, is_named_instance, ) @@ -891,6 +892,35 @@ def visit_union_type(self, left: UnionType) -> bool: if not self._is_subtype(item, self.orig_right): return False return True + + elif isinstance(self.right, UnionType): + # prune literals early to avoid nasty quadratic behavior which would otherwise arise when checking + # subtype relationships between slightly different narrowings of an Enum + # we achieve O(N+M) instead of O(N*M) + + fast_check: set[ProperType] = set() + + for item in _flattened(self.right.relevant_items()): + p_item = get_proper_type(item) + if isinstance(p_item, LiteralType): + fast_check.add(p_item) + elif isinstance(p_item, Instance): + if p_item.last_known_value is None: + fast_check.add(p_item) + else: + fast_check.add(p_item.last_known_value) + + for item in left.relevant_items(): + p_item = get_proper_type(item) + if p_item in fast_check: + continue + lit_type = mypy.typeops.simple_literal_type(p_item) + if lit_type in fast_check: + continue + if not self._is_subtype(item, self.orig_right): + return False + return True + return all(self._is_subtype(item, self.orig_right) for item in left.items) def visit_partial_type(self, left: PartialType) -> bool: diff --git a/mypy/types.py b/mypy/types.py index 86a700d52469..9dc4ac4c7596 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -3343,6 +3343,15 @@ def has_recursive_types(typ: Type) -> bool: return typ.accept(_has_recursive_type) +def _flattened(types: Iterable[Type]) -> Iterable[Type]: + for t in types: + tp = get_proper_type(t) + if isinstance(tp, UnionType): + yield from _flattened(tp.items) + else: + yield t + + def flatten_nested_unions( types: Iterable[Type], handle_type_alias_type: bool = True ) -> list[Type]: