@@ -2439,7 +2439,7 @@ def pointwise(
2439
2439
2440
2440
2441
2441
def _reduction_configs (
2442
- * , size_hints : dict [str , int ], inductor_meta : dict [str , Any ]
2442
+ * , size_hints : dict [str , int ], inductor_meta : dict [str , Any ], is_dynamic = False
2443
2443
) -> list [Config ]:
2444
2444
reduction_hint = inductor_meta .get ("reduction_hint" , None )
2445
2445
@@ -2491,13 +2491,40 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
2491
2491
num_stages = num_stages ,
2492
2492
register_intensive = register_intensive ,
2493
2493
)
2494
+
2495
+ def make_outer_config ():
2496
+ min_x_block , max_x_block = 8 , 256
2497
+ load_factor = inductor_meta .get ("num_load" , 0 )
2498
+ x = size_hints ["x" ]
2499
+ imbalance_threshold = 256
2500
+
2501
+ if is_dynamic :
2502
+ # Dynamic shapes introduce a lot register pressure for indexing
2503
+ outer_r_block = 1 if load_factor >= 3 else min (next_power_of_2 (max (rnumel , 128 ) // 128 ), 16 )
2504
+ max_x_block = 128
2505
+ x_block = 64
2506
+ else :
2507
+ # Try to do reduction in 1 pass
2508
+ outer_r_block = min (next_power_of_2 (rnumel ), 128 )
2509
+ # xblock * rblock shouldn't exceed 1024, maximize x_block for coalesced loads
2510
+ # TODO: x_block might want to be set to the lower power of 2
2511
+ x_block = min (1024 // outer_r_block , max_x_block )
2512
+
2513
+ x_block = max (min_x_block , x_block )
2514
+
2515
+ # Resolve imbalance if the grid dim of x is much larger than r
2516
+ while x // x_block > (rnumel // outer_r_block ) * imbalance_threshold and outer_r_block > x_block and outer_r_block > 1 and x_block < max_x_block :
2517
+ x_block *= 2
2518
+ outer_r_block //= 2
2519
+
2520
+ return make_config (x_block , outer_r_block , register_intensive = register_intensive )
2494
2521
2495
2522
contiguous_config = make_config (
2496
2523
1 ,
2497
2524
min (rnumel , MAX_R0_BLOCK ),
2498
2525
register_intensive = register_intensive ,
2499
2526
)
2500
- outer_config = make_config ( 64 , 8 , register_intensive = register_intensive )
2527
+ outer_config = make_outer_config ( )
2501
2528
tiny_config = make_config (
2502
2529
2 * (256 // rnumel ) if rnumel <= 256 else 1 ,
2503
2530
min (rnumel , MAX_R0_BLOCK ),
@@ -2622,7 +2649,9 @@ def reduction(
2622
2649
2623
2650
assert triton_meta is not None
2624
2651
2625
- configs = _reduction_configs (size_hints = size_hints , inductor_meta = inductor_meta )
2652
+ is_dynamic = any (["ks" in k for k in triton_meta ["signature" ].keys ()])
2653
+ configs = _reduction_configs (size_hints = size_hints , inductor_meta = inductor_meta , is_dynamic = is_dynamic )
2654
+
2626
2655
configs = _maybe_filter_configs_for_tma_restrictions (inductor_meta , configs )
2627
2656
return cached_autotune (
2628
2657
size_hints ,
0 commit comments