From 217eaff67d92a449f69b3840aceca850f8032e67 Mon Sep 17 00:00:00 2001 From: jianyizh Date: Wed, 30 Jul 2025 08:16:28 +0000 Subject: [PATCH] save --- torch/_inductor/graph.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index e2cc101533f28..f03b438a44ed8 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -800,12 +800,17 @@ def find_nodes_prefer_channels_last(self) -> OrderedSet[Node]: With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies can be saved. """ + last_conv = None + nodes_cannot_propagate = [torch.ops.aten.bmm.default] output_set = OrderedSet[Node]() for n in reversed(self.module.graph.nodes): # type: ignore[arg-type, union-attr] if n.target == torch.ops.aten.convolution.default: output_set.add(n) + if last_conv is None: + last_conv = n + continue + if n.target in nodes_cannot_propagate: continue - for user in n.users: if user in output_set: output_set.add(n) @@ -826,8 +831,14 @@ def find_nodes_prefer_channels_last(self) -> OrderedSet[Node]: # - res2net50_14w_8s # - sebotnet33ts_256 for n in self.module.graph.nodes: # type: ignore[union-attr] + # layout propagation ends at last conv node, which will benefit vison transformers. + if last_conv is not None and n == last_conv: + break if n in output_set: - output_set.update(n.users) + for user in n.users: + if user.target in nodes_cannot_propagate: + continue + output_set.add(user) return output_set