-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[Inductor] addmm + activation function fusion #158137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158137
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 4708fd0 with merge base 195b5c2 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "topic: not user facing" |
needs some benchmarks comparing against existing Triton fusions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for PR. A couple comments.
def addmm_gelu_pattern(input, mat1, mat2): | ||
output = aten.mm(mat1, mat2) | ||
output = aten.add() | ||
return aten.gelu(output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it also worth adding a pattern that targets addmm, instead of mm? I imagine most of the addmms will not be decomposed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the addmms that have an activation after it will be decomposed which is what we are interested in pattern matching
@AaronWang04 re-request when ready.. |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
bbb1083
to
0f9acc5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you have a chance to run any dashboards, just curious ? Anyway looks good but it would be nice if you could use the gen_register_replacement api to avoid compile time overhead. It's more required for training patterns but still nice either way..
args_bf16 = [torch.empty(shape, dtype=torch.bfloat16) for shape in shapes] | ||
|
||
for pattern in [addmm_relu_pattern, addmm_relu_pattern_2]: | ||
register_replacement( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that we are parameterizing this across 4 total patterns.. Could I trouble you to pre-register the pattern ? See gen_register_replacement
:
## Precompiled Patterns
New patterns are added using register_replacement(). Patterns added in this way
can have a compile-time overhead because they need to be traced before
use. Patterns can be precompiled and added using gen_register_replacement()
instead. To do this you call gen_register_replacement() instead of
register_replacement(). The arguments are the same except for an additional
unique name which is used as a lookup key.
And https://github.com/pytorch/pytorch/blob/main/torchgen/fuse/gen_patterns.py.
def addmm_gelu_pattern_2(input, mat1, mat2): | ||
output = aten.mm(mat1, mat2) | ||
output = aten.add(input, output) | ||
return aten.gelu(output, approximate="tanh") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's kind of unfortunate cublas only support "tanh"... which we dont use by default in pytorch.
@eellison For my edification, what's the dashboard, is it |
@eellison the dashboard ran, results are under branch "AaronWang04_addmmfusion_perftest" not sure where the slow downs are from, whether it is weird addmm_activation shapes that are slower than triton or if the benchmarks have high variance |
@eellison Ran the dashboard again with gen_register_replacement Pasted an image of it on the top comment |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
7ec007a
to
4708fd0
Compare
@pytorchmergebot label ciflow/trunk |
@eellison This PR is failing a newly added run time estimation test from #159730. From our initial investigation, the functions for estimating flops doesnt work well with extern kernels. This Is this an issue with the register_replacement machinery or something else? pytorch/torch/_inductor/scheduler.py Lines 794 to 798 in 731ee31
cc @skarjala |
@AaronWang04 what do you mean What is erroring exactly ? Does adding a formula for addmm_activation in pytorch/torch/utils/flop_counter.py Line 4 in 731ee31
|
I'll fix the test |
@eellison no adding a formula does not fix it because it doesn't even reach there. for the estimate flop function, it gets the "origin_node" to count the flops. in the triton case get_origin_node returns a mm node whereas with this PR added it incorrectly returns a wait_tensor node So here its trying to count flops from a wait_tensor node instead of mm which returns 0 and fails the test |
I think its not really the test's problem, more so with the flop counting infrastructure However, the easy way to patch this test for now is simply adding another pointwise op so the addmm_activation fusion does not trigger. def forward(self, x):
h = self.linear(x)
h = torch.relu(h)
h = torch.relu(h)
h = torch.ops._c10d_functional.all_reduce.default(h, "sum", "0")
h = torch.ops._c10d_functional.wait_tensor.default(h) Maybe a good solution right now is to patch this current test with this ^ and add another test that xfails for the extern fusion case so this issue gets circuled back later? |
PR implements a pass in post_grad to fuse activation(add + mm)
This was previously done similarly here #106912 but was reverted for performance reasons. it was replaced with a pass that unfuses the activation and add from addmm/addmm_activation and let inductor handle the fusion.
however since then cuBLAS team has made a lot of perf improvements on this, will update this post with more benchmarks but preliminary benchmark show good results
perf dash board

Relu works with both training and inference but gelu only works with inference mode due to some fundamental limitations since gelu's derivative depends on input and relu's doesnt. don't think this is fixable with the current addmm_activation API
Graph module before and after this pass
Relu(addmm)
Gelu (addmm)
Benchmark setup:
NGC pytorch 25.06 container
cublas version: 12.9.1.4
torch.compile ran with dynamic = False and max_autotune
H100
A100
Script
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben