-
Notifications
You must be signed in to change notification settings - Fork 24.9k
foreach_map_fn: add UX fallback for torch.mm to avoid Dynamo graph breaks #159757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
foreach_map_fn: add UX fallback for torch.mm to avoid Dynamo graph breaks #159757
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159757
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d1f887c with merge base fb887c3 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "topic: not user facing" |
Hey, just confirming this PR is ready for review. It's fully updated, tests are passing (with unrelated flakes), and I'm done making changes. Let me know if anything else is needed! |
This is cool! Could you add a test? Also does this work without the special handling? Your addition is very similar to the code below it, and perhaps this was already working? Either way adding a test for this scenario would be great and then I'll approve! |
Hi @mlazos — just wanted to share a quick update. I’ve started working on the fix for handling matmuls within foreach_map. Currently validating the fallback logic and ensuring it integrates cleanly without affecting existing behavior. I’ll update the PR once things are stable. Please let me know if there are any specific edge cases or requirements I should keep in mind. Thanks! |
Added the test! Let me know if anything else is needed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this change actually improve perf? foreach_map should be a perf optimization, so I'd expect to see benchmarks.
expected = [torch.mm(a, b) for a, b in zip(a_list, b_list)] | ||
|
||
for r, e in zip(result, expected): | ||
self.assertTrue(torch.allclose(r, e), msg=f"Expected {e}, got {r}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.assertTrue(torch.allclose(r, e), msg=f"Expected {e}, got {r}") | |
self.assertEqual(r, e) |
if op is torch.mm: | ||
if len(new_args) != 2: | ||
raise ValueError("torch.mm requires exactly two argument lists") | ||
return [torch.mm(a, b) for a, b in zip(new_args[0], new_args[1])] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait...does this just forloop over torch.mm with no optimization...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is just a UX improvement so that we no longer need to workaround the matmul. (we can basically just pass a whole optim single-tensor implementation, vs breaking it up into multiple foreach_map calls to workaround the matmul.
torch/_foreach_where.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks irrelevant to this change, let's address foreach_where separately.
torch/optim/swa_utils.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not related
Hi @janeyx99, thanks for the detailed review! I've removed the unrelated changes ( Let me know if you'd like me to add benchmarks for this path. Based on your earlier comment, I understand performance is a key concern here — happy to provide numbers or additional tests if needed. Thanks again for your time! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does not appear that you've removed irrelevant changes. And yes, please run some benchmarks comparing against just a for loop of mms.
148a24e
to
d1f887c
Compare
Thanks for your patience — I’ve cleaned up the branch so it only includes the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Im still not sure this helps with any performance, but I’ll let @mlazos finish the review as he’s the expert with foreach_map!
Yeah this won't improve perf necessarily over generic torch.compile, this is a UX improvement so we can have cleaner code with a whole optimizer loop w/ GEMMs in it. In the future we can change the lowering to use grouped gemm or write a foreach_mm kernel if applicable. |
@@ -318,6 +318,12 @@ def foreach_map_fn(*args): | |||
if not at_least_one_list: | |||
return op(*args[1:]) | |||
|
|||
# Special handling for torch.mm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So before this PR it was an open question whether this was needed or not, can you check what happens in your test without the special handling? It's possible this was already working, but it just needed a test.
This PR adds special handling to foreach_map_fn to support torch.mm (matrix multiplication) with list arguments. If torch.mm is passed, it now unpacks two argument lists and applies mm elementwise.
This prevents graph breaks in Dynamo when torch.mm is used in a foreach_map context.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Lucaskabela