Skip to content

Update the heuristic for AArch64 bmm/baddbmm #149122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

michalowski-arm
Copy link
Contributor

@michalowski-arm michalowski-arm commented Mar 13, 2025

Updates heuristic for bmm/baddbmm and consolidates all heuristic logic in a single location

  • The goal of the consolidation is to improve maintainability and readability of the heuristic logic. Instead of different parts scattered across two files, this patch centralizes everything inside Matmul.cpp, where there already exists heuristic-based selection for mkldnn.
  • The logic of the check itself doesn't change (existing code is reused where possible) but a separate heuristic threshold for bmm/baddbmm is introduced based on newer, benchmarking data. Use the script below to see the performance improvement for bmm from the new heuristic:
import torch
import time

# Set below to True to use cases selected by only one of the hueristics.
USE_ONLY_DIVERGENT_TEST_CASES = True   
BATCH_SIZES  = [ 1, 8, 32, 64, 128, 256 ]
M_DIMS       = [ 4, 8, 16, 32, 64, 256, 512 ]
N_DIMS       = [ 4, 8, 16, 32, 64, 256, 512 ]
K_DIMS       = [ 4, 8, 16, 32, 64, 256, 512 ]
ITERS = 50

def old_heuristic(m, n, k):
    is_above_min_dims = m > 8 and n > 8 and k > 8
    is_above_min_size = m*n*k > 8_192
    return is_above_min_dims and is_above_min_size

def new_heuristic(b, m, n, k):
    return b*b*m*n*k >= 4_194_304

def generate_test_cases():
   test_cases = []
   for b in BATCH_SIZES:
       for m in M_DIMS:
           for n in N_DIMS:
                   for k in K_DIMS:
                       if USE_ONLY_DIVERGENT_TEST_CASES:
                           if old_heuristic(m, n, k) != new_heuristic(b, m, n, k):
                               test_cases.append([b, m, n, k])
                       else:
                           test_cases.append([b, m, n, k])
   return test_cases

def test(x, y):
   for _ in range(5):
       torch.bmm(x, y)
   perf = 0.0
   for _ in range(ITERS):
       start = time.time()
       torch.bmm(x, y)
       end = time.time()
       perf += (end - start) / ITERS
   return perf

def main():
   print(f"{'b':<10}{'m':<10}{'n':<10}{'k':<10}{'time (s)':10}")
   cumulative_mean_time = 0.0
   for b, m, n, k in generate_test_cases():
       mean_time = test(torch.rand(b, m, n), torch.rand(b, n, k))
       cumulative_mean_time += mean_time
       print(f"{b:<10}{m:<10}{n:<10}{k:<10}{mean_time:10.3e}")
   print(f"Cumulative mean time = {cumulative_mean_time:.4f} s")


if __name__ == "__main__":
   main()

From the script we see that cumulative mean time from all test cases (at 16 threads) is:

  • 1.6195 s for the old heuristic
  • 0.7012 s for the new heuristic

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @malfet @snadampal @milpuz01 @aditew01 @nikhil-arm @fadara01

Copy link

pytorch-bot bot commented Mar 13, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149122

Note: Links to docs will display an error until the docs builds have been completed.

⏳ 41 Pending, 1 Unrelated Failure

As of commit 53a6da1 with merge base a53d14d (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: cpu CPU specific problem (e.g., perf, algorithm) release notes: linalg_frontend release notes category labels Mar 13, 2025
@malfet malfet added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 labels Mar 13, 2025
Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide a bit more description on what this consolidation is trying to do?
Or is it just a better engineering change?
If it changes the semantic of the check, would be nice to have some sort of unit test added, if it's a performance only change would be nice to have script in PR description that one can run to observe those perf changes

@michalowski-arm
Copy link
Contributor Author

Can you provide a bit more description on what this consolidation is trying to do? Or is it just a better engineering change? If it changes the semantic of the check, would be nice to have some sort of unit test added, if it's a performance only change would be nice to have script in PR description that one can run to observe those perf changes

Added some extra information in the PR description, let me know if anything more is needed.

@aditew01 aditew01 requested a review from malfet March 19, 2025 09:31
#if defined(__aarch64__)
const int64_t mkldnn_acl_bmm_baddbmm_threshold = get_mkldnn_acl_bmm_baddbmm_threshold();
// BATCH_SIZE^2 * M * N * K >= THRESHOLD
return mat1.size(0) * mat1.size(0) * mat1.size(1) * mat1.size(2) * mat2.size(2) >= mkldnn_acl_bmm_baddbmm_threshold;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Neoverse cores, mkldnn_acl_bmm_baddbmm_threshold == 1L<<22
Does this imply we don't want to dispatch to MKLDNN for aten::baddbmm or aten::bmm cases?

The script attached only tests for float32 cases. Did we test the performance for float16 / bfloat16?

Copy link
Collaborator

@fadara01 fadara01 Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Neoverse cores, mkldnn_acl_bmm_baddbmm_threshold == 1L<<22
Does this imply we don't want to dispatch to MKLDNN for aten::baddbmm or aten::bmm cases?

1L<<22 is 4,194,304 which is not as large as it seems.
for K, N = 1024 (which is not large by today's standards), we'll go to oneDNN/ACL with M=4 and batch size = 1

@@ -322,6 +322,42 @@ void mkldnn_matmul(

}

#if defined(__aarch64__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be guarded behind USE_MKLDNN_ACL or something like that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@fadara01
Copy link
Collaborator

fadara01 commented Apr 7, 2025

@pytorchbot label "ciflow/linux-aarch64"

@pytorch-bot pytorch-bot bot added the ciflow/linux-aarch64 linux aarch64 CI workflow label Apr 7, 2025
@michalowski-arm
Copy link
Contributor Author

Some more performance data

  1. Running the script from the description with dtype=torch.bfloat16 we get 11.1x speedup (no change for torch.float16)
  2. Running t5-small, t5-base and t5-large models with dtype=torch.float32 at 16 threads:
=== Performance Comparison: Old vs New Heuristic (t5-small) ===
Batch Size | Model Speedup (%) | BMM Speedup (%) |   Old BMM CPU % |   New BMM CPU %
-------------------------------------------------------------------------------------
         1 |              0.70 |            0.00 |            3.21 |            3.23
         2 |             -1.47 |            1.05 |            5.03 |            4.91
         4 |              0.92 |            5.97 |            7.84 |            7.44
         8 |              4.28 |            9.75 |           11.63 |           10.97
        16 |             36.45 |           85.63 |           10.92 |            2.47
        32 |             48.33 |           90.33 |           14.85 |            2.78
        64 |             58.10 |           94.50 |           19.29 |            2.53
       128 |             69.77 |           96.93 |           21.09 |            2.14
       256 |             70.06 |           97.67 |           22.05 |            1.71
-------------------------------------------------------------------------------------
=== Performance Comparison: Old vs New Heuristic (t5-base) ===
Batch Size | Model Speedup (%) | BMM Speedup (%) |   Old BMM CPU % |   New BMM CPU %
-------------------------------------------------------------------------------------
         1 |             -3.39 |           -9.53 |            4.38 |            4.64
         2 |             -3.03 |           -5.84 |            6.00 |            6.17
         4 |             -0.41 |            3.05 |            9.41 |            9.09
         8 |             41.05 |           85.02 |           12.97 |            3.29
        16 |             44.94 |           88.77 |           13.24 |            2.70
        32 |             53.44 |           93.33 |           17.73 |            2.54
        64 |             65.31 |           96.36 |           19.65 |            2.06
       128 |             71.74 |           97.70 |           20.75 |            1.69
       256 |             71.65 |           98.22 |           21.30 |            1.34
-------------------------------------------------------------------------------------
=== Performance Comparison: Old vs New Heuristic (t5-large) ===
Batch Size | Model Speedup (%) | BMM Speedup (%) |   Old BMM CPU % |   New BMM CPU %
-------------------------------------------------------------------------------------
         1 |              0.51 |           -3.64 |            5.20 |            5.42
         2 |              2.66 |            4.22 |            6.74 |            6.63
         4 |              5.15 |            8.84 |            9.82 |            9.43
         8 |             44.21 |           85.09 |           13.58 |            3.63
        16 |             47.11 |           90.62 |           15.40 |            2.73
        32 |             57.76 |           94.74 |           17.58 |            2.19
        64 |             66.74 |           96.99 |           19.28 |            1.74
       128 |             70.75 |           97.94 |           20.13 |            1.42
       256 |             68.83 |           98.20 |           20.61 |            1.19
-------------------------------------------------------------------------------------
  1. Running t5-small and t5-base models with ONEDNN_DEFAULT_FPMATH_MODE=BF16 at 16 threads:
=== Performance Comparison: Old vs New Heuristic (t5-small) ===
Batch Size | Model Speedup (%) | BMM Speedup (%) |   Old BMM CPU % |   New BMM CPU %
-------------------------------------------------------------------------------------
         1 |             -2.34 |           -3.67 |            3.69 |            3.74
         2 |             -1.30 |           -4.07 |            5.45 |            5.60
         4 |              0.82 |            9.57 |            9.01 |            8.22
         8 |              9.69 |           24.32 |           13.43 |           11.25
        16 |             56.74 |           91.30 |           17.83 |            3.58
        32 |             70.97 |           94.43 |           21.48 |            4.12
        64 |             80.46 |           96.38 |           23.81 |            4.41
       128 |             87.78 |           97.57 |           25.62 |            5.10
-------------------------------------------------------------------------------------
=== Performance Comparison: Old vs New Heuristic (t5-base) ===
Batch Size | Model Speedup (%) | BMM Speedup (%) |   Old BMM CPU % |   New BMM CPU %
-------------------------------------------------------------------------------------
         1 |              4.55 |           -2.32 |            3.82 |            4.10
         2 |              3.20 |            7.15 |            5.64 |            5.41
         4 |              3.72 |            7.05 |            9.20 |            8.88
         8 |             47.54 |           86.87 |           14.34 |            3.59
        16 |             63.61 |           93.27 |           18.58 |            3.44
        32 |             75.78 |           95.62 |           21.81 |            3.95
        64 |             84.57 |           97.02 |           24.09 |            4.66
       128 |             89.44 |           97.69 |           25.36 |            5.54
-------------------------------------------------------------------------------------

@pytorch-bot pytorch-bot bot removed the ciflow/linux-aarch64 linux aarch64 CI workflow label Apr 8, 2025
@aditew01 aditew01 added the ciflow/linux-aarch64 linux aarch64 CI workflow label Apr 10, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/linux-aarch64 linux aarch64 CI workflow label Apr 22, 2025
@fadara01
Copy link
Collaborator

@pytorchbot label "ciflow/linux-aarch64"

@pytorch-bot pytorch-bot bot added the ciflow/linux-aarch64 linux aarch64 CI workflow label Apr 29, 2025
@fadara01
Copy link
Collaborator

@malfet @aditew01 can you guys give this another look?

fadara01
fadara01 previously approved these changes Apr 29, 2025
aditew01
aditew01 previously approved these changes May 6, 2025
Copy link
Collaborator

@aditew01 aditew01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@pytorch-bot pytorch-bot bot removed the ciflow/linux-aarch64 linux aarch64 CI workflow label May 7, 2025
malfet
malfet previously approved these changes May 13, 2025
Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@malfet malfet added the ciflow/trunk Trigger trunk jobs on your pull request label May 13, 2025
@malfet
Copy link
Contributor

malfet commented May 13, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@malfet
Copy link
Contributor

malfet commented May 23, 2025

FYI this change is being reverted right now as causing 50% perf regression...

@jeanschmidt
Copy link
Contributor

@pytorchbot revert -m "breaking internal models, @malfet may you help merge this?" -c ghfirst

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request May 23, 2025
This reverts commit d759a51.

Reverted #149122 on behalf of https://github.com/jeanschmidt due to breaking internal models, @malfet may you help merge this? ([comment](#149122 (comment)))
@pytorchmergebot
Copy link
Collaborator

@michalowski-arm your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels May 23, 2025
@pytorch-bot pytorch-bot bot dismissed stale reviews from fadara01, aditew01, and malfet May 23, 2025 14:55

This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.

@aditew01
Copy link
Collaborator

@malfet @jeanschmidt where can I look for the model break / regression suite to fix the heuristics ?

@michalowski-arm
Copy link
Contributor Author

@malfet @jeanschmidt just pinging this to ask where we can see the perf results again? In my tests there was a general speedup with this heuristic for all except the smallest shapes, so to fix it I'd need to see the shapes being ran.

@michalowski-arm
Copy link
Contributor Author

@malfet @jeanschmidt In this CI run we don't see any regressions: TorchInductor Performance DashBoard. This was for bf16 only, I don't think we can run f32. Are there other tests we are supposed to be running? Also, presumably the regression you saw was for AArch64, or did this change somehow impact x86 performance?

@nikhil-arm
Copy link
Collaborator

Waiting for Meta team to share details on internal regressions

@malfet
Copy link
Contributor

malfet commented Aug 6, 2025

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict pull/149122/head returned non-zero exit code 1

Rebasing (1/2)
Auto-merging aten/src/ATen/native/LinearAlgebra.cpp
CONFLICT (content): Merge conflict in aten/src/ATen/native/LinearAlgebra.cpp
Auto-merging aten/src/ATen/native/mkldnn/Matmul.cpp
CONFLICT (content): Merge conflict in aten/src/ATen/native/mkldnn/Matmul.cpp
error: could not apply b382d29a47a... Update the heuristic for AArch64 bmm/baddbmm
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Could not apply b382d29a47a... # Update the heuristic for AArch64 bmm/baddbmm

Raised by https://github.com/pytorch/pytorch/actions/runs/16790309166

Replaced the defined(__aarch64__) with AT_MKLDNN_ACL_ENABLED() in the
parts of Matmul.cpp relevant to the bmm heuristic
@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Aug 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR Merged module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 module: cpu CPU specific problem (e.g., perf, algorithm) open source release notes: linalg_frontend release notes category Reverted triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants