Skip to content

[inductor][cpu] Fix double-offset issue in GEMM_TEMPLATE #159233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions test/inductor/test_cpu_select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2769,6 +2769,31 @@ def forward(self, x, w):
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)

@patches
@torch.no_grad
@parametrize("bs", (1, 50))
@parametrize("Mdim", (192,))
@parametrize("Kdim", (196,))
@parametrize("Ndim", (84, 385))
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_bmm_with_y_storage_offset(self, dtype, bs, Mdim, Kdim, Ndim):
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
# y_with_offset: contiguous, but has non-zero storage offset
y_with_offset = torch.empty((3, *y.shape), dtype=y.dtype, device=y.device)[2].copy_(y)
return x @ y_with_offset

counters.clear()
u = torch.randn(bs, Mdim, Kdim).to(dtype=dtype)
v = torch.randn(bs, Kdim, Ndim).to(dtype=dtype)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u, v), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1)

@patches
@torch.no_grad
@dtypes(torch.float)
Expand Down
7 changes: 7 additions & 0 deletions torch/_inductor/codegen/cpp_gemm_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,13 @@
elif isinstance(W, ir.IRNode):
# Require W layout to be fixed & contiguous, happens inplace.
ir.ExternKernel.require_contiguous(W)
if W.layout.offset != 0:

Check failure on line 1152 in torch/_inductor/codegen/cpp_gemm_template.py

View workflow job for this annotation

GitHub Actions / lintrunner-mypy / linux-job

MYPY [attr-defined]

"IRNode" has no attribute "layout"
# W may be contiguous but still have a non-zero storage offset.
# GEMM_TEMPLATE emits code like:
# W.data_ptr[offset + ...]
# but the data_ptr already includes the offset.
Comment on lines +1154 to +1156
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes it sound like the correct fix is to remove the offset from the index calculation in the emitted code rather than copy. I assume that would break something else?

# To avoid double-offsetting, we create a copy with storage offset = 0.
new_inputs[1] = ir.ExternKernel.copy_input(W)

if not skip_int8_compensation and _is_int8_gemm(new_inputs):
BCompensate = None
Expand Down
Loading