Skip to content

Commit 3178ece

Browse files
committed
outer heuristic
ghstack-source-id: c57d54f Pull Request resolved: #159093
1 parent abb0bf4 commit 3178ece

File tree

1 file changed

+32
-3
lines changed

1 file changed

+32
-3
lines changed

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2439,7 +2439,7 @@ def pointwise(
24392439

24402440

24412441
def _reduction_configs(
2442-
*, size_hints: dict[str, int], inductor_meta: dict[str, Any]
2442+
*, size_hints: dict[str, int], inductor_meta: dict[str, Any], is_dynamic=False
24432443
) -> list[Config]:
24442444
reduction_hint = inductor_meta.get("reduction_hint", None)
24452445

@@ -2491,13 +2491,40 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
24912491
num_stages=num_stages,
24922492
register_intensive=register_intensive,
24932493
)
2494+
2495+
def make_outer_config():
2496+
min_x_block, max_x_block = 8, 256
2497+
load_factor = inductor_meta.get("num_load", 0)
2498+
x = size_hints["x"]
2499+
imbalance_threshold = 256
2500+
2501+
if is_dynamic:
2502+
# Dynamic shapes introduce a lot register pressure for indexing
2503+
outer_r_block = 1 if load_factor >= 3 else min(next_power_of_2(max(rnumel, 128) // 128), 16)
2504+
max_x_block = 128
2505+
x_block = 64
2506+
else:
2507+
# Try to do reduction in 1 pass
2508+
outer_r_block = min(next_power_of_2(rnumel), 128)
2509+
# xblock * rblock shouldn't exceed 1024, maximize x_block for coalesced loads
2510+
# TODO: x_block might want to be set to the lower power of 2
2511+
x_block = min(1024 // outer_r_block, max_x_block)
2512+
2513+
x_block = max(min_x_block, x_block)
2514+
2515+
# Resolve imbalance if the grid dim of x is much larger than r
2516+
while x // x_block > (rnumel // outer_r_block) * imbalance_threshold and outer_r_block > x_block and outer_r_block > 1 and x_block < max_x_block:
2517+
x_block *= 2
2518+
outer_r_block //= 2
2519+
2520+
return make_config(x_block, outer_r_block, register_intensive=register_intensive)
24942521

24952522
contiguous_config = make_config(
24962523
1,
24972524
min(rnumel, MAX_R0_BLOCK),
24982525
register_intensive=register_intensive,
24992526
)
2500-
outer_config = make_config(64, 8, register_intensive=register_intensive)
2527+
outer_config = make_outer_config()
25012528
tiny_config = make_config(
25022529
2 * (256 // rnumel) if rnumel <= 256 else 1,
25032530
min(rnumel, MAX_R0_BLOCK),
@@ -2622,7 +2649,9 @@ def reduction(
26222649

26232650
assert triton_meta is not None
26242651

2625-
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
2652+
is_dynamic = any(["ks" in k for k in triton_meta["signature"].keys()])
2653+
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta, is_dynamic=is_dynamic)
2654+
26262655
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
26272656
return cached_autotune(
26282657
size_hints,

0 commit comments

Comments
 (0)