-
Notifications
You must be signed in to change notification settings - Fork 24.9k
foreach_map_fn: add UX fallback for torch.mm to avoid Dynamo graph breaks #159757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import unittest | ||
import torch | ||
from torch._foreach import foreach_map | ||
|
||
class TestForeachMapMatmul(unittest.TestCase): | ||
def test_foreach_map_with_torch_mm(self): | ||
a_list = [torch.randn(3, 4) for _ in range(3)] | ||
b_list = [torch.randn(4, 2) for _ in range(3)] | ||
|
||
result = foreach_map(torch.mm, a_list, b_list) | ||
expected = [torch.mm(a, b) for a, b in zip(a_list, b_list)] | ||
|
||
for r, e in zip(result, expected): | ||
self.assertTrue(torch.allclose(r, e), msg=f"Expected {e}, got {r}") | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -318,6 +318,12 @@ def foreach_map_fn(*args): | |
if not at_least_one_list: | ||
return op(*args[1:]) | ||
|
||
# Special handling for torch.mm | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So before this PR it was an open question whether this was needed or not, can you check what happens in your test without the special handling? It's possible this was already working, but it just needed a test. |
||
if op is torch.mm: | ||
if len(new_args) != 2: | ||
raise ValueError("torch.mm requires exactly two argument lists") | ||
return [torch.mm(a, b) for a, b in zip(new_args[0], new_args[1])] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wait...does this just forloop over torch.mm with no optimization... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah this is just a UX improvement so that we no longer need to workaround the matmul. (we can basically just pass a whole optim single-tensor implementation, vs breaking it up into multiple foreach_map calls to workaround the matmul. |
||
|
||
out = [] | ||
for unpacked in zip(*new_args): | ||
out.append(op(*unpacked)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.