Skip to content

Commit 22103e4

Browse files
committed
[inductor] add lowering for repeat_interleave.Tensor with output size specified (#147160)
ghstack-source-id: 7c98b17 Pull Request resolved: #158462
1 parent 6efbbd6 commit 22103e4

File tree

4 files changed

+53
-1
lines changed

4 files changed

+53
-1
lines changed

test/inductor/test_torchinductor.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13670,6 +13670,35 @@ def fn(inp, repeats, output_size):
1367013670
args = (inp, repeats, output_size)
1367113671
self.assertEqual(fn(*args), torch.compile(fn)(*args))
1367213672

13673+
@parametrize("dtype", [torch.int32, torch.int64])
13674+
@parametrize("nd", [1, 2])
13675+
def test_repeat_interleave_Tensor_decomp(self, dtype, nd):
13676+
# https://github.com/pytorch/pytorch/issues/147160
13677+
def f(input, repeats):
13678+
return torch.repeat_interleave(input, repeats, dim=0, output_size=3) + 1
13679+
13680+
input = torch.tensor([[1, 2], [3, 4]], dtype=dtype, device=self.device)
13681+
input = torch.arange(1, 2**nd + 1, dtype=dtype, device=self.device).reshape(
13682+
[2] * nd
13683+
)
13684+
repeat = torch.tensor([1, 2], device=self.device)
13685+
13686+
if input.device.type == "mps" and dtype == torch.int64:
13687+
raise unittest.SkipTest(
13688+
"torch.compile fails this test with mps & int64, "
13689+
"see https://github.com/pytorch/pytorch/issues/159408"
13690+
)
13691+
13692+
f_compiled = torch.compile(f)
13693+
output, (code,) = run_and_get_code(f_compiled, input, repeat)
13694+
reference = f(input, repeat)
13695+
self.assertEqual(output, reference)
13696+
# we don't lower when the cpp_wrapper is used because it cannot generate
13697+
# proper examples during autotune
13698+
can_lower = (not config.cpp_wrapper) and (input.device.type != "mps")
13699+
has_lowered = not re.search(r"repeat_interleave.Tensor", code)
13700+
self.assertEqual(has_lowered, can_lower)
13701+
1367313702
# end of class CommonTemplate - add new tests here
1367413703

1367513704

test/inductor/test_torchinductor_codegen_dynamic_shapes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def run(*ex, **kwargs):
348348
"test_rand_like_deterministic_dynamic_shapes": TestFailure(
349349
("cpu", "cuda", "xpu"), is_skip=True
350350
),
351-
"test_repeat_interleave_2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
351+
"test_repeat_interleave_2_dynamic_shapes": TestFailure(("cpu", "xpu")),
352352
"test_slice_mutation2_dynamic_shapes": TestFailure(
353353
("cpu", "cuda", "xpu"), is_skip=True
354354
),

torch/_inductor/decomposition.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,3 +1154,25 @@ def rrelu_with_noise_functional(
11541154
else:
11551155
negative_slope = (lower + upper) / 2
11561156
return aten.leaky_relu(self, negative_slope), torch.Tensor()
1157+
1158+
1159+
@register_decomposition(aten.repeat_interleave.Tensor)
1160+
def repeat_interleave_Tensor(
1161+
repeat: torch.Tensor,
1162+
output_size: Optional[int] = None,
1163+
) -> torch.Tensor:
1164+
if config.triton.autotune_at_compile_time:
1165+
# We can't compile-time auto-tune this because
1166+
# it expects specific data in `repeat`
1167+
return NotImplemented
1168+
if output_size is None or type(output_size) is not int:
1169+
return NotImplemented
1170+
if repeat.device.type == "mps":
1171+
return NotImplemented
1172+
assert repeat.dtype in [torch.int32, torch.int64]
1173+
assert repeat.ndim == 1
1174+
cumsum = repeat.cumsum(0)
1175+
pos = torch.arange(output_size, device=repeat.device)
1176+
return torch.searchsorted(
1177+
cumsum, pos, out_int32=(repeat.dtype == torch.int32), right=True
1178+
)

torch/_inductor/lowering.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2879,6 +2879,7 @@ def is_aligned(x):
28792879

28802880
# index_reduce requires fallback when use_scatter_fallback(...) returns True
28812881
make_fallback(aten.index_reduce)
2882+
make_fallback(aten.repeat_interleave.Tensor, override_decomp=True)
28822883

28832884

28842885
# Register with type_promotion_kind None.

0 commit comments

Comments
 (0)