Skip to content

[inductor] consolidate common GEMM triton param retrieval #159383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

coconutruben
Copy link
Contributor

@coconutruben coconutruben commented Jul 29, 2025

Stack from ghstack (oldest at bottom):

# Why

  • Make loop iteration simpler
  • Have a common spot where to make modifications that affect
    all the GEMM Triton templates, avoiding missed spots

# What

  • pull out commong logic of taking the BaseConfig objects
    and turning them into kwargs to feed into maybe_append_choice
    for Triton GEMM templates

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Differential Revision: D79186962

\# Why

- Make loop iteration simpler
- Have a common spot where to make modifications that affect
  all the GEMM Triton templates, avoiding missed spots

\# What

- pull out commong logic of taking the BaseConfig objects
  and turning them into kwargs to feed into maybe_append_choice
  for Triton GEMM templates

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Jul 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159383

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

⏳ 8 Pending, 3 Unrelated Failures

As of commit 676e185 with merge base 5d89634 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@coconutruben
Copy link
Contributor Author

@coconutruben has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@coconutruben coconutruben added the topic: not user facing topic category label Jul 29, 2025
@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 29, 2025
@coconutruben
Copy link
Contributor Author

coconutruben commented Jul 29, 2025

@jansel context this is the same as #158015 - there was an issue last week, we reverted externally before it landed internally, and fixed the oversights. To avoid any weird sync behavior now, the suggestion was to make this a fresh PR/diff pair, and just take it from here.

@coconutruben
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

yangw-dev pushed a commit that referenced this pull request Aug 1, 2025
\# Why

- Make loop iteration simpler
- Have a common spot where to make modifications that affect
  all the GEMM Triton templates, avoiding missed spots

\# What

- pull out commong logic of taking the BaseConfig objects
  and turning them into kwargs to feed into maybe_append_choice
  for Triton GEMM templates

Differential Revision: [D79186962](https://our.internmc.facebook.com/intern/diff/D79186962)
Pull Request resolved: #159383
Approved by: https://github.com/jansel
@jataylo
Copy link
Collaborator

jataylo commented Aug 1, 2025

https://hud.pytorch.org/pytorch/pytorch/commit/e7cc42df58a86bee05944f6e80c535aa1d099443

This seems to have caused breakages on rocm

2025-07-31T14:12:45.0144510Z AUTOTUNE mm_plus_mm(10x40, 40x30, 10x40, 40x30)
2025-07-31T14:12:45.0144635Z strides: [40, 1], [30, 1], [40, 1], [30, 1]
2025-07-31T14:12:45.0144786Z dtypes: torch.float32, torch.float32, torch.float32, torch.float32
2025-07-31T14:12:45.0145143Z   triton_mm_plus_mm_8 0.0063 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, kpack=2, matrix_instr_nonkdim=16, waves_per_eu=0, num_stages=2, num_warps=2
2025-07-31T14:12:45.0145703Z   triton_mm_plus_mm_9 0.0067 ms 94.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=16, BLOCK_N=32, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=False, kpack=2, matrix_instr_nonkdim=16, waves_per_eu=0, num_stages=2, num_warps=2
2025-07-31T14:12:45.0146015Z   _mm_plus_mm 0.0296 ms 21.4% 
2025-07-31T14:12:45.0146205Z SingleProcess AUTOTUNE benchmarking takes 0.2306 seconds and 0.2253 seconds precompiling for 3 choices
2025-07-31T14:12:45.0146581Z - generated xml file: /var/lib/jenkins/pytorch/test/test-reports/python-pytest/inductor.test_max_autotune/inductor.test_max_autotune-3db95b7588c8a2e4.xml -
2025-07-31T14:12:45.0147920Z =========================== short test summary info ============================
2025-07-31T14:12:45.0148273Z FAILED [0.6824s] inductor/test_max_autotune.py::TestMaxAutotune::test_triton_template_generated_code_caching_mm_plus_mm - AssertionError: Scalars are not equal!
2025-07-31T14:12:45.0148505Z 
2025-07-31T14:12:45.0148549Z Expected 4 but got 2.
2025-07-31T14:12:45.0148646Z Absolute difference: 2
2025-07-31T14:12:45.0148742Z Relative difference: 0.5

Presumably this change is now making rocm only check 2 configs instead of 4 for mm_plus_mm autotune. @huydhn @amdfaa May need to issue revert here to recover ROCm CI.

Testing a revert here #159642

jataylo added a commit to jataylo/pytorch that referenced this pull request Aug 1, 2025
@jataylo
Copy link
Collaborator

jataylo commented Aug 1, 2025

https://hud.pytorch.org/pr/159642
My manual revert of this PR gets the rocm CI back to green. Sorry will have to revert this. @coconutruben it looks like the autotune selection may have changed on ROCm after your PR the way I interpret the error above is that it was expecting 4 choices and got 2 but not 100% sure. I'll add ROCm labels to this PR if you have an idea how it can be fixed.

@jataylo
Copy link
Collaborator

jataylo commented Aug 1, 2025

@pytorchbot revert -c nosignal -m "sorry but rocm CI is broken due to this PR"

@jataylo jataylo added ciflow/rocm Trigger "default" config CI on ROCm ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Aug 1, 2025
@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Aug 1, 2025
…59383)"

This reverts commit e7cc42d.

Reverted #159383 on behalf of https://github.com/jataylo due to sorry but rocm CI is broken due to this PR ([comment](#159383 (comment)))
@pytorchmergebot
Copy link
Collaborator

@coconutruben your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Aug 1, 2025
@coconutruben
Copy link
Contributor Author

Thanks for reverting and sorry to break the ROCm CI I'll take a look

@jataylo
Copy link
Collaborator

jataylo commented Aug 1, 2025

Thanks @coconutruben let us know if any help is needed at all from our side to take a look

\# Why

- Make loop iteration simpler
- Have a common spot where to make modifications that affect
  all the GEMM Triton templates, avoiding missed spots

\# What

- pull out commong logic of taking the BaseConfig objects
  and turning them into kwargs to feed into maybe_append_choice
  for Triton GEMM templates

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D79186962](https://our.internmc.facebook.com/intern/diff/D79186962)

[ghstack-poisoned]
@coconutruben
Copy link
Contributor Author

coconutruben commented Aug 2, 2025

Thanks @coconutruben let us know if any help is needed at all from our side to take a look

Thank you @jataylo! I think I oversaw two things in the refactor that should be fixed now.

  1. mm_plus_mm configs didn't used to go through preprocessing, including scaling. I was aware of this but didn't see issues on H100 and figured that's fine and working as intended
  2. filter_configs for mm_plus_mm on ROCm used to be with num_stages = 1 rather than default_num_stages - that's rectified now

What would be helpful

  1. can you confirm on your end if this fixes the offending test? I wasn't able to build the ROCm torch on my AMD GPU so I tested with some hacking around to validate the fix
  2. as a follow-up, I think the test is failing because after scaling + filtering there are only 2 mm_plus_mm configs left for ROCm. Maybe that's fine? can you help us identify whether we should just change the test and bring scaling back for ROCm, or whether the lack of configs is an actual issue?

I'll leave the logic as is in the PR once the tests are green as this should be a straight refactor and we can follow up on (2) later

@jataylo
Copy link
Collaborator

jataylo commented Aug 2, 2025

@coconutruben I can take a look at this monday, theory makes sense to me, if we filter out configs due to duplicates because we use static num_stages then would make sense that config selection reduced.

@coconutruben
Copy link
Contributor Author

@jataylo, do you think the results here are fine to just merge in now? I saw all the rocm tests timed out, I assume there aren't enough runners available?

@jataylo jataylo added the keep-going Don't stop on first failure, keep running tests until the end label Aug 5, 2025
@jataylo
Copy link
Collaborator

jataylo commented Aug 5, 2025

Hi @coconutruben looks like mi300 runners are facing some issues at the minute but the regular ROCm CI still ran:
https://github.com/pytorch/pytorch/actions/runs/16693264444

I've just added keep-going label and rerunning the failing shard, it failed a helion test and exited, so trying with keep-going to make sure we don't run into any other failures then should be okay to merge. In order to get the mi300 ci running may need to rebase the PR, but the failure happened in both mi300/mi200 so we should be okay to go forward if the above action shows no suspicious failures.

@coconutruben
Copy link
Contributor Author

thanks @jataylo, I'm seeing the same helion failure again, though this seems to be classified as a trunk failure. I'll merge in with -f now to bypass the mi300 ci issues, and let's keep an eye on it if we see anything else suspicious

@coconutruben
Copy link
Contributor Author

@pytorchbot merge -f "ignoring mi300 pending as per comment above"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D79186962

coconutruben added a commit that referenced this pull request Aug 6, 2025
Summary:

This reverts the part of #159383 for scaled_mm where now, like before,
we pass through the normal input_nodes (not the triton_input_nodes)
to select_algorithm

- #159383 refactored how kwargs are retrieved
- it introduced this notion of KernelInputs that wrap input_nodes
- scaled_mm uses unsqueezed input nodes for triton to retrieve params
- the issue: it uses a squeezed (regular) bias for select_algorithm
  instead

This fixes that by passing the original input nodes rather
than the triton input nodes.

Test Plan:

```
buck test '@fbcode//mode/opt' fbcode//caffe2/test/inductor:fp8 -- --exact 'caffe2/test/inductor:fp8 - test_rowwise_scaling_shape_1024,1024,512_has_bias_True_use_fast_accum_True_persistent_matmul_False (caffe2.test.inductor.test_fp8.TestFP8Lowering)'
buck test '@fbcode//mode/opt' fbcode//caffe2/test/inductor:fp8 -- --exact 'caffe2/test/inductor:fp8 - test_rowwise_scaling_shape_1024,1024,512_has_bias_True_use_fast_accum_True_persistent_matmul_True (caffe2.test.inductor.test_fp8.TestFP8Lowering)'
```

This set of tests was failing, and is passing now

Side note: these tests were failing I believe because the unsqueezed
bias made the ATEN choice no longer eligible, and there is some minor
numerical discrepancy between ATEN and Triton for this. I'm not sure
the test should be written like that, as we're implicitly relying on
ATEN being the choice here.

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
coconutruben added a commit that referenced this pull request Aug 6, 2025
Summary:

This reverts the part of #159383 for scaled_mm where now, like before,
we pass through the normal input_nodes (not the triton_input_nodes)
to select_algorithm

- #159383 refactored how kwargs are retrieved
- it introduced this notion of KernelInputs that wrap input_nodes
- scaled_mm uses unsqueezed input nodes for triton to retrieve params
- the issue: it uses a squeezed (regular) bias for select_algorithm
  instead

This fixes that by passing the original input nodes rather
than the triton input nodes.

Test Plan:

```
buck test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:fp8 -- --exact 'caffe2/test/inductor:fp8 - test_rowwise_scaling_shape_1024,1024,512_has_bias_True_use_fast_accum_True_persistent_matmul_False (caffe2.test.inductor.test_fp8.TestFP8Lowering)'
buck test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:fp8 -- --exact 'caffe2/test/inductor:fp8 - test_rowwise_scaling_shape_1024,1024,512_has_bias_True_use_fast_accum_True_persistent_matmul_True (caffe2.test.inductor.test_fp8.TestFP8Lowering)'
```

This set of tests was failing, and is passing now

Side note: these tests were failing I believe because the unsqueezed
bias made the ATEN choice no longer eligible, and there is some minor
numerical discrepancy between ATEN and Triton for this. I'm not sure
the test should be written like that, as we're implicitly relying on
ATEN being the choice here.

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: b7b8caa
Pull Request resolved: #159948
pytorchmergebot pushed a commit that referenced this pull request Aug 6, 2025
Summary:

This reverts the part of #159383 for scaled_mm where now, like before,
we pass through the normal input_nodes (not the triton_input_nodes)
to select_algorithm

- #159383 refactored how kwargs are retrieved
- it introduced this notion of KernelInputs that wrap input_nodes
- scaled_mm uses unsqueezed input nodes for triton to retrieve params
- the issue: it uses a squeezed (regular) bias for select_algorithm
  instead

This fixes that by passing the original input nodes rather
than the triton input nodes.

Test Plan:

```
buck test '@fbcode//mode/opt' fbcode//caffe2/test/inductor:fp8 -- --exact 'caffe2/test/inductor:fp8 - test_rowwise_scaling_shape_1024,1024,512_has_bias_True_use_fast_accum_True_persistent_matmul_False (caffe2.test.inductor.test_fp8.TestFP8Lowering)'
buck test '@fbcode//mode/opt' fbcode//caffe2/test/inductor:fp8 -- --exact 'caffe2/test/inductor:fp8 - test_rowwise_scaling_shape_1024,1024,512_has_bias_True_use_fast_accum_True_persistent_matmul_True (caffe2.test.inductor.test_fp8.TestFP8Lowering)'
```

This set of tests was failing, and is passing now

Side note: these tests were failing I believe because the unsqueezed
bias made the ATEN choice no longer eligible, and there is some minor
numerical discrepancy between ATEN and Triton for this. I'm not sure
the test should be written like that, as we're implicitly relying on
ATEN being the choice here.

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D79717654](https://our.internmc.facebook.com/intern/diff/D79717654)
Pull Request resolved: #159948
Approved by: https://github.com/izaitsevfb, https://github.com/eellison
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ciflow/inductor ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/trunk Trigger trunk jobs on your pull request fb-exported keep-going Don't stop on first failure, keep running tests until the end Merged module: inductor Reverted topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants