Skip to content

Commit 48733bb

Browse files
committed
outer heuristic
ghstack-source-id: e1d6fa6 Pull Request resolved: #159093
1 parent 2259dbe commit 48733bb

File tree

1 file changed

+35
-3
lines changed

1 file changed

+35
-3
lines changed

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2454,7 +2454,7 @@ def pointwise(
24542454

24552455

24562456
def _reduction_configs(
2457-
*, size_hints: dict[str, int], inductor_meta: dict[str, Any]
2457+
*, size_hints: dict[str, int], inductor_meta: dict[str, Any], is_dynamic=False
24582458
) -> list[Config]:
24592459
reduction_hint = inductor_meta.get("reduction_hint", None)
24602460

@@ -2507,12 +2507,40 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
25072507
register_intensive=register_intensive,
25082508
)
25092509

2510+
def make_outer_config():
2511+
max_x_block = 256
2512+
load_factor = inductor_meta.get("num_load", 0)
2513+
x = size_hints["x"]
2514+
2515+
if x <= 8 * 4096:
2516+
x_block = 8
2517+
else:
2518+
x_block = min(max_x_block, next_power_of_2(x // 4096))
2519+
if x_block < 64:
2520+
x_block = 64
2521+
if is_dynamic:
2522+
# Dynamic shapes introduce a lot register pressure for indexing
2523+
outer_r_block = (
2524+
1
2525+
if load_factor >= 3
2526+
else min(next_power_of_2(max(rnumel, 128) // 128), 16)
2527+
)
2528+
else:
2529+
# Try to do reduction in 1 pass
2530+
outer_r_block = min(next_power_of_2(rnumel), 128)
2531+
2532+
if x_block * outer_r_block > 4096:
2533+
x_block = 2048 // outer_r_block
2534+
2535+
# Set register intensive to true by default as we try to maximize tiles with heuristic
2536+
return make_config(x_block, outer_r_block, register_intensive=True)
2537+
25102538
contiguous_config = make_config(
25112539
1,
25122540
min(rnumel, MAX_R0_BLOCK),
25132541
register_intensive=register_intensive,
25142542
)
2515-
outer_config = make_config(64, 8, register_intensive=register_intensive)
2543+
outer_config = make_outer_config()
25162544
tiny_config = make_config(
25172545
2 * (256 // rnumel) if rnumel <= 256 else 1,
25182546
min(rnumel, MAX_R0_BLOCK),
@@ -2637,7 +2665,11 @@ def reduction(
26372665

26382666
assert triton_meta is not None
26392667

2640-
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
2668+
is_dynamic = any("ks" in k for k in triton_meta["signature"].keys())
2669+
configs = _reduction_configs(
2670+
size_hints=size_hints, inductor_meta=inductor_meta, is_dynamic=is_dynamic
2671+
)
2672+
26412673
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
26422674
return cached_autotune(
26432675
size_hints,

0 commit comments

Comments
 (0)