diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 0e76ca489284..9298d251a098 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -13717,6 +13717,35 @@ def fn(inp, repeats, output_size): args = (inp, repeats, output_size) self.assertEqual(fn(*args), torch.compile(fn)(*args)) + @parametrize("dtype", [torch.int32, torch.int64]) + @parametrize("nd", [1, 2]) + def test_repeat_interleave_Tensor_decomp(self, dtype, nd): + # https://github.com/pytorch/pytorch/issues/147160 + def f(input, repeats): + return torch.repeat_interleave(input, repeats, dim=0, output_size=3) + 1 + + input = torch.tensor([[1, 2], [3, 4]], dtype=dtype, device=self.device) + input = torch.arange(1, 2**nd + 1, dtype=dtype, device=self.device).reshape( + [2] * nd + ) + repeat = torch.tensor([1, 2], device=self.device) + + if input.device.type == "mps" and dtype == torch.int64: + raise unittest.SkipTest( + "torch.compile fails this test with mps & int64, " + "see https://github.com/pytorch/pytorch/issues/159408" + ) + + f_compiled = torch.compile(f) + output, (code,) = run_and_get_code(f_compiled, input, repeat) + reference = f(input, repeat) + self.assertEqual(output, reference) + # we don't lower when the cpp_wrapper is used because it cannot generate + # proper examples during autotune + can_lower = (not config.cpp_wrapper) and (input.device.type != "mps") + has_lowered = not re.search(r"repeat_interleave.Tensor", code) + self.assertEqual(has_lowered, can_lower) + # end of class CommonTemplate - add new tests here diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index cdf76772b936..62aeaf5e99c8 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -348,7 +348,7 @@ def run(*ex, **kwargs): "test_rand_like_deterministic_dynamic_shapes": TestFailure( ("cpu", "cuda", "xpu"), is_skip=True ), - "test_repeat_interleave_2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_repeat_interleave_2_dynamic_shapes": TestFailure(("cpu", "xpu")), "test_slice_mutation2_dynamic_shapes": TestFailure( ("cpu", "cuda", "xpu"), is_skip=True ), diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index d903d851ee87..c38265abe336 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -1154,3 +1154,25 @@ def rrelu_with_noise_functional( else: negative_slope = (lower + upper) / 2 return aten.leaky_relu(self, negative_slope), torch.Tensor() + + +@register_decomposition(aten.repeat_interleave.Tensor) +def repeat_interleave_Tensor( + repeat: torch.Tensor, + output_size: Optional[int] = None, +) -> torch.Tensor: + if config.triton.autotune_at_compile_time: + # We can't compile-time auto-tune this because + # it expects specific data in `repeat` + return NotImplemented + if output_size is None or type(output_size) is not int: + return NotImplemented + if repeat.device.type == "mps": + return NotImplemented + assert repeat.dtype in [torch.int32, torch.int64] + assert repeat.ndim == 1 + cumsum = repeat.cumsum(0) + pos = torch.arange(output_size, device=repeat.device) + return torch.searchsorted( + cumsum, pos, out_int32=(repeat.dtype == torch.int32), right=True + ) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 74a562365b69..efcbc97ac7d0 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2879,6 +2879,7 @@ def is_aligned(x): # index_reduce requires fallback when use_scatter_fallback(...) returns True make_fallback(aten.index_reduce) +make_fallback(aten.repeat_interleave.Tensor, override_decomp=True) # Register with type_promotion_kind None.