diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index e2cc101533f2..f03b438a44ed 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -800,12 +800,17 @@ def find_nodes_prefer_channels_last(self) -> OrderedSet[Node]: With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies can be saved. """ + last_conv = None + nodes_cannot_propagate = [torch.ops.aten.bmm.default] output_set = OrderedSet[Node]() for n in reversed(self.module.graph.nodes): # type: ignore[arg-type, union-attr] if n.target == torch.ops.aten.convolution.default: output_set.add(n) + if last_conv is None: + last_conv = n + continue + if n.target in nodes_cannot_propagate: continue - for user in n.users: if user in output_set: output_set.add(n) @@ -826,8 +831,14 @@ def find_nodes_prefer_channels_last(self) -> OrderedSet[Node]: # - res2net50_14w_8s # - sebotnet33ts_256 for n in self.module.graph.nodes: # type: ignore[union-attr] + # layout propagation ends at last conv node, which will benefit vison transformers. + if last_conv is not None and n == last_conv: + break if n in output_set: - output_set.update(n.users) + for user in n.users: + if user.target in nodes_cannot_propagate: + continue + output_set.add(user) return output_set