-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[inductor] Improve GEMM logging to display batch size for batched operations #155544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Changes: - Extract batch size from input tensor shape in tuned_bmm() - Include batch size in log messages: "batch=B, m=M, n=N, k=K" - Update counter keys to include batch info for better tracking - Apply same enhancement to tuned_baddbmm() for consistency
…rations The GEMM overview table in inductor logs was missing batch size information for batched matrix operations like torch.bmm and torch.baddbmm. This made it difficult to distinguish between different batched operations with the same M, N, K dimensions but different batch sizes. Changes: - Updated counter key format in kernel files to use prefixed values (e.g., "aten.bmm_b10_m1024_n1024_k1024" instead of "aten.bmm_10_1024_1024_1024") - Enhanced parsing logic in compile_fx.py to handle both new prefixed format and legacy format for backward compatibility - Added batch size display in overview table for batched operations (e.g., "aten.bmm (B=10)" instead of just "aten.bmm") - Increased table width to accommodate batch size information Before: ``` Name | M | N | K | Count aten.bmm | 1024 | 1024 | 1024 | 1 ``` After: ``` Name | M | N | K | Count aten.bmm (B=10) | 1024 | 1024 | 1024 | 1 ``` This provides clearer visibility into batched GEMM operations while maintaining backward compatibility with existing counter formats. Fixes pytorch#155307
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155544
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 1c3bfb3 with merge base b0fbbef ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "topic: not user facing" |
@penknife6153 can you merge with |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…rations (pytorch#155544) Improves the GEMM overview logging in PyTorch Inductor to properly display batch size information for batched matrix operations like `torch.bmm` and `torch.baddbmm`. **Fixes pytorch#155307** ## Problem The current GEMM logging for `torch.bmm` shows: ```python # Repro import os os.environ["TORCH_LOGS"] = "inductor" import torch M, N, K = 1024, 1024, 1024 dtype = torch.bfloat16 A = torch.randn(10, M, K, device="cuda", dtype=dtype) B = torch.randn(10, K, N, device="cuda", dtype=dtype) compiled_model = torch.compile(torch.bmm, fullgraph=True) _ = compiled_model(A, B) ``` **Before:** ``` Name | M | N | K | Count ---------------------------------------------------------------------------------------------------- aten.bmm | 1024 | 1024 | 1024 | 1 ---------------------------------------------------------------------------------------------------- ``` The batch size (10) is missing from the logs, making it unclear what the actual operation dimensions were. ## Solution **After:** ``` Name | B | M | N | K | Count ---------------------------------------------------------------------------------------------------------------------------------- aten.bmm | 10 | 1024 | 1024 | 1024 | 1 aten.mm | - | 1024 | 1024 | 1024 | 2 ---------------------------------------------------------------------------------------------------------------------------------- ``` ## Changes Made ### 1. Enhanced Parsing Logic in compile_fx.py - Detects batched operations by checking if operation name ends with `'bmm'` or `'baddbmm'` - For batched operations: takes last 4 parts as `batch, m, n, k` - For non-batched operations: takes last 3 parts as `m, n, k` - **Dedicated "B" column**: Added separate column for batch size instead of embedding in operation name - Shows batch size for batched operations, shows "-" for non-batched operations ### 2. Updated All MM Operations for Consistency - **bmm.py**: - Extract batch size from `mat1.get_size()[0]` for both `tuned_bmm` and `tuned_baddbmm` - Use positional counter keys: `aten.bmm_{batch_size}_{m}_{n}_{k}` - Enhanced log messages to include batch size information - **mm.py**: Updated counter keys for consistency: - `aten.mm_{m}_{n}_{k}` (no batch dimension) - `aten.addmm_{m}_{n}_{k}` (no batch dimension) - `aten._int_mm_{m}_{n}_{k}` (no batch dimension) - `aten._scaled_mm.default_{m}_{n}_{k}` (no batch dimension) Pull Request resolved: pytorch#155544 Approved by: https://github.com/jansel, https://github.com/BoyuanFeng
Improves the GEMM overview logging in PyTorch Inductor to properly display batch size information for batched matrix operations like
torch.bmm
andtorch.baddbmm
.Fixes #155307
Before:
The batch size (10) is missing from the logs, making it unclear what the actual operation dimensions were.
After:
Changes Made
compile_fx.py:
'bmm'
or'baddbmm'
batch, m, n, k
m, n, k
bmm.py:
mat1.get_size()[0]
for bothtuned_bmm
andtuned_baddbmm
aten.bmm_{batch_size}_{m}_{n}_{k}
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @henrylhtsang