Skip to content

Commit 0e2013a

Browse files
oulgenpytorchmergebot
authored andcommitted
Add helion x pt2 test (#155513)
This kinda just worked out of the box, shocking. PT2 traced into helion and emitted it as a user defined triton kernel: P1836496774 In the long run, we do not actually want this, but rather to create a helion HOP so we can do fusions etc. Pull Request resolved: #155513 Approved by: https://github.com/zou3519, https://github.com/jansel
1 parent 5b9db43 commit 0e2013a

File tree

4 files changed

+73
-0
lines changed

4 files changed

+73
-0
lines changed

.ci/docker/common/install_triton.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,6 @@ fi
9898
if [ -n "${NUMPY_VERSION}" ]; then
9999
pip_install "numpy==${NUMPY_VERSION}"
100100
fi
101+
if [[ "$ANACONDA_PYTHON_VERSION" != 3.9* ]]; then
102+
pip_install helion
103+
fi

test/inductor/test_helion_kernels.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Owner(s): ["module: inductor"]
2+
import torch
3+
from torch._inductor.test_case import run_tests, TestCase
4+
from torch.testing._internal.common_utils import instantiate_parametrized_tests
5+
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_HELION, requires_helion
6+
7+
8+
if HAS_HELION:
9+
import helion
10+
import helion.language as hl
11+
12+
13+
class HelionTests(TestCase):
14+
@requires_helion()
15+
def test_add_kernel(self):
16+
@helion.kernel(config=helion.Config(block_sizes=[1, 2]))
17+
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
18+
# match pytorch broadcasting rules
19+
x, y = torch.broadcast_tensors(x, y)
20+
out = torch.empty(
21+
x.shape,
22+
# match type promotion of torch.add
23+
dtype=torch.promote_types(x.dtype, y.dtype),
24+
device=x.device,
25+
)
26+
# tile will be a tuple of blocks
27+
for tile in hl.tile(out.size()):
28+
out[tile] = x[tile] + y[tile]
29+
return out
30+
31+
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
32+
return add(x, y)
33+
34+
x = torch.randn(4, 8, device=GPU_TYPE, dtype=torch.float16)
35+
y = torch.randn(4, 8, device=GPU_TYPE, dtype=torch.float16)
36+
37+
out = add(x, y)
38+
compiled_add = torch.compile(f, fullgraph=True, backend="inductor")
39+
compiled_out = compiled_add(x, y)
40+
41+
self.assertEqual(out, x + y)
42+
self.assertEqual(compiled_out, x + y)
43+
44+
45+
instantiate_parametrized_tests(HelionTests)
46+
47+
48+
if __name__ == "__main__":
49+
run_tests()

torch/testing/_internal/inductor_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch._inductor.codecache import CppCodeCache
1717
from torch._inductor.utils import get_gpu_shared_memory, is_big_gpu
1818
from torch._inductor.utils import GPU_TYPES, get_gpu_type, is_gpu
19+
from torch.utils._helion import has_helion
1920
from torch.utils._triton import has_triton
2021
from torch.testing._internal.common_device_type import (
2122
get_desired_device_type_test_bases,
@@ -48,6 +49,8 @@ def test_cpu():
4849

4950
HAS_TRITON = has_triton()
5051

52+
HAS_HELION = has_helion()
53+
5154
if HAS_TRITON:
5255
import triton
5356
TRITON_HAS_CPU = "cpu" in triton.backends.backends
@@ -133,6 +136,7 @@ def skip_windows_ci(name: str, file: str) -> None:
133136
# TODO: Remove HAS_MPS condition when `HAS_GPU` includes HAS_MPS
134137
requires_gpu = functools.partial(unittest.skipIf, not (HAS_GPU or HAS_MPS), "requires gpu")
135138
requires_triton = functools.partial(unittest.skipIf, not HAS_TRITON, "requires triton")
139+
requires_helion = functools.partial(unittest.skipIf, not HAS_HELION, "requires helion")
136140

137141
def requires_cuda_with_enough_memory(min_mem_required):
138142
def inner(fn):

torch/utils/_helion.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import functools
2+
3+
from torch.utils._triton import has_triton
4+
5+
6+
@functools.lru_cache(None)
7+
def has_helion_package() -> bool:
8+
try:
9+
import helion # type: ignore[import-untyped, import-not-found] # noqa: F401
10+
except ImportError:
11+
return False
12+
return True
13+
14+
15+
@functools.lru_cache(None)
16+
def has_helion() -> bool:
17+
return has_helion_package() and has_triton()

0 commit comments

Comments
 (0)