Skip to content

Commit 83b4126

Browse files
committed
[kernacle] add support for addmm and bmm
Title says it all Differential Revision: [D79940195](https://our.internmc.facebook.com/intern/diff/D79940195/) ghstack-source-id: 301861447 Pull Request resolved: #160239
1 parent 24257f5 commit 83b4126

File tree

4 files changed

+64
-31
lines changed

4 files changed

+64
-31
lines changed

torch/_inductor/kernel/bmm.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from torch._dynamo.utils import counters
66
from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
7+
from torch._inductor.remote_gemm_autotune_cache import gen_best_config
78

89
from .. import ir, lowering as L
910
from ..kernel_inputs import MMKernelInputs
@@ -240,7 +241,17 @@ def may_require_contiguous(t, meta_t):
240241
if use_ck_gemm_template(layout, m, n, k):
241242
CKGemmTemplate.add_ck_gemm_choices(choices, layout, kernel_inputs.nodes())
242243

243-
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)
244+
best_config_future = None
245+
if torch._inductor.config.remote_gemm_autotune_cache:
246+
best_config_future = gen_best_config("bmm", [mat1, mat2])
247+
248+
return autotune_select_algorithm(
249+
name,
250+
choices,
251+
kernel_inputs.nodes(),
252+
layout,
253+
best_config_future=best_config_future,
254+
)
244255

245256

246257
@L.register_lowering(aten.baddbmm)

torch/_inductor/kernel/mm.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
841841
# Purposely not awaiting the future here - this kicks off the best config lookup at lowering time
842842
# The future will be awaited at scheduling time in select_algorithm.py
843843
if torch._inductor.config.remote_gemm_autotune_cache:
844-
best_config_future = gen_best_config(mat1, mat2)
844+
best_config_future = gen_best_config("mm", [mat1, mat2])
845845

846846
return autotune_select_algorithm(
847847
name,
@@ -946,13 +946,19 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
946946
if use_aten_gemm_kernels()
947947
else []
948948
)
949+
950+
best_config_future = None
951+
if torch._inductor.config.remote_gemm_autotune_cache:
952+
best_config_future = gen_best_config("addmm", [mat1, mat2, inp])
953+
954+
# TODO(coconutruben): replace with kernel_inputs.nodes()
955+
# once that supports the unexpanded nodes as well
949956
return autotune_select_algorithm(
950-
# TODO(coconutruben): replace with kernel_inputs.nodes()
951-
# once that supports the unexpanded nodes as well
952957
"addmm",
953958
choices,
954959
[inp, mat1, mat2],
955960
layout,
961+
best_config_future=best_config_future,
956962
)
957963

958964
choices = (
@@ -1055,7 +1061,19 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
10551061
has_bias=True,
10561062
)
10571063

1058-
return autotune_select_algorithm("addmm", choices, kernel_inputs.nodes(), layout)
1064+
best_config_future = None
1065+
if torch._inductor.config.remote_gemm_autotune_cache:
1066+
best_config_future = gen_best_config(
1067+
"addmm", [mat1, mat2, inp], alpha=alpha, beta=beta
1068+
)
1069+
1070+
return autotune_select_algorithm(
1071+
"addmm",
1072+
choices,
1073+
kernel_inputs.nodes(),
1074+
layout,
1075+
best_config_future=best_config_future,
1076+
)
10591077

10601078

10611079
@register_lowering(aten._sparse_semi_structured_mm, type_promotion_kind=None)
Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from typing import TypeVar
2+
from typing import Any, TypeVar
33

44
import torch._inductor.config as config
55
from torch._inductor import ir
@@ -8,13 +8,15 @@
88
_T = TypeVar("_T")
99

1010

11-
def gen_best_config(mat1: ir.StorageBox, mat2: ir.StorageBox) -> asyncio.Task[_T]:
11+
def gen_best_config(
12+
mm_type: str, mats: list[ir.StorageBox], **kwargs: Any
13+
) -> asyncio.Task[_T]:
1214
"""
1315
Generate the best GEMM autotune config for the given matrices.
1416
"""
1517
if config.is_fbcode():
1618
from torch._inductor.fb.remote_gemm_autotune_cache import gen_best_config
1719

18-
return gen_best_config(mat1, mat2)
20+
return gen_best_config(mm_type, mats, **kwargs)
1921
else:
2022
raise NotImplementedError("Function gen_best_config is not yet implemented")

torch/_inductor/select_algorithm.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2392,31 +2392,33 @@ def do_autotuning(choices, precompile_fn, hint_override: Optional[int] = None):
23922392

23932393
if best_config_future is not None:
23942394
best_config = await_sync(best_config_future)
2395-
2396-
important_keys = [
2397-
"ACC_TYPE",
2398-
"ALLOW_TF32",
2399-
"BLOCK_K",
2400-
"BLOCK_M",
2401-
"BLOCK_N",
2402-
"EVEN_K",
2403-
"GROUP_M",
2404-
"USE_FAST_ACCUM",
2405-
"num_stages",
2406-
"num_warps",
2407-
"num_consumer_groups",
2408-
"num_buffers_warp_spec",
2409-
]
2410-
choices = [
2411-
choice
2412-
for choice in choices
2413-
if all(
2414-
f"{k}={best_config[k]}" in choice.description
2395+
if best_config:
2396+
important_keys = [
2397+
"ACC_TYPE",
2398+
"ALLOW_TF32",
2399+
"BLOCK_K",
2400+
"BLOCK_M",
2401+
"BLOCK_N",
2402+
"EVEN_K",
2403+
"GROUP_M",
2404+
"USE_FAST_ACCUM",
2405+
"num_stages",
2406+
"num_warps",
2407+
"num_consumer_groups",
2408+
"num_buffers_warp_spec",
2409+
]
2410+
choices = [
2411+
choice
2412+
for choice in choices
2413+
if all(
2414+
f"{k}={best_config[k]}" in choice.description
2415+
for k in important_keys
2416+
)
24152417
for k in important_keys
2418+
]
2419+
log.info(
2420+
"Filtered to %d choices based on best_config", len(choices)
24162421
)
2417-
for k in important_keys
2418-
]
2419-
log.info("Filtered to %d choices based on best_config", len(choices))
24202422

24212423
timings = self.lookup(
24222424
choices,

0 commit comments

Comments
 (0)