Skip to content

Commit 50eac81

Browse files
benjaminglass1pytorchmergebot
authored andcommitted
[typing] Constrain OrderedSet generic to be Hashable (#159684)
Ran across this typing bug while creating an OrderedSet from a type I didn't realize wasn't hashable, which failed at runtime. With this constraint, typing would've failed pre-runtime. Pull Request resolved: #159684 Approved by: https://github.com/Skylion007
1 parent 4e0f179 commit 50eac81

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

torch/_inductor/codegen/simd.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1991,7 +1991,7 @@ def collapse_ranges(ranges: Sequence[sympy.Expr]) -> sympy.Expr:
19911991
@classmethod
19921992
def create_tiling(
19931993
cls, pw_tiling: Sequence[sympy.Expr], reduction_tiling: Sequence[sympy.Expr]
1994-
) -> dict[str, sympy.Expr]:
1994+
) -> immutable_dict[str, sympy.Expr]:
19951995
"""
19961996
Create a tiling dict from pointwise and reduction splits.
19971997
"""
@@ -2006,7 +2006,7 @@ def create_partial_tiling(
20062006
cls,
20072007
tiling: Sequence[sympy.Expr],
20082008
is_pointwise: bool,
2009-
) -> dict[str, sympy.Expr]:
2009+
) -> immutable_dict[str, sympy.Expr]:
20102010
return cls.create_tiling(
20112011
tiling if is_pointwise else [],
20122012
tiling if not is_pointwise else [],
@@ -2018,7 +2018,7 @@ def complete_partial_tiling(
20182018
tiling: dict[str, sympy.Expr],
20192019
numel: sympy.Expr,
20202020
reduction_numel: sympy.Expr,
2021-
) -> dict[str, sympy.Expr]:
2021+
) -> immutable_dict[str, sympy.Expr]:
20222022
"""
20232023
Given a tiling for only pointwise or reduction dimensions, adds the missing one.
20242024
"""
@@ -2039,15 +2039,15 @@ def get_nd_tilings(
20392039
node_schedule,
20402040
pointwise_numel,
20412041
reduction_numel,
2042-
) -> list[dict[str, tuple[sympy.Expr]]]:
2042+
) -> list[immutable_dict[str, sympy.Expr]]:
20432043
"""
20442044
Creates N-dimensional tiling candidates, attempting to simplify loads/stores
20452045
by tiling the kernel into higher dimensions.
20462046
20472047
Returns a list of tilings ranked by dimensionality.
20482048
"""
20492049
is_pointwise = reduction_numel == 1
2050-
tilings = OrderedSet[dict[str, sympy.Expr]]()
2050+
tilings = OrderedSet[immutable_dict[str, sympy.Expr]]()
20512051
for node in EnableReduction.filter(node_schedule):
20522052
if not isinstance(node, scheduler.SchedulerNode):
20532053
continue
@@ -2312,7 +2312,7 @@ def process_node_vars(
23122312
)
23132313
)
23142314

2315-
tilings: list[tuple[CandidateTiling, dict[str, sympy.Expr]]] = []
2315+
tilings: list[tuple[CandidateTiling, immutable_dict[str, sympy.Expr]]] = []
23162316
for (pw_split, pw_score), (red_split, red_score) in score_split:
23172317
candidate = CandidateTiling(
23182318
cls.create_tiling(pw_split, red_split),

torch/utils/_ordered_set.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import (
4+
Hashable,
45
Iterable,
56
Iterator,
67
MutableSet,
@@ -10,8 +11,8 @@
1011
from typing import Any, cast, Optional, TypeVar
1112

1213

13-
T = TypeVar("T")
14-
T_co = TypeVar("T_co", covariant=True)
14+
T = TypeVar("T", bound=Hashable)
15+
T_co = TypeVar("T_co", bound=Hashable, covariant=True)
1516

1617
__all__ = ["OrderedSet"]
1718

0 commit comments

Comments
 (0)