Skip to content

Commit c91a6c0

Browse files
committed
Fix retracing in FuseViewCopyTransform
Since the pass can change shapes of ops, the graph needs to be retraced to show this in node.meta["val"]. Signed-off-by: Erik Lundell <[email protected]> Change-Id: Ief24fe9d11384a2d0f64f0d91070eca7b0caf18e
1 parent 95d7cce commit c91a6c0

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

backends/transforms/fuse_view_copy.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -11,13 +12,14 @@
1112
from executorch.exir.pass_base import ExportPass, PassResult
1213

1314

14-
def merge_view_copy_chains(graph: torch.fx.Graph) -> torch.fx.Graph:
15+
def merge_view_copy_chains(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]:
1516
"""
1617
Find chains of view_copy nodes and merge them into one view_copy node.
1718
Only merges view_copy nodes that are not used by any other nodes.
1819
"""
1920
ops = exir_ops.edge
2021
view_op = ops.aten.view_copy.default
22+
modified = False
2123
for node in graph.nodes:
2224
if node.op == "call_function" and node.target == view_op:
2325
# find ending view_copy node in chain
@@ -35,29 +37,36 @@ def merge_view_copy_chains(graph: torch.fx.Graph) -> torch.fx.Graph:
3537
new_args = (node.args[0], end_node.args[1])
3638
node.args = new_args
3739
end_node.replace_all_uses_with(node)
40+
modified = True
3841

3942
graph.eliminate_dead_code()
40-
return graph
43+
return graph, modified
4144

4245

43-
def remove_noop_view_copy(graph: torch.fx.Graph) -> torch.fx.Graph:
46+
def remove_noop_view_copy(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]:
4447
"""
4548
Remove view_copy nodes that are no-ops.
4649
"""
4750
ops = exir_ops.edge
4851
view_op = ops.aten.view_copy.default
52+
modified = False
4953
for node in graph.nodes:
5054
if node.op == "call_function" and node.target == view_op:
5155
input_shape = list(node.args[0].meta["val"].shape)
5256
target_shape = node.args[1]
5357
if input_shape == target_shape:
5458
node.replace_all_uses_with(node.args[0])
59+
modified = True
5560
graph.eliminate_dead_code()
56-
return graph
61+
return graph, modified
5762

5863

5964
class FuseViewCopyTransform(ExportPass):
6065
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
61-
graph_module.graph = merge_view_copy_chains(graph_module.graph)
62-
graph_module.graph = remove_noop_view_copy(graph_module.graph)
63-
return PassResult(graph_module, True)
66+
graph_module.graph, merge_modified = merge_view_copy_chains(graph_module.graph)
67+
graph_module.graph, noop_modified = remove_noop_view_copy(graph_module.graph)
68+
modified = merge_modified or noop_modified
69+
if modified:
70+
graph_module.recompile()
71+
graph_module = super().call(graph_module).graph_module
72+
return PassResult(graph_module, modified)

0 commit comments

Comments
 (0)