Skip to content

[inductor] add lowering for repeat_interleave.Tensor with output size specified (#147160) #158462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 34 commits into
base: gh/v0i0/1/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ae99042
[inductor] add lowering for repeat_interleave.Tensor with output size…
v0i0 Jul 16, 2025
217f847
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Jul 16, 2025
631d213
Update
v0i0 Jul 16, 2025
e5a3aa1
Update
v0i0 Jul 16, 2025
9c22766
Update
v0i0 Jul 16, 2025
cc2189f
Update
v0i0 Jul 16, 2025
6c37575
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Jul 16, 2025
ef23607
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Jul 17, 2025
6141da8
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Jul 17, 2025
5da407f
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Jul 17, 2025
0aaf849
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Jul 17, 2025
e1b0bdd
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Jul 18, 2025
f03d66f
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Jul 21, 2025
90f72c0
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Jul 22, 2025
4310217
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Jul 22, 2025
80f52ce
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Jul 22, 2025
2467517
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Jul 22, 2025
e7be1a6
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Jul 23, 2025
8592e4c
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Jul 24, 2025
83d9c4a
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Jul 24, 2025
dae7e73
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Jul 24, 2025
591691b
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Jul 25, 2025
52c08d9
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Jul 28, 2025
5fc1179
Update
v0i0 Jul 28, 2025
a56b5d7
Update
v0i0 Jul 28, 2025
eccfd1e
Update
v0i0 Jul 29, 2025
c6dc919
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Jul 29, 2025
71ba5cf
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Jul 29, 2025
03d6a22
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Aug 6, 2025
a3bb872
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Aug 6, 2025
b2965b9
Update
v0i0 Aug 7, 2025
10ce9c5
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Aug 7, 2025
ec9e2fb
Update on "[inductor] add lowering for repeat_interleave.Tensor with …
v0i0 Aug 7, 2025
66109f5
Update
v0i0 Aug 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13670,6 +13670,35 @@ def fn(inp, repeats, output_size):
args = (inp, repeats, output_size)
self.assertEqual(fn(*args), torch.compile(fn)(*args))

@parametrize("dtype", [torch.int32, torch.int64])
@parametrize("nd", [1, 2])
def test_repeat_interleave_Tensor_decomp(self, dtype, nd):
# https://github.com/pytorch/pytorch/issues/147160
def f(input, repeats):
return torch.repeat_interleave(input, repeats, dim=0, output_size=3) + 1

input = torch.tensor([[1, 2], [3, 4]], dtype=dtype, device=self.device)
input = torch.arange(1, 2**nd + 1, dtype=dtype, device=self.device).reshape(
[2] * nd
)
repeat = torch.tensor([1, 2], device=self.device)

if input.device.type == "mps" and dtype == torch.int64:
raise unittest.SkipTest(
"torch.compile fails this test with mps & int64, "
"see https://github.com/pytorch/pytorch/issues/159408"
)

f_compiled = torch.compile(f)
output, (code,) = run_and_get_code(f_compiled, input, repeat)
reference = f(input, repeat)
self.assertEqual(output, reference)
# we don't lower when the cpp_wrapper is used because it cannot generate
# proper examples during autotune
can_lower = (not config.cpp_wrapper) and (input.device.type != "mps")
has_lowered = not re.search(r"repeat_interleave.Tensor", code)
self.assertEqual(has_lowered, can_lower)

# end of class CommonTemplate - add new tests here


Expand Down
2 changes: 1 addition & 1 deletion test/inductor/test_torchinductor_codegen_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def run(*ex, **kwargs):
"test_rand_like_deterministic_dynamic_shapes": TestFailure(
("cpu", "cuda", "xpu"), is_skip=True
),
"test_repeat_interleave_2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_repeat_interleave_2_dynamic_shapes": TestFailure(("cpu", "xpu")),
"test_slice_mutation2_dynamic_shapes": TestFailure(
("cpu", "cuda", "xpu"), is_skip=True
),
Expand Down
22 changes: 22 additions & 0 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,3 +1154,25 @@ def rrelu_with_noise_functional(
else:
negative_slope = (lower + upper) / 2
return aten.leaky_relu(self, negative_slope), torch.Tensor()


@register_decomposition(aten.repeat_interleave.Tensor)
def repeat_interleave_Tensor(
repeat: torch.Tensor,
output_size: Optional[int] = None,
) -> torch.Tensor:
if config.triton.autotune_at_compile_time:
# We can't compile-time auto-tune this because
# it expects specific data in `repeat`
return NotImplemented
if output_size is None or type(output_size) is not int:
return NotImplemented
if repeat.device.type == "mps":
return NotImplemented
assert repeat.dtype in [torch.int32, torch.int64]
assert repeat.ndim == 1
cumsum = repeat.cumsum(0)
pos = torch.arange(output_size, device=repeat.device)
return torch.searchsorted(
cumsum, pos, out_int32=(repeat.dtype == torch.int32), right=True
)
1 change: 1 addition & 0 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2879,6 +2879,7 @@ def is_aligned(x):

# index_reduce requires fallback when use_scatter_fallback(...) returns True
make_fallback(aten.index_reduce)
make_fallback(aten.repeat_interleave.Tensor, override_decomp=True)


# Register with type_promotion_kind None.
Expand Down
Loading