Skip to content

foreach_map_fn: add UX fallback for torch.mm to avoid Dynamo graph breaks #159757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Quantum-Kayak
Copy link

@Quantum-Kayak Quantum-Kayak commented Aug 4, 2025

This PR adds special handling to foreach_map_fn to support torch.mm (matrix multiplication) with list arguments. If torch.mm is passed, it now unpacks two argument lists and applies mm elementwise.

This prevents graph breaks in Dynamo when torch.mm is used in a foreach_map context.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Lucaskabela

Copy link

pytorch-bot bot commented Aug 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159757

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d1f887c with merge base fb887c3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Aug 4, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

@Quantum-Kayak
Copy link
Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Aug 4, 2025
@Quantum-Kayak Quantum-Kayak reopened this Aug 4, 2025
@janeyx99 janeyx99 requested a review from mlazos August 4, 2025 23:01
@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 4, 2025
@Quantum-Kayak
Copy link
Author

Hey, just confirming this PR is ready for review. It's fully updated, tests are passing (with unrelated flakes), and I'm done making changes. Let me know if anything else is needed!

@mlazos
Copy link
Contributor

mlazos commented Aug 6, 2025

This is cool! Could you add a test? Also does this work without the special handling? Your addition is very similar to the code below it, and perhaps this was already working? Either way adding a test for this scenario would be great and then I'll approve!

@Quantum-Kayak
Copy link
Author

Quantum-Kayak commented Aug 6, 2025

Hi @mlazos — just wanted to share a quick update. I’ve started working on the fix for handling matmuls within foreach_map. Currently validating the fallback logic and ensuring it integrates cleanly without affecting existing behavior. I’ll update the PR once things are stable.

Please let me know if there are any specific edge cases or requirements I should keep in mind.

Thanks!

@Quantum-Kayak
Copy link
Author

Added the test! Let me know if anything else is needed.

@albanD albanD removed their request for review August 7, 2025 13:59
Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this change actually improve perf? foreach_map should be a perf optimization, so I'd expect to see benchmarks.

expected = [torch.mm(a, b) for a, b in zip(a_list, b_list)]

for r, e in zip(result, expected):
self.assertTrue(torch.allclose(r, e), msg=f"Expected {e}, got {r}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.assertTrue(torch.allclose(r, e), msg=f"Expected {e}, got {r}")
self.assertEqual(r, e)

if op is torch.mm:
if len(new_args) != 2:
raise ValueError("torch.mm requires exactly two argument lists")
return [torch.mm(a, b) for a, b in zip(new_args[0], new_args[1])]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait...does this just forloop over torch.mm with no optimization...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is just a UX improvement so that we no longer need to workaround the matmul. (we can basically just pass a whole optim single-tensor implementation, vs breaking it up into multiple foreach_map calls to workaround the matmul.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks irrelevant to this change, let's address foreach_where separately.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related

@Quantum-Kayak
Copy link
Author

Hi @janeyx99, thanks for the detailed review!

I've removed the unrelated changes (_foreach_where.py, swa_utils.py) from this PR — the branch now contains only the relevant fallback logic and its test for torch.mm in foreach_map_fn.

Let me know if you'd like me to add benchmarks for this path. Based on your earlier comment, I understand performance is a key concern here — happy to provide numbers or additional tests if needed.

Thanks again for your time!

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not appear that you've removed irrelevant changes. And yes, please run some benchmarks comparing against just a for loop of mms.

@Quantum-Kayak Quantum-Kayak force-pushed the foreach-map-matmul-fallback branch from 148a24e to d1f887c Compare August 9, 2025 08:18
@Quantum-Kayak
Copy link
Author

Thanks for your patience — I’ve cleaned up the branch so it only includes the torch.mm fallback and its test. Sorry about the earlier mess; should be all set for review now.

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im still not sure this helps with any performance, but I’ll let @mlazos finish the review as he’s the expert with foreach_map!

@mlazos
Copy link
Contributor

mlazos commented Aug 11, 2025

Im still not sure this helps with any performance, but I’ll let @mlazos finish the review as he’s the expert with foreach_map!

Yeah this won't improve perf necessarily over generic torch.compile, this is a UX improvement so we can have cleaner code with a whole optimizer loop w/ GEMMs in it.

In the future we can change the lowering to use grouped gemm or write a foreach_mm kernel if applicable.

@Quantum-Kayak Quantum-Kayak changed the title Add fallback support for torch.mm in foreach_map_fn foreach_map_fn: add UX fallback for torch.mm to avoid Dynamo graph breaks Aug 12, 2025
@@ -318,6 +318,12 @@ def foreach_map_fn(*args):
if not at_least_one_list:
return op(*args[1:])

# Special handling for torch.mm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So before this PR it was an open question whether this was needed or not, can you check what happens in your test without the special handling? It's possible this was already working, but it just needed a test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants