Skip to content

Commit 9625356

Browse files
committed
[inductor] add lowering for repeat_interleave.Tensor with output size specified (#147160)
ghstack-source-id: fdeeea1 Pull Request resolved: #158462
1 parent da05b7f commit 9625356

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

test/inductor/test_torchinductor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13626,6 +13626,18 @@ def forward(self, x):
1362613626
FileCheck().check("cpp_fused_add_0").run(code)
1362713627
self.assertEqual(refe_out, test_out)
1362813628

13629+
def test_repeat_interleave_pass(self):
13630+
# https://github.com/pytorch/pytorch/issues/147160
13631+
def f(input, repeats):
13632+
return torch.repeat_interleave(input, repeats, dim=0, output_size=3) + 1
13633+
13634+
input = torch.tensor([[1, 2], [3, 4]], device="cuda")
13635+
repeat = torch.tensor([1, 2], device="cuda")
13636+
f_compiled = torch.compile(f)
13637+
test, (code,) = run_and_get_code(f_compiled, input, repeat)
13638+
self.assertEqual(test, f(input, repeat))
13639+
self.assertFalse("repeat_interleave.Tensor" in code)
13640+
1362913641

1363013642
@dataclasses.dataclass
1363113643
class TestFailure:

torch/_inductor/decomposition.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,3 +1150,15 @@ def rrelu_with_noise_functional(
11501150
else:
11511151
negative_slope = (lower + upper) / 2
11521152
return aten.leaky_relu(self, negative_slope), torch.Tensor()
1153+
1154+
1155+
@register_decomposition(aten.repeat_interleave.Tensor)
1156+
def repeast_interleave_Tensor(
1157+
repeat: torch.Tensor,
1158+
output_size: Optional[int] = None,
1159+
) -> torch.Tensor:
1160+
if output_size is None or type(output_size) is not int:
1161+
return NotImplemented
1162+
cumsum = repeat.cumsum(0)
1163+
pos = torch.arange(output_size, device=repeat.device)
1164+
return torch.searchsorted(cumsum, pos, right=True)

0 commit comments

Comments
 (0)