6
6
from torch ._inductor .codegen .rocm .ck_universal_gemm_template import CKGemmTemplate
7
7
8
8
from .. import ir , lowering as L
9
- from ..kernel_inputs import MMKernelInputs
10
9
from ..select_algorithm import (
11
10
autotune_select_algorithm ,
12
11
ExternKernelChoice ,
27
26
addmm_epilogue ,
28
27
is_batch_stride_largest ,
29
28
mm_args ,
29
+ mm_config_kwargs ,
30
+ mm_options ,
30
31
)
31
32
32
33
@@ -39,6 +40,13 @@ def bmm_grid(b, m, n, meta, *, cdiv):
39
40
return (cdiv (m , meta ["BLOCK_M" ]) * cdiv (n , meta ["BLOCK_N" ]), b , 1 )
40
41
41
42
43
+ def _is_large_block_for_cpu (m , n , k ):
44
+ # Thresholds are experimentally determined to reduce Triton CPU compile times
45
+ if m > 128 or n > 128 or k > 128 :
46
+ return True
47
+ return m * n > 2 ** 12
48
+
49
+
42
50
bmm_template = TritonTemplate (
43
51
name = "bmm" ,
44
52
grid = bmm_grid ,
@@ -167,14 +175,9 @@ def may_require_contiguous(t, meta_t):
167
175
meta_mat2 = V .graph .current_node .args [1 ]
168
176
mat2 = may_require_contiguous (mat2 , meta_mat2 )
169
177
170
- # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
171
178
m , n , k , layout , mat1 , mat2 = mm_args (
172
179
mat1 , mat2 , layout = layout , out_dtype = out_dtype
173
180
)
174
- name = "bmm"
175
-
176
- # Create MMKernelInputs for BMM at the top
177
- kernel_inputs = MMKernelInputs ([mat1 , mat2 ])
178
181
179
182
# below is for getting an overview logging info of inductor mms
180
183
batch_size = mat1 .get_size ()[0 ] # Extract batch dimension
@@ -192,65 +195,63 @@ def may_require_contiguous(t, meta_t):
192
195
193
196
if out_dtype :
194
197
assert mat1 .get_device ().type == "cuda" , "out_dtype is only supported for CUDA"
195
- aten_func = aten_bmm_dtype .bind (
196
- kernel_inputs .nodes (), layout , out_dtype = out_dtype
197
- )
198
+ aten_func = aten_bmm_dtype .bind ((mat1 , mat2 ), layout , out_dtype = out_dtype )
198
199
else :
199
- aten_func = aten_bmm .bind (kernel_inputs . nodes ( ), layout )
200
+ aten_func = aten_bmm .bind (( mat1 , mat2 ), layout )
200
201
201
202
# options to tune from
202
203
choices = [aten_func ] if use_aten_gemm_kernels () else []
203
204
205
+ device_type = ir .get_device_type (mat1 )
206
+ bmm_configs = V .choices .get_base_mm_configs (device_type )
207
+
208
+ dtype = mat1 .get_dtype ()
204
209
if use_triton_template (layout ):
205
210
# TODO: add out_dtype support for Triton Template
206
211
assert out_dtype is None , "out_dtype is not supported for Triton"
207
-
208
- for kwargs in V .choices .get_mm_configs (
209
- kernel_inputs , layout , bmm_template .name , name
212
+ for config in bmm_configs (
213
+ m ,
214
+ n ,
215
+ k ,
216
+ ** mm_config_kwargs (device_type , _is_large_block_for_cpu , dtype .itemsize ),
210
217
):
211
218
bmm_template .maybe_append_choice (
212
219
choices ,
213
- input_nodes = kernel_inputs . nodes ( ),
220
+ input_nodes = ( mat1 , mat2 ),
214
221
layout = layout ,
215
- ** kwargs ,
222
+ ** mm_options ( config , m , n , k , layout ) ,
216
223
)
217
224
_ , is_nonzero = _is_static_problem (layout )
218
225
batch_stride_largest = is_batch_stride_largest (mat1 , mat2 , layout )
219
226
if (
220
227
batch_stride_largest
221
228
and is_nonzero
222
229
and use_cutlass_template (layout , m , n , k )
223
- and _use_cutlass_for_op (name )
230
+ and _use_cutlass_for_op ("bmm" )
224
231
):
225
232
from ..codegen .cuda .gemm_template import CUTLASS3xGemmTemplate
226
233
227
- CUTLASS3xGemmTemplate .add_cutlass_gemm_choices (
228
- choices , layout , kernel_inputs .nodes ()
229
- ) # type: ignore[arg-type]
234
+ CUTLASS3xGemmTemplate .add_cutlass_gemm_choices (choices , layout , [mat1 , mat2 ]) # type: ignore[arg-type]
230
235
231
236
if use_cpp_bmm_template (layout , mat1 , mat2 ):
232
237
from ..codegen .cpp_bmm_template import CppBmmTemplate
233
238
234
239
CppBmmTemplate .add_choices (
235
240
choices ,
236
241
layout ,
237
- kernel_inputs . nodes () ,
242
+ [ mat1 , mat2 ] ,
238
243
)
239
244
240
245
if use_ck_gemm_template (layout , m , n , k ):
241
- CKGemmTemplate .add_ck_gemm_choices (choices , layout , kernel_inputs . nodes () )
246
+ CKGemmTemplate .add_ck_gemm_choices (choices , layout , [ mat1 , mat2 ] )
242
247
243
- return autotune_select_algorithm (name , choices , kernel_inputs . nodes () , layout )
248
+ return autotune_select_algorithm ("bmm" , choices , [ mat1 , mat2 ] , layout )
244
249
245
250
246
251
@L .register_lowering (aten .baddbmm )
247
252
def tuned_baddbmm (inp , mat1 , mat2 , * , alpha = 1 , beta = 1 , layout = None ):
248
- # TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
249
253
m , n , k , layout , mat1 , mat2 , inp = mm_args (mat1 , mat2 , inp , layout = layout )
250
254
251
- # Create MMKernelInputs for BadDBMM at the top
252
- kernel_inputs = MMKernelInputs ([inp , mat1 , mat2 ])
253
-
254
255
# below is for getting an overview logging info of inductor mms
255
256
batch_size = mat1 .get_size ()[0 ]
256
257
counters ["aten_mm_info" ][f"aten.baddbmm_{ batch_size } _{ m } _{ n } _{ k } " ] += 1
@@ -265,26 +266,29 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
265
266
inp .get_dtype (),
266
267
layout ,
267
268
)
268
- name = "baddbmm"
269
+
269
270
# options to tune from
270
271
choices = (
271
- [aten_baddbmm .bind (kernel_inputs . nodes ( ), layout , alpha = alpha , beta = beta )]
272
+ [aten_baddbmm .bind (( inp , mat1 , mat2 ), layout , alpha = alpha , beta = beta )]
272
273
if use_aten_gemm_kernels ()
273
274
else []
274
275
)
275
276
277
+ device_type = ir .get_device_type (mat1 )
278
+ bmm_configs = V .choices .get_base_mm_configs (device_type )
279
+
276
280
if use_triton_template (layout ):
277
- for kwargs in V . choices . get_mm_configs (
278
- kernel_inputs , layout , bmm_template . name , name
281
+ for config in bmm_configs (
282
+ m , n , k , ** mm_config_kwargs ( device_type , _is_large_block_for_cpu )
279
283
):
280
284
bmm_template .maybe_append_choice (
281
285
choices ,
282
- input_nodes = kernel_inputs . nodes ( ),
286
+ input_nodes = ( inp , mat1 , mat2 ),
283
287
layout = layout ,
284
- ** kwargs ,
288
+ ** mm_options ( config , m , n , k , layout ) ,
285
289
prefix_args = 1 ,
286
290
epilogue_fn = addmm_epilogue (layout .dtype , alpha , beta ),
287
291
epilogue_fn_hash = str (["addmm_epilogue" , layout .dtype , alpha , beta ]),
288
292
)
289
293
290
- return autotune_select_algorithm (name , choices , kernel_inputs . nodes () , layout )
294
+ return autotune_select_algorithm ("baddbmm" , choices , [ inp , mat1 , mat2 ] , layout )
0 commit comments