From 863da89bd1f3bbb2aff7efba65d85e740d164b85 Mon Sep 17 00:00:00 2001 From: Hugues Bruant Date: Sat, 10 Dec 2022 15:19:21 -0800 Subject: [PATCH] subtypes: fast path for Union/Union subtype check Enums are exploded into Union of Literal when narrowed. Conditional branches on enum values can result in multiple distinct narrowing of the same enum which are later subject to subtype checks (most notably via `is_same_type`, when exiting frame context in the binder). Such checks would have quadratic complexity: `O(N*M)` where `N` and `M` are the number of entries in each narrowed enum variable, and led to drastic slowdown if any of the enums involved has a large number of valuees. Implemement a linear-time fast path where literals are quickly filtered, with a fallback to the slow path for more complex values. In our codebase there is one method with a chain of a dozen if statements operating on instances of an enum with a hundreds of values. Prior to the regression it was typechecked in less than 1s. After the regression it takes over 13min to typecheck. This patch fully fixes the regression for us. Fixes #13821 --- mypy/subtypes.py | 30 ++++++++++++++++++++++++++++++ mypy/types.py | 9 +++++++++ 2 files changed, 39 insertions(+) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 994c4081addd..61ba7af5147f 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -57,6 +57,7 @@ UninhabitedType, UnionType, UnpackType, + _flattened, get_proper_type, is_named_instance, ) @@ -891,6 +892,35 @@ def visit_union_type(self, left: UnionType) -> bool: if not self._is_subtype(item, self.orig_right): return False return True + + elif isinstance(self.right, UnionType): + # prune literals early to avoid nasty quadratic behavior which would otherwise arise when checking + # subtype relationships between slightly different narrowings of an Enum + # we achieve O(N+M) instead of O(N*M) + + fast_check: set[ProperType] = set() + + for item in _flattened(self.right.relevant_items()): + p_item = get_proper_type(item) + if isinstance(p_item, LiteralType): + fast_check.add(p_item) + elif isinstance(p_item, Instance): + if p_item.last_known_value is None: + fast_check.add(p_item) + else: + fast_check.add(p_item.last_known_value) + + for item in left.relevant_items(): + p_item = get_proper_type(item) + if p_item in fast_check: + continue + lit_type = mypy.typeops.simple_literal_type(p_item) + if lit_type in fast_check: + continue + if not self._is_subtype(item, self.orig_right): + return False + return True + return all(self._is_subtype(item, self.orig_right) for item in left.items) def visit_partial_type(self, left: PartialType) -> bool: diff --git a/mypy/types.py b/mypy/types.py index 86a700d52469..9dc4ac4c7596 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -3343,6 +3343,15 @@ def has_recursive_types(typ: Type) -> bool: return typ.accept(_has_recursive_type) +def _flattened(types: Iterable[Type]) -> Iterable[Type]: + for t in types: + tp = get_proper_type(t) + if isinstance(tp, UnionType): + yield from _flattened(tp.items) + else: + yield t + + def flatten_nested_unions( types: Iterable[Type], handle_type_alias_type: bool = True ) -> list[Type]: