-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Fuse matmul #157743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fuse matmul #157743
Conversation
|
torch/_inductor/config.py
Outdated
@@ -1197,6 +1197,9 @@ class triton: | |||
# For best results, this should be used with prefer_nd_tiling. | |||
tile_reductions: bool = False | |||
|
|||
# Codegen matmul natively with tl.dot without calling template. | |||
enable_native_matmul: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update the PR to turn this on so we can do a full CI run with it enabled to check for bugs?
(After CI is passing we can turn it off again)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’ve just enabled it, but I’m not fully confident about the performance due to the potential overhead from the reshape
and transpose
operations. Back in March, Triton compiler didn’t handle these operations efficiently, which resulted in slower performance. To work around this, I had to modify Inductor to emit alternative code—which I had originally planned to include in a follow-up PR.
tmp0 = tl.load(in_ptr0 + (r0_2 + 128 * y0), r0_mask & ymask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(in_ptr1 + (x1 + 128 * r0_2), r0_mask & xmask, eviction_policy='evict_last', other=0.0)
tmp2 = tl.dot(tl.reshape(tmp0, [YBLOCK, R0_BLOCK]), tl.trans(tl.reshape(tmp1, [XBLOCK, R0_BLOCK])), allow_tf32=False)
I haven't looked at this super carefully yet, but I kicked off a benchmark run with it enabled here: It should show up in the dropdown (nullplay_fuse_matmul) here once the jobs finishes: |
I approved CI. The benchmark run is done, looks like there are a number of models that are failing: The other benchmark suites can be viewed by selecting "nullplay_fuse_matmul" in the branch dropdown (+ training/inference). |
I noticed that when doing I fixed a few bugs and pushed the changes again. Could you re-run the CI and performance benchmarks? Just to confirm—there’s no way for me to trigger the CI myself, right? Or is there a way to run the tests locally on my end? |
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
Something is odd with CI (in this PR and a few others). I don't see any jobs to approve. There is also a merge conflict. Can you rebase? That will hopefully fix the CI issue.
This is to match what eager pytorch does for pointwise ops. Most of those ops are memory bound so the upcast to fp32 doesn't matter for performance. For matmuls that won't work. We should modify the upcast logic to not apply to matmuls.
I just asked to add permissions for you to trigger CI yourself. You should be able to run tests locally. Failing tests should print out the repro command and the benchamrks are all in the pytorch/benchmarks folder. |
You should have access to start CI now. I kicked off another benchmark run here: https://github.com/pytorch/pytorch/actions/runs/16242184585 |
cfed28d
to
1793aec
Compare
Implementation of #151705
This PR introduces the initial implementation of native
tl.dot
support in Inductor, with the goal of generating Triton matmul kernels directly—without relying on predefined templates.To avoid complexity and ease the review process, I plan to split this work into two phases as outlined in #151705:
Summary of This PR
This PR implements the basic functionality. It does not include lazy broadcasting, so the generated kernels may involve explicit
tl.reshape
andtl.trans
operations before callingtl.dot
, which introduces some overhead.Notable Changes
config.triton.enable_native_matmul
ops.dot
IR node in Inductor and lowersaten.mm
andaten.bmm
to it when native matmul is enabledops.dot
@eellison @jansel @PaulZhang12 @shunting314
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @mlazos