diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 7e35c93ee0b7..c40206197249 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -2769,6 +2769,31 @@ def forward(self, x, w): self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1) + @patches + @torch.no_grad + @parametrize("bs", (1, 50)) + @parametrize("Mdim", (192,)) + @parametrize("Kdim", (196,)) + @parametrize("Ndim", (84, 385)) + @dtypes(torch.float, torch.bfloat16, torch.half) + def test_bmm_with_y_storage_offset(self, dtype, bs, Mdim, Kdim, Ndim): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + # y_with_offset: contiguous, but has non-zero storage offset + y_with_offset = torch.empty((3, *y.shape), dtype=y.dtype, device=y.device)[2].copy_(y) + return x @ y_with_offset + + counters.clear() + u = torch.randn(bs, Mdim, Kdim).to(dtype=dtype) + v = torch.randn(bs, Kdim, Ndim).to(dtype=dtype) + mod = M().to(dtype=dtype).eval() + with verify(dtype) as (atol, rtol): + self.common(mod, (u, v), atol=atol, rtol=rtol) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) + @patches @torch.no_grad @dtypes(torch.float) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 8f04ac923613..4721e532c9be 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -1149,6 +1149,13 @@ def prep_weight( elif isinstance(W, ir.IRNode): # Require W layout to be fixed & contiguous, happens inplace. ir.ExternKernel.require_contiguous(W) + if W.layout.offset != 0: + # W may be contiguous but still have a non-zero storage offset. + # GEMM_TEMPLATE emits code like: + # W.data_ptr[offset + ...] + # but the data_ptr already includes the offset. + # To avoid double-offsetting, we create a copy with storage offset = 0. + new_inputs[1] = ir.ExternKernel.copy_input(W) if not skip_int8_compensation and _is_int8_gemm(new_inputs): BCompensate = None