Skip to content

Commit cf0a0dc

Browse files
PaulZhang12pytorchmergebot
authored andcommitted
Make user defined Triton kernels serializable for fx_graph_runnable (#160002)
Resolves issue #153475 where `fx_graph_runnable` didn't work with user defined triton kernels. Pull Request resolved: #160002 Approved by: https://github.com/eellison
1 parent b149c72 commit cf0a0dc

File tree

2 files changed

+165
-0
lines changed

2 files changed

+165
-0
lines changed

test/dynamo/test_fx_graph_runnable.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,65 @@
1111
from torch._inductor.codecache import WritableTempFile
1212
from torch._inductor.test_case import TestCase
1313
from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE
14+
from torch.utils._triton import has_triton
1415

1516

1617
if torch.distributed.is_available():
1718
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
1819
from torch.testing._internal.distributed.fake_pg import FakeStore
1920

21+
if has_triton():
22+
import triton
23+
import triton.language as tl
24+
25+
def init_to_zero(name):
26+
return lambda nargs: nargs[name].zero_()
27+
28+
@triton.jit
29+
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
30+
pid = tl.program_id(axis=0)
31+
32+
block_start = pid * BLOCK_SIZE
33+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
34+
mask = offsets < n_elements
35+
36+
x = tl.load(x_ptr + offsets, mask=mask)
37+
y = tl.load(y_ptr + offsets, mask=mask)
38+
output = x + y
39+
tl.atomic_add(output_ptr + offsets, output, mask=mask)
40+
41+
@triton.autotune(
42+
configs=[
43+
triton.Config(
44+
{"BLOCK_SIZE": 1024},
45+
num_warps=4,
46+
num_stages=2,
47+
pre_hook=init_to_zero("output_ptr"),
48+
)
49+
],
50+
pre_hook=init_to_zero("output_ptr"),
51+
post_hook=init_to_zero("output_ptr"),
52+
key=["n_elements"],
53+
)
54+
@triton.jit
55+
def add_kernel_autotune(
56+
x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr
57+
):
58+
pid = tl.program_id(axis=0)
59+
60+
block_start = pid * BLOCK_SIZE
61+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
62+
mask = offsets < n_elements
63+
64+
x = tl.load(x_ptr + offsets, mask=mask)
65+
y = tl.load(y_ptr + offsets, mask=mask)
66+
output = x + y
67+
tl.atomic_add(output_ptr + offsets, output, mask=mask)
68+
69+
70+
from torch.testing._internal.inductor_utils import GPU_TYPE
71+
from torch.testing._internal.triton_utils import requires_gpu
72+
2073

2174
class FxGraphRunnableArtifactFilter(logging.Filter):
2275
def filter(self, record):
@@ -100,6 +153,41 @@ def f(x):
100153
torch.compile(f)(torch.randn(4))
101154
self._exec_and_verify_payload()
102155

156+
@unittest.skipUnless(has_triton(), "Triton not available")
157+
def test_user_defined_triton_kernel_autotune(self):
158+
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
159+
output = torch.ones(x.shape, device=x.device, dtype=x.dtype)
160+
n_elements = output.numel()
161+
162+
def grid(
163+
meta,
164+
):
165+
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
166+
167+
add_kernel_autotune[grid](x, y, output, n_elements)
168+
return output
169+
170+
x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
171+
y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
172+
173+
torch.compile(add)(x, y)
174+
self._exec_and_verify_payload()
175+
176+
@unittest.skipUnless(has_triton(), "Triton not available")
177+
@requires_gpu
178+
def test_user_defined_triton_kernel(self):
179+
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
180+
output = torch.ones(x.shape, device=x.device, dtype=x.dtype)
181+
n_elements = x.numel()
182+
add_kernel[n_elements,](x, y, output, n_elements, BLOCK_SIZE=4)
183+
return output
184+
185+
x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
186+
y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
187+
188+
torch.compile(add)(x, y)
189+
self._exec_and_verify_payload()
190+
103191
def test_two_inputs_matmul(self):
104192
def f(a, b):
105193
return (a @ b).relu()

torch/_dynamo/repro/after_aot.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,24 @@
3434
from typing import Any, Callable, IO, Optional, TYPE_CHECKING, Union
3535
from typing_extensions import Unpack
3636

37+
from torch.utils._triton import has_triton
38+
39+
40+
if has_triton():
41+
from triton.runtime.autotuner import Autotuner, Heuristics
42+
from triton.runtime.jit import JITFunction
43+
else:
44+
45+
class Autotuner: # type: ignore[no-redef]
46+
pass
47+
48+
class JITFunction: # type: ignore[no-redef]
49+
pass
50+
51+
class Heuristics: # type: ignore[no-redef]
52+
pass
53+
54+
3755
import torch
3856
import torch.fx as fx
3957
import torch.nn as nn
@@ -58,6 +76,7 @@
5876
)
5977
from torch._dynamo.utils import clone_inputs, counters, same
6078
from torch._environment import is_fbcode
79+
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
6180
from torch._inductor.cpp_builder import normalize_path_separator
6281
from torch._inductor.output_code import OutputCode
6382
from torch._library.fake_class_registry import FakeScriptObject
@@ -302,6 +321,16 @@ def generate_compiler_repro_string(
302321
"""
303322
).strip()
304323

324+
triton_imports = ""
325+
326+
if len(kernel_side_table.id_to_kernel) > 0:
327+
triton_imports = textwrap.dedent(
328+
"""
329+
import triton
330+
import triton.language as tl
331+
"""
332+
).strip()
333+
305334
model_str = textwrap.dedent(
306335
f"""
307336
{generate_env_vars_string(stable_output=stable_output)}
@@ -312,6 +341,7 @@ def generate_compiler_repro_string(
312341
from math import inf
313342
import torch._inductor.inductor_prims
314343
{distributed_imports}
344+
{triton_imports}
315345
316346
{generate_config_string(stable_output=stable_output)}
317347
@@ -330,6 +360,53 @@ def generate_compiler_repro_string(
330360
model_str += f"# torch git version: {torch.version.git_version}\n\n\n"
331361
model_str += _cuda_system_info_comment()
332362

363+
kernel_side_table_prefix = (
364+
"torch._higher_order_ops.triton_kernel_wrap.kernel_side_table"
365+
)
366+
# Track which grid entry corresponds to the best config
367+
for id in kernel_side_table.id_to_kernel:
368+
kernel = kernel_side_table.get_kernel(id)
369+
370+
if isinstance(kernel, Autotuner):
371+
if isinstance(kernel.fn, Heuristics):
372+
model_str += "ERROR: Repro will not work as intended, "
373+
model_str += (
374+
"triton.runtime.autotuner.Heuristics is not currently supported\n"
375+
)
376+
break
377+
378+
config_strs = []
379+
for kernel_config in kernel.configs:
380+
config_strs.append(f"""triton.Config(
381+
{str(kernel_config.kwargs)},
382+
num_warps={kernel_config.num_warps},
383+
num_stages={kernel_config.num_stages},
384+
)""")
385+
386+
config_str = ",".join(config_strs)
387+
model_str += textwrap.dedent(f"""
388+
@triton.autotune(
389+
configs=[
390+
{config_str}
391+
],
392+
key=[]
393+
)
394+
""").strip()
395+
396+
model_str += "\n@triton.jit\n"
397+
src_code = kernel.src if isinstance(kernel, JITFunction) else kernel.fn.src
398+
fn_name = (
399+
kernel._fn_name if isinstance(kernel, JITFunction) else kernel.fn._fn_name
400+
)
401+
fn_name = fn_name.split(".")[-1]
402+
403+
model_str += src_code
404+
model_str += "\n"
405+
model_str += f"{kernel_side_table_prefix}.add_kernel({fn_name})\n"
406+
407+
if len(kernel_side_table.constant_args) > 0:
408+
model_str += f"{kernel_side_table_prefix}.constant_args={kernel_side_table.constant_args}\n"
409+
333410
model_str += NNModuleToString.convert(gm)
334411

335412
writer = InputWriter(save_dir, stable_hash=stable_hash)

0 commit comments

Comments
 (0)