Skip to content

Fuse matmul #157743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open

Fuse matmul #157743

wants to merge 31 commits into from

Conversation

nullplay
Copy link
Collaborator

@nullplay nullplay commented Jul 7, 2025

Implementation of #151705

This PR introduces the initial implementation of native tl.dot support in Inductor, with the goal of generating Triton matmul kernels directly—without relying on predefined templates.

To avoid complexity and ease the review process, I plan to split this work into two phases as outlined in #151705:

  1. Basic support (this PR)
  2. Lazy broadcasting for optimal performance (future PR)

Summary of This PR

This PR implements the basic functionality. It does not include lazy broadcasting, so the generated kernels may involve explicit tl.reshape and tl.trans operations before calling tl.dot, which introduces some overhead.

Notable Changes

  1. Adds a new config flag: config.triton.enable_native_matmul
  2. Introduces a new ops.dot IR node in Inductor and lowers aten.mm and aten.bmm to it when native matmul is enabled
  3. Enforces tililng suitable for matmul when the native matmul flag is enabled
  4. Implements code generation for ops.dot
  5. Adds Triton autotuning heuristics: for now, I’ve copied the configuration from the existing matmul templates. However, this may not be optimal—it currently takes a long time to tune, and I think there must be a better way to tackle this.

@eellison @jansel @PaulZhang12 @shunting314

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @mlazos

Copy link

pytorch-bot bot commented Jul 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157743

Note: Links to docs will display an error until the docs builds have been completed.

❌ 14 New Failures

As of commit 68cfd8e with merge base ca7315c (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Jul 7, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

@@ -1197,6 +1197,9 @@ class triton:
# For best results, this should be used with prefer_nd_tiling.
tile_reductions: bool = False

# Codegen matmul natively with tl.dot without calling template.
enable_native_matmul: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the PR to turn this on so we can do a full CI run with it enabled to check for bugs?

(After CI is passing we can turn it off again)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve just enabled it, but I’m not fully confident about the performance due to the potential overhead from the reshape and transpose operations. Back in March, Triton compiler didn’t handle these operations efficiently, which resulted in slower performance. To work around this, I had to modify Inductor to emit alternative code—which I had originally planned to include in a follow-up PR.

tmp0 = tl.load(in_ptr0 + (r0_2 + 128 * y0), r0_mask & ymask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(in_ptr1 + (x1 + 128 * r0_2), r0_mask & xmask, eviction_policy='evict_last', other=0.0)
tmp2 = tl.dot(tl.reshape(tmp0, [YBLOCK, R0_BLOCK]), tl.trans(tl.reshape(tmp1, [XBLOCK, R0_BLOCK])), allow_tf32=False)

@jansel
Copy link
Contributor

jansel commented Jul 8, 2025

I haven't looked at this super carefully yet, but I kicked off a benchmark run with it enabled here:
https://github.com/pytorch/pytorch/actions/runs/16134785066

It should show up in the dropdown (nullplay_fuse_matmul) here once the jobs finishes:
https://hud.pytorch.org/benchmark/compilers

@jerryzh168 jerryzh168 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 8, 2025
@nullplay
Copy link
Collaborator Author

I noticed that when doing torch.float16 matmuls, it was automatically upcasting to float32. Disabling config.triton.codegen_upcast_to_fp32 made things faster. I'm not sure what effect this might have on other parts of the code, but I’ve set config.triton.codegen_upcast_to_fp32 = False for now.

I fixed a few bugs and pushed the changes again. Could you re-run the CI and performance benchmarks?

Just to confirm—there’s no way for me to trigger the CI myself, right? Or is there a way to run the tests locally on my end?

Copy link

pytorch-bot bot commented Jul 12, 2025

To add the ciflow label ciflow/inductor please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@jansel
Copy link
Contributor

jansel commented Jul 12, 2025

I fixed a few bugs and pushed the changes again. Could you re-run the CI and performance benchmarks?

Something is odd with CI (in this PR and a few others). I don't see any jobs to approve.

There is also a merge conflict. Can you rebase? That will hopefully fix the CI issue.

I noticed that when doing torch.float16 matmuls, it was automatically upcasting to float32. Disabling config.triton.codegen_upcast_to_fp32 made things faster. I'm not sure what effect this might have on other parts of the code, but I’ve set config.triton.codegen_upcast_to_fp32 = False for now.

This is to match what eager pytorch does for pointwise ops. Most of those ops are memory bound so the upcast to fp32 doesn't matter for performance. For matmuls that won't work. We should modify the upcast logic to not apply to matmuls.

Just to confirm—there’s no way for me to trigger the CI myself, right? Or is there a way to run the tests locally on my end?

I just asked to add permissions for you to trigger CI yourself.

You should be able to run tests locally. Failing tests should print out the repro command and the benchamrks are all in the pytorch/benchmarks folder.

@jansel
Copy link
Contributor

jansel commented Jul 12, 2025

You should have access to start CI now. I kicked off another benchmark run here: https://github.com/pytorch/pytorch/actions/runs/16242184585

@nullplay nullplay force-pushed the fuse_matmul branch 2 times, most recently from cfed28d to 1793aec Compare August 1, 2025 00:05
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Aug 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: inductor triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants