|
11 | 11 | from torch._inductor.codecache import WritableTempFile
|
12 | 12 | from torch._inductor.test_case import TestCase
|
13 | 13 | from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE
|
| 14 | +from torch.utils._triton import has_triton |
14 | 15 |
|
15 | 16 |
|
16 | 17 | if torch.distributed.is_available():
|
17 | 18 | from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
|
18 | 19 | from torch.testing._internal.distributed.fake_pg import FakeStore
|
19 | 20 |
|
| 21 | +if has_triton(): |
| 22 | + import triton |
| 23 | + import triton.language as tl |
| 24 | + |
| 25 | + def init_to_zero(name): |
| 26 | + return lambda nargs: nargs[name].zero_() |
| 27 | + |
| 28 | + @triton.jit |
| 29 | + def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): |
| 30 | + pid = tl.program_id(axis=0) |
| 31 | + |
| 32 | + block_start = pid * BLOCK_SIZE |
| 33 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 34 | + mask = offsets < n_elements |
| 35 | + |
| 36 | + x = tl.load(x_ptr + offsets, mask=mask) |
| 37 | + y = tl.load(y_ptr + offsets, mask=mask) |
| 38 | + output = x + y |
| 39 | + tl.atomic_add(output_ptr + offsets, output, mask=mask) |
| 40 | + |
| 41 | + @triton.autotune( |
| 42 | + configs=[ |
| 43 | + triton.Config( |
| 44 | + {"BLOCK_SIZE": 1024}, |
| 45 | + num_warps=4, |
| 46 | + num_stages=2, |
| 47 | + pre_hook=init_to_zero("output_ptr"), |
| 48 | + ) |
| 49 | + ], |
| 50 | + pre_hook=init_to_zero("output_ptr"), |
| 51 | + post_hook=init_to_zero("output_ptr"), |
| 52 | + key=["n_elements"], |
| 53 | + ) |
| 54 | + @triton.jit |
| 55 | + def add_kernel_autotune( |
| 56 | + x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr |
| 57 | + ): |
| 58 | + pid = tl.program_id(axis=0) |
| 59 | + |
| 60 | + block_start = pid * BLOCK_SIZE |
| 61 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 62 | + mask = offsets < n_elements |
| 63 | + |
| 64 | + x = tl.load(x_ptr + offsets, mask=mask) |
| 65 | + y = tl.load(y_ptr + offsets, mask=mask) |
| 66 | + output = x + y |
| 67 | + tl.atomic_add(output_ptr + offsets, output, mask=mask) |
| 68 | + |
| 69 | + |
| 70 | +from torch.testing._internal.inductor_utils import GPU_TYPE |
| 71 | +from torch.testing._internal.triton_utils import requires_gpu |
| 72 | + |
20 | 73 |
|
21 | 74 | class FxGraphRunnableArtifactFilter(logging.Filter):
|
22 | 75 | def filter(self, record):
|
@@ -100,6 +153,41 @@ def f(x):
|
100 | 153 | torch.compile(f)(torch.randn(4))
|
101 | 154 | self._exec_and_verify_payload()
|
102 | 155 |
|
| 156 | + @unittest.skipUnless(has_triton(), "Triton not available") |
| 157 | + def test_user_defined_triton_kernel_autotune(self): |
| 158 | + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 159 | + output = torch.ones(x.shape, device=x.device, dtype=x.dtype) |
| 160 | + n_elements = output.numel() |
| 161 | + |
| 162 | + def grid( |
| 163 | + meta, |
| 164 | + ): |
| 165 | + return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| 166 | + |
| 167 | + add_kernel_autotune[grid](x, y, output, n_elements) |
| 168 | + return output |
| 169 | + |
| 170 | + x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) |
| 171 | + y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) |
| 172 | + |
| 173 | + torch.compile(add)(x, y) |
| 174 | + self._exec_and_verify_payload() |
| 175 | + |
| 176 | + @unittest.skipUnless(has_triton(), "Triton not available") |
| 177 | + @requires_gpu |
| 178 | + def test_user_defined_triton_kernel(self): |
| 179 | + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 180 | + output = torch.ones(x.shape, device=x.device, dtype=x.dtype) |
| 181 | + n_elements = x.numel() |
| 182 | + add_kernel[n_elements,](x, y, output, n_elements, BLOCK_SIZE=4) |
| 183 | + return output |
| 184 | + |
| 185 | + x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) |
| 186 | + y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16) |
| 187 | + |
| 188 | + torch.compile(add)(x, y) |
| 189 | + self._exec_and_verify_payload() |
| 190 | + |
103 | 191 | def test_two_inputs_matmul(self):
|
104 | 192 | def f(a, b):
|
105 | 193 | return (a @ b).relu()
|
|
0 commit comments