Skip to content

Commit 2d58dd4

Browse files
committed
WIP add support for dynamic shapes
ghstack-source-id: 6e66931 Pull Request resolved: #155557
1 parent 908c5cc commit 2d58dd4

File tree

5 files changed

+49
-39
lines changed

5 files changed

+49
-39
lines changed

test/inductor/test_loop_ordering.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,9 @@ def forward(permute):
896896
arg0_1 = torch.randn([XDIM, YDIM], device=GPU_TYPE, dtype=torch.bfloat16)
897897
permute = torch.ops.aten.permute.default(arg0_1, [1, 0])
898898

899-
out, code = run_and_get_code(torch.compile(forward), (permute))
899+
out, code = run_and_get_code(
900+
torch.compile(forward, dynamic=True), (permute)
901+
)
900902

901903
self.assertEqual(out, forward(permute))
902904
FileCheck().check("YBLOCK").check("XBLOCK").run(code[0])
@@ -937,12 +939,13 @@ def T(self, layout: str):
937939

938940
@parametrize("a", layouts)
939941
@parametrize("b", layouts)
940-
def test_pointwise(self, a, b):
942+
@parametrize("dynamic", (False, True))
943+
def test_pointwise(self, a, b, dynamic):
941944
def foo(x, y):
942945
return x + y
943946

944947
x, y = self.T(a), self.T(b)
945-
res, code = run_and_get_code(torch.compile(foo), x, y)
948+
res, code = run_and_get_code(torch.compile(foo, dynamic=dynamic), x, y)
946949

947950
if a != b:
948951
FileCheck().check("ynumel").run(code[0])
@@ -968,13 +971,14 @@ def f(a, b):
968971
).run(code[0])
969972
self.assertEqual(out, f(*inps), atol=0.001, rtol=0.04)
970973

971-
def test_3d_pointwise(self):
974+
@parametrize("dynamic", (False, True))
975+
def test_3d_pointwise(self, dynamic):
972976
inps = (self.T("cont"), self.T("T"), self.T("NHWC"))
973977

974978
def f(x, y, z):
975979
return x + y + z
976980

977-
f_c = torch.compile(f)
981+
f_c = torch.compile(f, dynamic=dynamic)
978982
out, code = run_and_get_code(f_c, *inps)
979983

980984
FileCheck().check_dag("znumel").check_dag("ynumel").check_dag("xnumel").run(

torch/_inductor/codegen/simd.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2179,12 +2179,15 @@ def compute_tiling_strategy(
21792179
pw_ranges = [ranges[v] for v in all_iter_vars]
21802180
red_ranges = [ranges[v] for v in all_red_vars]
21812181

2182+
def check_eq(a, b):
2183+
return V.graph.sizevars.atomically_apply_size_hint(a - b, fallback=32) == 0
2184+
21822185
torch._check(
2183-
sympy_product(pw_ranges) == pointwise_numel,
2186+
check_eq(sympy_product(pw_ranges), pointwise_numel),
21842187
lambda: f"{pw_ranges}, {pointwise_numel}, {node_schedule}",
21852188
)
21862189
torch._check(
2187-
sympy_product(red_ranges) == reduction_numel,
2190+
check_eq(sympy_product(red_ranges), reduction_numel),
21882191
lambda: f"{red_ranges}, {reduction_numel}, {node_schedule}",
21892192
)
21902193

@@ -2331,7 +2334,7 @@ def process_node_vars(
23312334
def score_mod(t):
23322335
score_factor = 1.0
23332336
for tile_size in t[0].tiling.values():
2334-
if not CandidateTiling.is_good_size(tile_size):
2337+
if not CandidateTiling.is_good_size(tile_size, size_hint=False):
23352338
score_factor = score_factor / bad_size_additional_tiling_penalty
23362339
else:
23372340
score_factor = score_factor / good_size_tiling_penalty
@@ -2588,8 +2591,12 @@ class CandidateTiling:
25882591
name: Optional[str] = None
25892592

25902593
@staticmethod
2591-
def is_good_size(s):
2594+
def is_good_size(s, size_hint=True):
25922595
"""Somewhat arbitrary heuristic used to boost scores for some sizes"""
2596+
sv = V.graph.sizevars
2597+
if not size_hint:
2598+
return sv.statically_known_multiple_of(s, 32) and sv.statically_known_geq(s, 32)
2599+
25932600
s = V.graph.sizevars.size_hint(s)
25942601
return s >= 32 and (s % 32 == 0)
25952602

torch/_inductor/sizevars.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -380,13 +380,6 @@ def statically_known_multiple_of(
380380
"""
381381
Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator.
382382
"""
383-
# The reason we skip compute here is to avoid the cost of trying to eval this symbolically.
384-
# see https://github.com/sympy/sympy/issues/28200
385-
if has_free_unbacked_symbols(numerator) or has_free_unbacked_symbols(
386-
denominator
387-
):
388-
return False
389-
390383
if len(free_symbols(numerator)) > 20:
391384
return False
392385

torch/_inductor/tiling_utils.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -221,13 +221,17 @@ def get_pw_red_splits(
221221
red_numel: sympy.Expr,
222222
none_if_not_divisible: bool = False,
223223
) -> Optional[tuple[VarsAndRanges, VarsAndRanges]]:
224-
if n.is_reduction() or sympy_product(n._body.sizes[0]) == pointwise_numel:
224+
n_pointwise_numel = V.graph.sizevars.simplify(sympy_product(n._body.sizes[0]))
225+
if n.is_reduction() or n_pointwise_numel == pointwise_numel:
225226
return (
226227
(n._body.iter_vars, n._body.sizes[0]),
227228
(n._body.reduce_vars, n._body.sizes[1]),
228229
) # type: ignore[return-value]
229230

230-
assert sympy_product(n._body.sizes[0]) == pointwise_numel * red_numel # type: ignore[operator]
231+
assert V.graph.sizevars.atomically_apply_size_hint(
232+
n_pointwise_numel - (pointwise_numel * red_numel), fallback=config.unbacked_symint_fallback
233+
) == 0
234+
231235
i = len(n._body.sizes[0]) - 1
232236
prod = 1
233237
while i >= 0:
@@ -319,6 +323,7 @@ def get_node_splits(self) -> tuple[Split, Split]:
319323

320324
if len(self.all_node_sizes) == 1:
321325
return next(iter(self.all_node_sizes))
326+
# TODO - return default pointwise, reduction
322327

323328
max_pw_split = max(self.pw_split_options.keys())
324329
for pw_split_len in range(max_pw_split, 0, -1):
@@ -478,13 +483,6 @@ def extract_normalized_read_writes(
478483
pointwise_numel: sympy.Expr = node.group[1][0]
479484
red_numel: sympy.Expr = node.group[1][1]
480485

481-
# TODO - a few dynamic shapes issues to resolve
482-
if any(
483-
(isinstance(var, sympy.Expr) and not var.is_constant())
484-
for var in (pointwise_numel, red_numel)
485-
):
486-
return None
487-
488486
pw_splits, red_splits = NodeSplitGetter(node).get_node_splits()
489487

490488
# lets use different prefix (`n`) to distinguish
@@ -663,13 +661,8 @@ def analyze_memory_coalescing(
663661
((True, item) for item in reads.items()),
664662
((False, item) for item in writes.items()),
665663
):
666-
# skip memory deps with indirect vars - todo: better handling
667-
indirect_expr = bool(
668-
memory_expr.free_symbols - norm_read_writes.var_ranges.keys()
669-
)
670-
671-
if indirect_expr:
672-
continue
664+
# TODO skip memory deps with indirect vars
665+
# handled in extract_normalized_read_writes currently
673666

674667
size = get_score(memory_expr, var_ranges)
675668
if size == 0:
@@ -699,8 +692,8 @@ def analyze_memory_coalescing(
699692
tiling_scores: dict[sympy.Expr, dict[int, int]] = defaultdict(Counter)
700693

701694
for uncoalesced_expr, addr_score in uncoalesced_addrs.items():
702-
expr_subs = dict.fromkeys(uncoalesced_expr.free_symbols, 0)
703-
for v in uncoalesced_expr.free_symbols:
695+
expr_subs = dict.fromkeys(var_ranges.keys(), 0)
696+
for v in uncoalesced_expr.free_symbols & var_ranges.keys():
704697
# skip non iter/reduce var variables
705698
if v not in var_ranges:
706699
continue
@@ -710,7 +703,13 @@ def analyze_memory_coalescing(
710703
del expr_subs[v]
711704
single_var_expr = sympy_subs(uncoalesced_expr, expr_subs)
712705
expr_subs[v] = 0
706+
707+
# TODO: skip dynamic shapes for now,
708+
if len(single_var_expr.free_symbols) != 1:
709+
continue
710+
713711
tiling_factor = solve_for_tiling(single_var_expr)
712+
714713
if (
715714
tiling_factor is None
716715
or not tiling_factor.is_constant()

torch/utils/_sympy/functions.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,9 @@ def eval(
237237
return base
238238
if base.is_integer and equal_valued(divisor, -1):
239239
return sympy.Mul(base, -1)
240+
if base is divisor:
241+
return sympy.S.One
242+
240243
if (
241244
isinstance(base, sympy.Number)
242245
and isinstance(divisor, sympy.Number)
@@ -324,11 +327,15 @@ def eval(
324327
if divisor != 1:
325328
gcd = sympy.gcd(base, divisor)
326329
if gcd != 1:
327-
return ModularIndexing(
328-
sympy.simplify(base / gcd),
329-
sympy.simplify(divisor / gcd),
330-
modulus,
331-
)
330+
try:
331+
return ModularIndexing(
332+
sympy.simplify(base / gcd),
333+
sympy.simplify(divisor / gcd),
334+
modulus,
335+
)
336+
except Exception:
337+
breakpoint()
338+
raise
332339
except sympy.PolynomialError:
333340
pass # https://github.com/pytorch/pytorch/issues/108276
334341

0 commit comments

Comments
 (0)