Skip to content

[WIP][fake tensor] invalidate memos for PropagateUnbackedSymInts #155187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -8963,6 +8963,24 @@ def forward(self, x):
ep.module()(*copy.deepcopy(inputs)), M()(*copy.deepcopy(inputs))
)

def test_nonzero_memo(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.mask = torch.zeros([4497], dtype=torch.bool)
self.mask[0:21] = True

def forward(self, x, y):
a = x[:, self.mask] # Create u0
b = y[:, self.mask] # Create another instance of u0 separately
return a, b

mod = Model()
x = torch.rand([1,4497,2])
y = torch.rand([1,4497,2])
ep = export(mod, (x, y))
ep = ep.run_decompositions(torch.export.default_decompositions())

def test__scaled_dot_product_flash_attention(self):
class Module(torch.nn.Module):
def forward(self, q, k, v):
Expand Down
23 changes: 23 additions & 0 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7818,13 +7818,36 @@


class PropagateUnbackedSymInts(torch.fx.Interpreter):
def run(self, *args, **kwargs) -> Any:

Check failure on line 7821 in torch/fx/experimental/symbolic_shapes.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [no-untyped-def]

Function is missing a type annotation for one or more arguments
self.invalidated_tensor_ids = set()

Check failure on line 7822 in torch/fx/experimental/symbolic_shapes.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [var-annotated]

Need type annotation for "invalidated_tensor_ids" (hint: "invalidated_tensor_ids: set[<type>] = ...")
return super().run(*args, **kwargs)

def _invalidate_fake_tensor_memos(self, n: torch.fx.Node) -> None:
from torch._subclasses.fake_tensor import FakeTensor

if (
(val := n.meta.get("val")) is not None
and isinstance(t := val, FakeTensor)
and id(t) not in self.invalidated_tensor_ids
):
t.nonzero_memo = None
t.item_memo = None
t.unique_memo = None
t.unique_consecutive_memo = None
self.invalidated_tensor_ids.add(id(t))

def _maybe_invalidate_fake_tensor_memos(self, n: torch.fx.Node) -> None:
pytree.tree_map_only(torch.fx.Node, self._invalidate_fake_tensor_memos, n.args)
pytree.tree_map_only(torch.fx.Node, self._invalidate_fake_tensor_memos, n.kwargs)

def run_node(self, n: torch.fx.Node) -> Result:
"""
Run an FX node, propagating unbacked Symbol bindings to the new fake tensor
"""
from torch._guards import detect_fake_mode

result = super().run_node(n)
self._maybe_invalidate_fake_tensor_memos(n)
rebind_unbacked(detect_fake_mode().shape_env, n, result)
return result

Expand Down
Loading