Skip to content

foreach_map_fn: add UX fallback for torch.mm to avoid Dynamo graph breaks #159757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions test/inductor/test_foreach_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import unittest
import torch
from torch._foreach import foreach_map

class TestForeachMapMatmul(unittest.TestCase):
def test_foreach_map_with_torch_mm(self):
a_list = [torch.randn(3, 4) for _ in range(3)]
b_list = [torch.randn(4, 2) for _ in range(3)]

result = foreach_map(torch.mm, a_list, b_list)
expected = [torch.mm(a, b) for a, b in zip(a_list, b_list)]

for r, e in zip(result, expected):
self.assertTrue(torch.allclose(r, e), msg=f"Expected {e}, got {r}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.assertTrue(torch.allclose(r, e), msg=f"Expected {e}, got {r}")
self.assertEqual(r, e)

6 changes: 6 additions & 0 deletions torch/_dynamo/polyfills/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,12 @@ def foreach_map_fn(*args):
if not at_least_one_list:
return op(*args[1:])

# Special handling for torch.mm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So before this PR it was an open question whether this was needed or not, can you check what happens in your test without the special handling? It's possible this was already working, but it just needed a test.

if op is torch.mm:
if len(new_args) != 2:
raise ValueError("torch.mm requires exactly two argument lists")
return [torch.mm(a, b) for a, b in zip(new_args[0], new_args[1])]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait...does this just forloop over torch.mm with no optimization...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is just a UX improvement so that we no longer need to workaround the matmul. (we can basically just pass a whole optim single-tensor implementation, vs breaking it up into multiple foreach_map calls to workaround the matmul.


out = []
for unpacked in zip(*new_args):
out.append(op(*unpacked))
Expand Down