Skip to content

Commit 6c37575

Browse files
committed
Update on "[inductor] add lowering for repeat_interleave.Tensor with output size specified (#147160)"
[ghstack-poisoned]
1 parent cc2189f commit 6c37575

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

test/inductor/test_torchinductor.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13626,13 +13626,19 @@ def forward(self, x):
1362613626
FileCheck().check("cpp_fused_add_0").run(code)
1362713627
self.assertEqual(refe_out, test_out)
1362813628

13629-
def test_repeat_interleave_pass(self):
13629+
13630+
@parametrize("dtype", [torch.int32, torch.int64])
13631+
def test_repeat_interleave_Tensor_decomp(self, dtype):
13632+
device = "cpu"
13633+
if self.device.lower() == "cuda":
13634+
device = "cuda"
13635+
1363013636
# https://github.com/pytorch/pytorch/issues/147160
1363113637
def f(input, repeats):
1363213638
return torch.repeat_interleave(input, repeats, dim=0, output_size=3) + 1
1363313639

13634-
input = torch.tensor([[1, 2], [3, 4]], device="cuda")
13635-
repeat = torch.tensor([1, 2], device="cuda")
13640+
input = torch.tensor([[1, 2], [3, 4]], dtype=dtype, device=device)
13641+
repeat = torch.tensor([1, 2], device=device)
1363613642
f_compiled = torch.compile(f)
1363713643
test, (code,) = run_and_get_code(f_compiled, input, repeat)
1363813644
self.assertEqual(test, f(input, repeat))

torch/_inductor/decomposition.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1159,6 +1159,8 @@ def repeast_interleave_Tensor(
11591159
) -> torch.Tensor:
11601160
if output_size is None or type(output_size) is not int:
11611161
return NotImplemented
1162+
if repeat.dtype not in [torch.int32, torch.int64]:
1163+
return NotImplemented
11621164
cumsum = repeat.cumsum(0)
11631165
pos = torch.arange(output_size, device=repeat.device)
1164-
return torch.searchsorted(cumsum, pos, right=True)
1166+
return torch.searchsorted(cumsum, pos, out_int32=(repeat.dtype == torch.int32), right=True)

0 commit comments

Comments
 (0)