Skip to content

[Inductor] addmm + activation function fusion #158137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

AaronWang04
Copy link
Contributor

@AaronWang04 AaronWang04 commented Jul 11, 2025

PR implements a pass in post_grad to fuse activation(add + mm)

This was previously done similarly here #106912 but was reverted for performance reasons. it was replaced with a pass that unfuses the activation and add from addmm/addmm_activation and let inductor handle the fusion.

however since then cuBLAS team has made a lot of perf improvements on this, will update this post with more benchmarks but preliminary benchmark show good results

perf dash board
Screenshot from 2025-08-07 13-41-35

Relu works with both training and inference but gelu only works with inference mode due to some fundamental limitations since gelu's derivative depends on input and relu's doesnt. don't think this is fixable with the current addmm_activation API

Graph module before and after this pass

Relu(addmm)

graph():
    %primals_1 : [num_users=1] = placeholder[target=primals_1]
    %primals_2 : [num_users=2] = placeholder[target=primals_2]
    %primals_3 : [num_users=2] = placeholder[target=primals_3]
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%primals_1, %primals_3, %primals_2), kwargs = {})
    %relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%addmm,), kwargs = {})
    %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu, 0), kwargs = {})
    %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%primals_3, [1, 0]), kwargs = {})
    return (relu, primals_2, le, permute_1)
graph():
    %primals_1 : [num_users=1] = placeholder[target=primals_1]
    %primals_2 : [num_users=2] = placeholder[target=primals_2]
    %primals_3 : [num_users=2] = placeholder[target=primals_3]
    %_addmm_activation_default : [num_users=2] = call_function[target=torch.ops.aten._addmm_activation.default](args = (%primals_1, %primals_3, %primals_2), kwargs = {})
    %le : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%_addmm_activation_default, 0), kwargs = {})
    %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%primals_3, [1, 0]), kwargs = {})
    return (_addmm_activation_default, primals_2, le, permute_1)

Gelu (addmm)

graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %addmm : [num_users=4] = call_function[target=torch.ops.aten.addmm.default](args = (%arg0_1, %arg2_1, %arg1_1), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, %addmm), kwargs = {})
    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul, %addmm), kwargs = {})
    %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_1, 0.044715), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%addmm, %mul_2), kwargs = {})
    %mul_3 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 0.7978845608028654), kwargs = {})
    %mul_4 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, 0.5), kwargs = {})
    %tanh : [num_users=1] = call_function[target=torch.ops.aten.tanh.default](args = (%mul_3,), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%tanh, 1), kwargs = {})
    %mul_5 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_4, %add_1), kwargs = {})
    return (mul_5,)
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %_addmm_activation_default : [num_users=1] = call_function[target=torch.ops.aten._addmm_activation.default](args = (%arg0_1, %arg2_1, %arg1_1), kwargs = {use_gelu: True})
    return (_addmm_activation_default,)

Benchmark setup:
NGC pytorch 25.06 container
cublas version: 12.9.1.4
torch.compile ran with dynamic = False and max_autotune

H100

Testing with M=1024, N=1024, K=1024, dtype=bfloat16
============================================================
Average Time per Iteration (cublas):	 0.0107 ms
Average Time per Iteration (torch compile):	 0.0296 ms

============================================================
Testing with M=2048, N=2048, K=2048, dtype=bfloat16
============================================================
Average Time per Iteration (cublas):	 0.0262 ms
Average Time per Iteration (torch compile):	 0.0327 ms

============================================================
Testing with M=4096, N=4096, K=4096, dtype=bfloat16
============================================================
Average Time per Iteration (cublas):	 0.1763 ms
Average Time per Iteration (torch compile):	 0.2457 ms

============================================================
Testing with M=8192, N=8192, K=8192, dtype=bfloat16
============================================================
Average Time per Iteration (cublas):	 1.5280 ms
Average Time per Iteration (torch compile):	 1.9437 ms

A100

############################################################
Testing with dtype: float16
############################################################

============================================================
Testing with M=1024, N=1024, K=1024, dtype=float16
============================================================
Average Time per Iteration (cublas):	 0.0313 ms
Average Time per Iteration (torch compile):	 0.0643 ms

============================================================
Testing with M=2048, N=2048, K=2048, dtype=float16
============================================================
Average Time per Iteration (cublas):	 0.1149 ms
Average Time per Iteration (torch compile):	 0.1255 ms

============================================================
Testing with M=4096, N=4096, K=4096, dtype=float16
============================================================
Average Time per Iteration (cublas):	 0.6297 ms
Average Time per Iteration (torch compile):	 0.7547 ms

============================================================
Testing with M=8192, N=8192, K=8192, dtype=float16
============================================================
Average Time per Iteration (cublas):	 4.3821 ms
Average Time per Iteration (torch compile):	 5.0740 ms

Script

import torch
torch.manual_seed(0)

warmup, numrun= 10, 100

sizes = [1024, 2048, 4096, 8192]
dtypes = [torch.float16, torch.bfloat16, torch.float32]

device = torch.device("cuda")

for dtype in dtypes:
    dtype_name = str(dtype).split('.')[-1] 
    print(f"\n{'#'*60}")
    print(f"Testing with dtype: {dtype_name}")
    print(f"{'#'*60}")
    
    for size in sizes:
        M, N, K = size, size, size
        print(f"\n{'='*60}")
        print(f"Testing with M={M}, N={N}, K={K}, dtype={dtype_name}")
        print(f"{'='*60}")
        
        A = torch.randn(M, K, device=device, dtype=dtype)
        B = torch.randn(K, N, device=device, dtype=dtype)
        C = torch.randn(M, device=device, dtype=dtype)

        def func1():
            return torch._addmm_activation(C, A, B, use_gelu=True)

        def func2():
            return torch.nn.functional.gelu(torch.add(C, torch.mm(A, B)), approximate="tanh")

        func2_compiled = torch.compile(
            func2,
            dynamic=False, 
            options={
                "force_disable_caches": True,
                "max_autotune": True,
                "max_autotune_gemm": True,
                "max_autotune_gemm_backends": "TRITON",
                "autotune_fallback_to_aten": False,
            }
        )

        for _ in range(warmup): func1()
        torch.cuda.synchronize(device=device)

        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)

        total_time_ms = 0.0
        start_event.record()
        for _ in range(numrun): func1()
        end_event.record()
        torch.cuda.synchronize(device=device)
        total_time_ms += start_event.elapsed_time(end_event)
        avg_time_ms = total_time_ms / numrun

        print(f"Average Time per Iteration (cublas):\t {avg_time_ms:.4f} ms")

        for _ in range(warmup): func2_compiled()
        torch.cuda.synchronize(device=device)

        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)

        total_time_ms = 0.0
        start_event.record()
        for _ in range(numrun): func2_compiled()
        end_event.record()
        torch.cuda.synchronize(device=device)
        total_time_ms += start_event.elapsed_time(end_event)
        avg_time_ms = total_time_ms / numrun

        print(f"Average Time per Iteration (torch compile):\t {avg_time_ms:.4f} ms")

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

Copy link

pytorch-bot bot commented Jul 11, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158137

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 4708fd0 with merge base 195b5c2 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@AaronWang04
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Jul 11, 2025
@eqy eqy requested a review from eellison July 11, 2025 23:58
@eqy
Copy link
Collaborator

eqy commented Jul 11, 2025

needs some benchmarks comparing against existing Triton fusions

@AaronWang04 AaronWang04 marked this pull request as ready for review July 12, 2025 00:04
@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 14, 2025
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for PR. A couple comments.

Comment on lines 672 to 675
def addmm_gelu_pattern(input, mat1, mat2):
output = aten.mm(mat1, mat2)
output = aten.add()
return aten.gelu(output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it also worth adding a pattern that targets addmm, instead of mm? I imagine most of the addmms will not be decomposed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the addmms that have an activation after it will be decomposed which is what we are interested in pattern matching

@eellison eellison self-requested a review July 17, 2025 21:58
@eellison
Copy link
Contributor

@AaronWang04 re-request when ready..

@eellison eellison removed their request for review July 22, 2025 15:18
@AaronWang04
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased addmm_activation_fusion onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout addmm_activation_fusion && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the addmm_activation_fusion branch from bbb1083 to 0f9acc5 Compare July 22, 2025 17:32
@AaronWang04 AaronWang04 requested a review from eellison July 23, 2025 21:06
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you have a chance to run any dashboards, just curious ? Anyway looks good but it would be nice if you could use the gen_register_replacement api to avoid compile time overhead. It's more required for training patterns but still nice either way..

args_bf16 = [torch.empty(shape, dtype=torch.bfloat16) for shape in shapes]

for pattern in [addmm_relu_pattern, addmm_relu_pattern_2]:
register_replacement(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we are parameterizing this across 4 total patterns.. Could I trouble you to pre-register the pattern ? See gen_register_replacement:

## Precompiled Patterns

New patterns are added using register_replacement(). Patterns added in this way
can have a compile-time overhead because they need to be traced before
use. Patterns can be precompiled and added using gen_register_replacement()
instead. To do this you call gen_register_replacement() instead of
register_replacement(). The arguments are the same except for an additional
unique name which is used as a lookup key.

And https://github.com/pytorch/pytorch/blob/main/torchgen/fuse/gen_patterns.py.

def addmm_gelu_pattern_2(input, mat1, mat2):
output = aten.mm(mat1, mat2)
output = aten.add(input, output)
return aten.gelu(output, approximate="tanh")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's kind of unfortunate cublas only support "tanh"... which we dont use by default in pytorch.

@eqy
Copy link
Collaborator

eqy commented Jul 25, 2025

@eellison For my edification, what's the dashboard, is it ciflow/inductor-perf-compare?

@eellison
Copy link
Contributor

@AaronWang04
Copy link
Contributor Author

@eellison the dashboard ran, results are under branch "AaronWang04_addmmfusion_perftest"

not sure where the slow downs are from, whether it is weird addmm_activation shapes that are slower than triton or if the benchmarks have high variance

@AaronWang04
Copy link
Contributor Author

AaronWang04 commented Aug 6, 2025

@eellison Ran the dashboard again with gen_register_replacement

Pasted an image of it on the top comment

@AaronWang04
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased addmm_activation_fusion onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout addmm_activation_fusion && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the addmm_activation_fusion branch from 7ec007a to 4708fd0 Compare August 8, 2025 03:51
@eqy
Copy link
Collaborator

eqy commented Aug 8, 2025

@pytorchmergebot label ciflow/trunk

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 8, 2025
@AaronWang04
Copy link
Contributor Author

AaronWang04 commented Aug 8, 2025

@eellison This PR is failing a newly added run time estimation test from #159730. From our initial investigation, the functions for estimating flops doesnt work well with extern kernels.

This self.node.get_origin_node() returns mm_default in the case of a triton fusion and wait_tensor in the case of an extern kernel (addmm_activation) for this test.

Is this an issue with the register_replacement machinery or something else?

@cache_on_self
def estimate_flops(self) -> int | None:
if self.node is None:
return None
fx_node = self.node.get_origin_node()

cc @skarjala

@eellison
Copy link
Contributor

eellison commented Aug 8, 2025

@AaronWang04 what do you mean wait_tensor in the case of an extern kernel (addmm_activation) for this test ?

What is erroring exactly ? Does adding a formula for addmm_activation in

from .module_tracker import ModuleTracker
fix it ?

@skarjala
Copy link
Contributor

skarjala commented Aug 8, 2025

@eellison This PR is failing a newly added run time estimation test from #159730. From our initial investigation, the functions for estimating flops doesnt work well with extern kernels.

This self.node.get_origin_node() returns mm_default in the case of a triton fusion and wait_tensor in the case of an extern kernel (addmm_activation) for this test.

Is this an issue with the register_replacement machinery or something else?

@cache_on_self
def estimate_flops(self) -> int | None:
if self.node is None:
return None
fx_node = self.node.get_origin_node()

cc @skarjala

I'll fix the test

@AaronWang04
Copy link
Contributor Author

AaronWang04 commented Aug 8, 2025

@eellison no adding a formula does not fix it because it doesn't even reach there.

for the estimate flop function, it gets the "origin_node" to count the flops. in the triton case get_origin_node returns a mm node whereas with this PR added it incorrectly returns a wait_tensor node

So here its trying to count flops from a wait_tensor node instead of mm which returns 0 and fails the test

@AaronWang04
Copy link
Contributor Author

@eellison This PR is failing a newly added run time estimation test from #159730. From our initial investigation, the functions for estimating flops doesnt work well with extern kernels.
This self.node.get_origin_node() returns mm_default in the case of a triton fusion and wait_tensor in the case of an extern kernel (addmm_activation) for this test.
Is this an issue with the register_replacement machinery or something else?

@cache_on_self
def estimate_flops(self) -> int | None:
if self.node is None:
return None
fx_node = self.node.get_origin_node()

cc @skarjala

I'll fix the test

I think its not really the test's problem, more so with the flop counting infrastructure

However, the easy way to patch this test for now is simply adding another pointwise op so the addmm_activation fusion does not trigger.

def forward(self, x):
    h = self.linear(x)
    h = torch.relu(h)
    h = torch.relu(h)

    h = torch.ops._c10d_functional.all_reduce.default(h, "sum", "0")
    h = torch.ops._c10d_functional.wait_tensor.default(h)

Maybe a good solution right now is to patch this current test with this ^ and add another test that xfails for the extern fusion case so this issue gets circuled back later?

@eellison
Copy link
Contributor

eellison commented Aug 8, 2025

cc @exclamaforte

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request matrix multiplication module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants