From cf250e7cb5a25bb263a1cd7f374edd32a9a51b11 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 4 Jun 2025 19:03:12 -0700 Subject: [PATCH] init --- test/export/test_export.py | 18 ++++++++++++++++++ torch/fx/experimental/symbolic_shapes.py | 23 +++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/test/export/test_export.py b/test/export/test_export.py index ff638981bf4c..99f4178328b4 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -8963,6 +8963,24 @@ def forward(self, x): ep.module()(*copy.deepcopy(inputs)), M()(*copy.deepcopy(inputs)) ) + def test_nonzero_memo(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.mask = torch.zeros([4497], dtype=torch.bool) + self.mask[0:21] = True + + def forward(self, x, y): + a = x[:, self.mask] # Create u0 + b = y[:, self.mask] # Create another instance of u0 separately + return a, b + + mod = Model() + x = torch.rand([1,4497,2]) + y = torch.rand([1,4497,2]) + ep = export(mod, (x, y)) + ep = ep.run_decompositions(torch.export.default_decompositions()) + def test__scaled_dot_product_flash_attention(self): class Module(torch.nn.Module): def forward(self, q, k, v): diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 8042515a7371..3dc03eaea731 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -7818,6 +7818,28 @@ def _is_dim_dynamic(t: torch.Tensor, d: int) -> bool: class PropagateUnbackedSymInts(torch.fx.Interpreter): + def run(self, *args, **kwargs) -> Any: + self.invalidated_tensor_ids = set() + return super().run(*args, **kwargs) + + def _invalidate_fake_tensor_memos(self, n: torch.fx.Node) -> None: + from torch._subclasses.fake_tensor import FakeTensor + + if ( + (val := n.meta.get("val")) is not None + and isinstance(t := val, FakeTensor) + and id(t) not in self.invalidated_tensor_ids + ): + t.nonzero_memo = None + t.item_memo = None + t.unique_memo = None + t.unique_consecutive_memo = None + self.invalidated_tensor_ids.add(id(t)) + + def _maybe_invalidate_fake_tensor_memos(self, n: torch.fx.Node) -> None: + pytree.tree_map_only(torch.fx.Node, self._invalidate_fake_tensor_memos, n.args) + pytree.tree_map_only(torch.fx.Node, self._invalidate_fake_tensor_memos, n.kwargs) + def run_node(self, n: torch.fx.Node) -> Result: """ Run an FX node, propagating unbacked Symbol bindings to the new fake tensor @@ -7825,6 +7847,7 @@ def run_node(self, n: torch.fx.Node) -> Result: from torch._guards import detect_fake_mode result = super().run_node(n) + self._maybe_invalidate_fake_tensor_memos(n) rebind_unbacked(detect_fake_mode().shape_env, n, result) return result