diff --git a/test/inductor/test_foreach_map.py b/test/inductor/test_foreach_map.py new file mode 100644 index 000000000000..b789019829d1 --- /dev/null +++ b/test/inductor/test_foreach_map.py @@ -0,0 +1,14 @@ +import unittest +import torch +from torch._foreach import foreach_map + +class TestForeachMapMatmul(unittest.TestCase): + def test_foreach_map_with_torch_mm(self): + a_list = [torch.randn(3, 4) for _ in range(3)] + b_list = [torch.randn(4, 2) for _ in range(3)] + + result = foreach_map(torch.mm, a_list, b_list) + expected = [torch.mm(a, b) for a, b in zip(a_list, b_list)] + + for r, e in zip(result, expected): + self.assertTrue(torch.allclose(r, e), msg=f"Expected {e}, got {r}") diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 6d467b215797..85d67f1b685a 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -318,6 +318,12 @@ def foreach_map_fn(*args): if not at_least_one_list: return op(*args[1:]) + # Special handling for torch.mm + if op is torch.mm: + if len(new_args) != 2: + raise ValueError("torch.mm requires exactly two argument lists") + return [torch.mm(a, b) for a, b in zip(new_args[0], new_args[1])] + out = [] for unpacked in zip(*new_args): out.append(op(*unpacked))