Skip to content

[inductor][cpu] Fix double-offset issue in GEMM_TEMPLATE #159233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Phoslight
Copy link

@Phoslight Phoslight commented Jul 27, 2025

Fixes #158076

Basically, the gemm template generates code like

cpp_CppMicroGemmRef_micro_gemm<static_cast<bool>(false), static_cast<bool>(false)>(
            &(X[static_cast<int64_t>(k_start + 196LL*m_start + 38416LL*ks_b_index)]),
            &(W[static_cast<int64_t>(200704000LL + n_start + 80LL*k_start + 15680LL*ks_b_index)]),
            &(local_acc_buf[static_cast<int64_t>(Nr*nci + ((-1LL)*Nr*nc))]),
            static_cast<int64_t>(m_end + ((-1LL)*m_start)),
            static_cast<int64_t>(Nr),
            static_cast<int64_t>(k_end + ((-1LL)*k_start)),
            static_cast<int64_t>(196LL),
            static_cast<int64_t>(80LL),
            static_cast<int64_t>(Nc_blocks*Nr)
        );

However, when the input tensor W has a storage offset, this results in a double offset issue. That is, the resulting pointer is 2 * 200704000LL away from W.storage().data_ptr(), which causes an out-of-bounds access.

The storage offset of W is introduced by this patch, but I think it's a reasonable fix. So cpp_gemm_template.py should handle input matrices with storage offsets properly.

I think a good way to fix this issue is to create a new matrix that has no storage offset.

When should_block_weights is true, block_weight() creates a clean new matrix, so that branch is not affected by this issue.

BTW I've also examined the FX IRs generated by torch.compile(), as well as the generated python module, and they are correct.

The newly-added test in test_cpu_select_algorithm.py can reproduce the issue. With this patch, the crash is fixed. It also resolves the crash reported in #158076.

I ran CPU tests in test_cpu_select_algorithm.py, but many of them are skipped due to MKL and AMX. I'd be appreciated if someone can help verify the test.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

Copy link

pytorch-bot bot commented Jul 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159233

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 1 Cancelled Job, 2 Unrelated Failures

As of commit 674a976 with merge base f636736 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Jul 27, 2025

CLA Signed


The committers listed above are authorized under a signed CLA.

@Phoslight
Copy link
Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Jul 27, 2025
@HDCharles HDCharles added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 28, 2025
@leslie-fang-intel leslie-fang-intel requested a review from CaoE July 29, 2025 01:38
@leslie-fang-intel
Copy link
Collaborator

I ran CPU tests in test_cpu_select_algorithm.py, but many of them are skipped due to MKL and AMX. I'd be appreciated if someone can help verify the test.

@CaoE could you help to take a look of this fix?

@leslie-fang-intel
Copy link
Collaborator

That is, the resulting pointer is 2 * 200704000LL away from W.storage().data_ptr(), which causes an out-of-bounds access.

hi @Phoslight, thanks for the fixing and want to understand more about why the offset will cause out-of-bounds. AFAIK, the offset should come from the original view node.

@Phoslight
Copy link
Author

Phoslight commented Jul 29, 2025

That is, the resulting pointer is 2 * 200704000LL away from W.storage().data_ptr(), which causes an out-of-bounds access.

hi @Phoslight, thanks for the fixing and want to understand more about why the offset will cause out-of-bounds. AFAIK, the offset should come from the original view node.

+----------------------------------+
| contiguous W with storage offset |  // should_block_weight == False
+-----+----------------------------+
      |
      |
      v
 compile_fx()
      +
      |
      |
      |  // (1) generates the cpp template:
      +------> run_node()
      |              +
      |              | ...
      |              v
      |        tuned_bmm()
      |              +
      |              | ...
      |              v
      |        CppBmmTemplate::render()   // generates template with W's storage offset
      |                                   // e.g. W[static_cast<int64_t>(200704000LL + n_start + 80LL*k_start + 15680LL*ks_b_index)]
      |                                   //   (in this patch I dropped the offset 200704000LL
      |                                   //    to align with should_block_weight branch in prep_weight)
      |
      |
      |  // (2) generates the example inputs
      +------> do_autotuning()
                     +
                     | ...
                     v
               AlgorithmSelectorCache::benchmark()
               +->benchmark_in_current_process()
                  +> get_inputs()
                     +> benchmark_example_value()
                     |  +> unwrap_view()   // drops the offset from W.data_ptr
                     |
                     | ...
                     |
                     +> DataProcessorChoiceCallerWrapper::benchmark()
                        +> CppGemmTemplate::preprocessor()
                           +> normalize_shapes()   // adds back the offset to W.data_ptr (your fix)

Thank you for your reply, Leslie.

Without this patch, the cpp template and the example inputs both add an offset, which causes the double-offset issue.
My proposed fix is to remove the offset from the cpp template to align with the behavior of the should_block_weight branch.

Hope the above chart helps clarify the issue.

@Phoslight
Copy link
Author

Gentle ping. Any other reviews? Thanks in advance.

Copy link
Contributor

@swolchok swolchok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not familiar with this code, but approving workflows to run

Comment on lines +1154 to +1156
# GEMM_TEMPLATE emits code like:
# W.data_ptr[offset + ...]
# but the data_ptr already includes the offset.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes it sound like the correct fix is to remove the offset from the index calculation in the emitted code rather than copy. I assume that would break something else?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torch.compile on BFloat16 Segment Anything segfaults in cpp_CppMicroGemmRef_micro_gemm<false, false> on Mac
5 participants