Skip to content

[inductor] Improve GEMM logging to display batch size for batched operations #155544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

penknife6153
Copy link
Contributor

@penknife6153 penknife6153 commented Jun 10, 2025

Improves the GEMM overview logging in PyTorch Inductor to properly display batch size information for batched matrix operations like torch.bmm and torch.baddbmm.

Fixes #155307

# Repro
import os
os.environ["TORCH_LOGS"] = "inductor"
import torch

M, N, K = 1024, 1024, 1024
dtype = torch.bfloat16
A = torch.randn(10, M, K, device="cuda", dtype=dtype)
B = torch.randn(10, K, N, device="cuda", dtype=dtype)

compiled_model = torch.compile(torch.bmm, fullgraph=True)
_ = compiled_model(A, B)

Before:

Name                 | M                    | N                    | K                    | Count               
----------------------------------------------------------------------------------------------------
aten.bmm             | 1024                 | 1024                 | 1024                 | 1                   
----------------------------------------------------------------------------------------------------

The batch size (10) is missing from the logs, making it unclear what the actual operation dimensions were.

After:

Name                 | B                    | M                    | N                    | K                    | Count               
---------------------------------------------------------------------------------------------------------------------------
aten.bmm             | 10                   | 1024                 | 1024                 | 1024                 | 1                   
---------------------------------------------------------------------------------------------------------------------------

Changes Made

  • compile_fx.py:

    • Detects batched operations by checking if operation name ends with 'bmm' or 'baddbmm'
    • For batched operations: takes last 4 parts as batch, m, n, k
    • For non-batched operations: takes last 3 parts as m, n, k
    • Added separate column for batch size
    • Shows batch size for batched operations, shows "-" for non-batched operations
  • bmm.py:

    • Extract batch size from mat1.get_size()[0] for both tuned_bmm and tuned_baddbmm
    • Use positional counter keys: aten.bmm_{batch_size}_{m}_{n}_{k}
    • Enhanced log messages to include batch size information

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @henrylhtsang

penknife6153 and others added 2 commits June 6, 2025 10:15
Changes:
- Extract batch size from input tensor shape in tuned_bmm()
- Include batch size in log messages: "batch=B, m=M, n=N, k=K"  
- Update counter keys to include batch info for better tracking
- Apply same enhancement to tuned_baddbmm() for consistency
…rations

The GEMM overview table in inductor logs was missing batch size information
for batched matrix operations like torch.bmm and torch.baddbmm. This made it
difficult to distinguish between different batched operations with the same
M, N, K dimensions but different batch sizes.

Changes:
- Updated counter key format in kernel files to use prefixed values
  (e.g., "aten.bmm_b10_m1024_n1024_k1024" instead of "aten.bmm_10_1024_1024_1024")
- Enhanced parsing logic in compile_fx.py to handle both new prefixed format
  and legacy format for backward compatibility
- Added batch size display in overview table for batched operations
  (e.g., "aten.bmm (B=10)" instead of just "aten.bmm")
- Increased table width to accommodate batch size information

Before:
```
Name                 | M    | N    | K    | Count
aten.bmm            | 1024 | 1024 | 1024 | 1
```

After:
```
Name                           | M    | N    | K    | Count
aten.bmm (B=10)               | 1024 | 1024 | 1024 | 1
```

This provides clearer visibility into batched GEMM operations while maintaining
backward compatibility with existing counter formats.

Fixes pytorch#155307
Copy link

pytorch-bot bot commented Jun 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155544

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 1c3bfb3 with merge base b0fbbef (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Jun 10, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

@penknife6153
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 10, 2025
@zou3519 zou3519 requested review from jansel and BoyuanFeng June 10, 2025 13:33
@jansel jansel added ciflow/trunk Trigger trunk jobs on your pull request ciflow/inductor labels Jun 10, 2025
@henrylhtsang henrylhtsang self-requested a review June 10, 2025 16:36
@pytorch-bot pytorch-bot bot removed ciflow/trunk Trigger trunk jobs on your pull request ciflow/inductor labels Jun 10, 2025
@BoyuanFeng BoyuanFeng added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 10, 2025
@henrylhtsang
Copy link
Contributor

@penknife6153 can you merge with @ pytorchbot merge

@penknife6153
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[inductor] Improve GEMM loggings for torch.bmm
7 participants