Skip to content

Commit 64de1a9

Browse files
pytorchmergebotyangw-dev
authored andcommitted
Revert "[inductor] consolidate common GEMM triton param retrieval (#158015)"
This reverts commit 9faef3d. Reverted #158015 on behalf of https://github.com/henrylhtsang due to breaking tests ([comment](#158015 (comment)))
1 parent 9c917a9 commit 64de1a9

File tree

10 files changed

+454
-1228
lines changed

10 files changed

+454
-1228
lines changed

test/inductor/test_max_autotune.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,7 @@
3535
TritonTemplate,
3636
TritonTemplateCaller,
3737
)
38-
from torch._inductor.template_heuristics import (
39-
CUDAMMTemplateConfigHeuristic,
40-
GemmConfig,
41-
)
38+
from torch._inductor.template_heuristics import CUDAConfigHeuristic, GemmConfig
4239
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
4340
from torch.testing._internal.common_utils import (
4441
instantiate_parametrized_tests,
@@ -1564,9 +1561,9 @@ def f(a, b):
15641561
b = torch.randn(K, N, dtype=torch.float16, device="cuda", requires_grad=True)
15651562

15661563
with mock.patch(
1567-
"torch._inductor.template_registry.get_template_heuristic"
1564+
"torch._inductor.kernel.mm.V.choices.get_config_heuristics"
15681565
) as config_mock:
1569-
config_heuristics = CUDAMMTemplateConfigHeuristic()
1566+
config_heuristics = CUDAConfigHeuristic()
15701567

15711568
# Traditionally, this would be set of all possible configs
15721569
# We mock out the code path for the sake of the unit test

torch/_inductor/choices.py

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from . import config
1111
from .codecache import write_text
12-
from .kernel_inputs import KernelInputs # noqa: TC001
1312
from .metrics import get_metric_table, is_metric_table_enabled
1413
from .runtime.hints import DeviceProperties, ReductionHint
1514
from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse
@@ -20,7 +19,6 @@
2019
ROCmConfigHeuristic,
2120
XPUConfigHeuristic,
2221
)
23-
from .template_registry import get_template_heuristic
2422
from .virtualized import V
2523

2624

@@ -70,6 +68,58 @@ def get_config_heuristics(
7068
else:
7169
return BaseConfigHeuristic()
7270

71+
# GEMM configs
72+
def get_base_mm_configs(
73+
self, device_type: Optional[str] = "cuda"
74+
) -> partial[Generator[TritonConfig, None, None]]:
75+
mm_heuristics = self.get_config_heuristics(device_type)
76+
if config.max_autotune_gemm_search_space != "EXHAUSTIVE":
77+
return mm_heuristics.get_mm_configs()
78+
else:
79+
return mm_heuristics.get_exhaustive_mm_configs()
80+
81+
def get_extra_mm_configs(
82+
self, device_type: Optional[str] = "cuda"
83+
) -> partial[Generator[TritonConfig, None, None]]:
84+
mm_heuristics = self.get_config_heuristics(device_type)
85+
return mm_heuristics.get_extra_mm_configs()
86+
87+
def get_int8_mm_configs(
88+
self, device_type: Optional[str] = "cuda"
89+
) -> partial[Generator[TritonConfig, None, None]]:
90+
mm_heuristics = self.get_config_heuristics(device_type)
91+
return mm_heuristics.get_int8_mm_configs()
92+
93+
def get_mixed_mm_configs(
94+
self, device_type: Optional[str] = "cuda"
95+
) -> partial[Generator[TritonConfig, None, None]]:
96+
mm_heuristics = self.get_config_heuristics(device_type)
97+
return mm_heuristics.get_mixed_mm_configs()
98+
99+
def get_persistent_mm_configs(
100+
self, device_type: Optional[str] = "cuda"
101+
) -> partial[Generator[TritonConfig, None, None]]:
102+
mm_heuristics = self.get_config_heuristics(device_type)
103+
return mm_heuristics.get_persistent_mm_configs()
104+
105+
def get_scaled_mm_configs(
106+
self, device_type: Optional[str] = "cuda"
107+
) -> partial[Generator[TritonConfig, None, None]]:
108+
mm_heuristics = self.get_config_heuristics(device_type)
109+
return mm_heuristics.get_scaled_mm_configs()
110+
111+
def get_scaled_persistent_mm_configs(
112+
self, device_type: Optional[str] = "cuda"
113+
) -> partial[Generator[TritonConfig, None, None]]:
114+
mm_heuristics = self.get_config_heuristics(device_type)
115+
return mm_heuristics.get_scaled_persistent_mm_configs()
116+
117+
def get_mm_plus_mm_configs(
118+
self, device_type: Optional[str] = "cuda"
119+
) -> partial[Generator[TritonConfig, None, None]]:
120+
mm_heuristics = self.get_config_heuristics(device_type)
121+
return mm_heuristics.get_mm_plus_mm_configs()
122+
73123
# Conv configs
74124
def get_conv_configs(
75125
self, device_type: Optional[str] = "cuda"
@@ -78,7 +128,6 @@ def get_conv_configs(
78128
return conv_heuristics.get_conv_configs()
79129

80130
# Flex attention configs
81-
# TODO(coconutruben): break out flexattention/decode configs into the new retrieval mechanism
82131
def get_flex_attention_fwd_configs(
83132
self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda"
84133
) -> list[Any]:
@@ -97,37 +146,6 @@ def get_flex_decode_configs(
97146
flex_heuristics = self.get_config_heuristics(device_type)
98147
return flex_heuristics.get_flex_decode_configs(head_dim, dtype)
99148

100-
def get_mm_configs(
101-
self,
102-
kernel_inputs: KernelInputs,
103-
layout: Any,
104-
template_name: str,
105-
op_name: str,
106-
) -> Generator[dict[str, Any], None, None]:
107-
"""
108-
Get generator of template parameters for MM templates using template-specific heuristics.
109-
110-
Args:
111-
kernel_inputs: MMKernelInputs containing input tensor nodes and matrix indices
112-
layout: Output layout
113-
template_name: Template name (e.g., "bmm", "mm", "mm_persistent_tma")
114-
op_name: Operation name (e.g., "bmm", "baddbmm", "addmm", "mm_plus_mm")
115-
116-
Yields:
117-
Template parameter dictionaries ready for maybe_append_choice
118-
"""
119-
input_tensors = kernel_inputs.nodes()
120-
if len(input_tensors) < 2:
121-
raise ValueError(f"Need at least 2 input tensors, got {len(input_tensors)}")
122-
123-
# Extract device_type from kernel_inputs
124-
device_type = kernel_inputs.device_type
125-
assert device_type is not None, "get_mm_configs requires a valid device type"
126-
# Get the appropriate template-specific heuristic
127-
heuristic = get_template_heuristic(template_name, device_type)
128-
129-
yield from heuristic.get_template_configs(kernel_inputs, layout, op_name)
130-
131149
def triton_kernel_kwargs(
132150
self,
133151
kernel_cls: type[TritonKernel],

torch/_inductor/kernel/bmm.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
77

88
from .. import ir, lowering as L
9-
from ..kernel_inputs import MMKernelInputs
109
from ..select_algorithm import (
1110
autotune_select_algorithm,
1211
ExternKernelChoice,
@@ -27,6 +26,8 @@
2726
addmm_epilogue,
2827
is_batch_stride_largest,
2928
mm_args,
29+
mm_config_kwargs,
30+
mm_options,
3031
)
3132

3233

@@ -39,6 +40,13 @@ def bmm_grid(b, m, n, meta, *, cdiv):
3940
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1)
4041

4142

43+
def _is_large_block_for_cpu(m, n, k):
44+
# Thresholds are experimentally determined to reduce Triton CPU compile times
45+
if m > 128 or n > 128 or k > 128:
46+
return True
47+
return m * n > 2**12
48+
49+
4250
bmm_template = TritonTemplate(
4351
name="bmm",
4452
grid=bmm_grid,
@@ -167,14 +175,9 @@ def may_require_contiguous(t, meta_t):
167175
meta_mat2 = V.graph.current_node.args[1]
168176
mat2 = may_require_contiguous(mat2, meta_mat2)
169177

170-
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
171178
m, n, k, layout, mat1, mat2 = mm_args(
172179
mat1, mat2, layout=layout, out_dtype=out_dtype
173180
)
174-
name = "bmm"
175-
176-
# Create MMKernelInputs for BMM at the top
177-
kernel_inputs = MMKernelInputs([mat1, mat2])
178181

179182
# below is for getting an overview logging info of inductor mms
180183
batch_size = mat1.get_size()[0] # Extract batch dimension
@@ -192,65 +195,63 @@ def may_require_contiguous(t, meta_t):
192195

193196
if out_dtype:
194197
assert mat1.get_device().type == "cuda", "out_dtype is only supported for CUDA"
195-
aten_func = aten_bmm_dtype.bind(
196-
kernel_inputs.nodes(), layout, out_dtype=out_dtype
197-
)
198+
aten_func = aten_bmm_dtype.bind((mat1, mat2), layout, out_dtype=out_dtype)
198199
else:
199-
aten_func = aten_bmm.bind(kernel_inputs.nodes(), layout)
200+
aten_func = aten_bmm.bind((mat1, mat2), layout)
200201

201202
# options to tune from
202203
choices = [aten_func] if use_aten_gemm_kernels() else []
203204

205+
device_type = ir.get_device_type(mat1)
206+
bmm_configs = V.choices.get_base_mm_configs(device_type)
207+
208+
dtype = mat1.get_dtype()
204209
if use_triton_template(layout):
205210
# TODO: add out_dtype support for Triton Template
206211
assert out_dtype is None, "out_dtype is not supported for Triton"
207-
208-
for kwargs in V.choices.get_mm_configs(
209-
kernel_inputs, layout, bmm_template.name, name
212+
for config in bmm_configs(
213+
m,
214+
n,
215+
k,
216+
**mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize),
210217
):
211218
bmm_template.maybe_append_choice(
212219
choices,
213-
input_nodes=kernel_inputs.nodes(),
220+
input_nodes=(mat1, mat2),
214221
layout=layout,
215-
**kwargs,
222+
**mm_options(config, m, n, k, layout),
216223
)
217224
_, is_nonzero = _is_static_problem(layout)
218225
batch_stride_largest = is_batch_stride_largest(mat1, mat2, layout)
219226
if (
220227
batch_stride_largest
221228
and is_nonzero
222229
and use_cutlass_template(layout, m, n, k)
223-
and _use_cutlass_for_op(name)
230+
and _use_cutlass_for_op("bmm")
224231
):
225232
from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate
226233

227-
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
228-
choices, layout, kernel_inputs.nodes()
229-
) # type: ignore[arg-type]
234+
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) # type: ignore[arg-type]
230235

231236
if use_cpp_bmm_template(layout, mat1, mat2):
232237
from ..codegen.cpp_bmm_template import CppBmmTemplate
233238

234239
CppBmmTemplate.add_choices(
235240
choices,
236241
layout,
237-
kernel_inputs.nodes(),
242+
[mat1, mat2],
238243
)
239244

240245
if use_ck_gemm_template(layout, m, n, k):
241-
CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes())
246+
CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2])
242247

243-
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
248+
return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout)
244249

245250

246251
@L.register_lowering(aten.baddbmm)
247252
def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
248-
# TODO(coconutruben): integrate into MMKernelInputs when all callsites use that
249253
m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout)
250254

251-
# Create MMKernelInputs for BadDBMM at the top
252-
kernel_inputs = MMKernelInputs([inp, mat1, mat2])
253-
254255
# below is for getting an overview logging info of inductor mms
255256
batch_size = mat1.get_size()[0]
256257
counters["aten_mm_info"][f"aten.baddbmm_{batch_size}_{m}_{n}_{k}"] += 1
@@ -265,26 +266,29 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
265266
inp.get_dtype(),
266267
layout,
267268
)
268-
name = "baddbmm"
269+
269270
# options to tune from
270271
choices = (
271-
[aten_baddbmm.bind(kernel_inputs.nodes(), layout, alpha=alpha, beta=beta)]
272+
[aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)]
272273
if use_aten_gemm_kernels()
273274
else []
274275
)
275276

277+
device_type = ir.get_device_type(mat1)
278+
bmm_configs = V.choices.get_base_mm_configs(device_type)
279+
276280
if use_triton_template(layout):
277-
for kwargs in V.choices.get_mm_configs(
278-
kernel_inputs, layout, bmm_template.name, name
281+
for config in bmm_configs(
282+
m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu)
279283
):
280284
bmm_template.maybe_append_choice(
281285
choices,
282-
input_nodes=kernel_inputs.nodes(),
286+
input_nodes=(inp, mat1, mat2),
283287
layout=layout,
284-
**kwargs,
288+
**mm_options(config, m, n, k, layout),
285289
prefix_args=1,
286290
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
287291
epilogue_fn_hash=str(["addmm_epilogue", layout.dtype, alpha, beta]),
288292
)
289293

290-
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
294+
return autotune_select_algorithm("baddbmm", choices, [inp, mat1, mat2], layout)

torch/_inductor/kernel/conv.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
use_triton_template,
3030
)
3131
from ..virtualized import V
32+
from .mm_common import mm_config_kwargs
3233

3334

3435
if TYPE_CHECKING:
@@ -60,6 +61,13 @@ def conv3d_grid(n, c, d, h, w, meta, *, cdiv):
6061
)
6162

6263

64+
def _is_large_block_for_cpu(m, n, k):
65+
# Thresholds are experimentally determined to reduce Triton CPU compile times
66+
if m > 256 or n > 256 or k > 256:
67+
return True
68+
return m * n * k > 2**17
69+
70+
6371
LOOP_BODY_2D = """
6472
idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
6573
idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W
@@ -595,6 +603,7 @@ def channels_last_conv():
595603
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
596604
out_chan,
597605
in_chan,
606+
**mm_config_kwargs(device_type, _is_large_block_for_cpu),
598607
):
599608
if ndim == 2:
600609
conv2d_template.maybe_append_choice(

0 commit comments

Comments
 (0)