@@ -2454,7 +2454,7 @@ def pointwise(
2454
2454
2455
2455
2456
2456
def _reduction_configs (
2457
- * , size_hints : dict [str , int ], inductor_meta : dict [str , Any ]
2457
+ * , size_hints : dict [str , int ], inductor_meta : dict [str , Any ], is_dynamic = False
2458
2458
) -> list [Config ]:
2459
2459
reduction_hint = inductor_meta .get ("reduction_hint" , None )
2460
2460
@@ -2507,12 +2507,40 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
2507
2507
register_intensive = register_intensive ,
2508
2508
)
2509
2509
2510
+ def make_outer_config ():
2511
+ max_x_block = 256
2512
+ load_factor = inductor_meta .get ("num_load" , 0 )
2513
+ x = size_hints ["x" ]
2514
+
2515
+ if x <= 8 * 4096 :
2516
+ x_block = 8
2517
+ else :
2518
+ x_block = min (max_x_block , next_power_of_2 (x // 4096 ))
2519
+ if x_block < 64 :
2520
+ x_block = 64
2521
+ if is_dynamic :
2522
+ # Dynamic shapes introduce a lot register pressure for indexing
2523
+ outer_r_block = (
2524
+ 1
2525
+ if load_factor >= 3
2526
+ else min (next_power_of_2 (max (rnumel , 128 ) // 128 ), 16 )
2527
+ )
2528
+ else :
2529
+ # Try to do reduction in 1 pass
2530
+ outer_r_block = min (next_power_of_2 (rnumel ), 128 )
2531
+
2532
+ if x_block * outer_r_block > 4096 :
2533
+ x_block = 2048 // outer_r_block
2534
+
2535
+ # Set register intensive to true by default as we try to maximize tiles with heuristic
2536
+ return make_config (x_block , outer_r_block , register_intensive = True )
2537
+
2510
2538
contiguous_config = make_config (
2511
2539
1 ,
2512
2540
min (rnumel , MAX_R0_BLOCK ),
2513
2541
register_intensive = register_intensive ,
2514
2542
)
2515
- outer_config = make_config ( 64 , 8 , register_intensive = register_intensive )
2543
+ outer_config = make_outer_config ( )
2516
2544
tiny_config = make_config (
2517
2545
2 * (256 // rnumel ) if rnumel <= 256 else 1 ,
2518
2546
min (rnumel , MAX_R0_BLOCK ),
@@ -2637,7 +2665,11 @@ def reduction(
2637
2665
2638
2666
assert triton_meta is not None
2639
2667
2640
- configs = _reduction_configs (size_hints = size_hints , inductor_meta = inductor_meta )
2668
+ is_dynamic = any ("ks" in k for k in triton_meta ["signature" ].keys ())
2669
+ configs = _reduction_configs (
2670
+ size_hints = size_hints , inductor_meta = inductor_meta , is_dynamic = is_dynamic
2671
+ )
2672
+
2641
2673
configs = _maybe_filter_configs_for_tma_restrictions (inductor_meta , configs )
2642
2674
return cached_autotune (
2643
2675
size_hints ,
0 commit comments