1
1
# Copyright (c) Meta Platforms, Inc. and affiliates.
2
2
# All rights reserved.
3
+ # Copyright 2025 Arm Limited and/or its affiliates.
3
4
#
4
5
# This source code is licensed under the BSD-style license found in the
5
6
# LICENSE file in the root directory of this source tree.
11
12
from executorch .exir .pass_base import ExportPass , PassResult
12
13
13
14
14
- def merge_view_copy_chains (graph : torch .fx .Graph ) -> torch .fx .Graph :
15
+ def merge_view_copy_chains (graph : torch .fx .Graph ) -> tuple [ torch .fx .Graph , bool ] :
15
16
"""
16
17
Find chains of view_copy nodes and merge them into one view_copy node.
17
18
Only merges view_copy nodes that are not used by any other nodes.
18
19
"""
19
20
ops = exir_ops .edge
20
21
view_op = ops .aten .view_copy .default
22
+ modified = False
21
23
for node in graph .nodes :
22
24
if node .op == "call_function" and node .target == view_op :
23
25
# find ending view_copy node in chain
@@ -35,29 +37,36 @@ def merge_view_copy_chains(graph: torch.fx.Graph) -> torch.fx.Graph:
35
37
new_args = (node .args [0 ], end_node .args [1 ])
36
38
node .args = new_args
37
39
end_node .replace_all_uses_with (node )
40
+ modified = True
38
41
39
42
graph .eliminate_dead_code ()
40
- return graph
43
+ return graph , modified
41
44
42
45
43
- def remove_noop_view_copy (graph : torch .fx .Graph ) -> torch .fx .Graph :
46
+ def remove_noop_view_copy (graph : torch .fx .Graph ) -> tuple [ torch .fx .Graph , bool ] :
44
47
"""
45
48
Remove view_copy nodes that are no-ops.
46
49
"""
47
50
ops = exir_ops .edge
48
51
view_op = ops .aten .view_copy .default
52
+ modified = False
49
53
for node in graph .nodes :
50
54
if node .op == "call_function" and node .target == view_op :
51
55
input_shape = list (node .args [0 ].meta ["val" ].shape )
52
56
target_shape = node .args [1 ]
53
57
if input_shape == target_shape :
54
58
node .replace_all_uses_with (node .args [0 ])
59
+ modified = True
55
60
graph .eliminate_dead_code ()
56
- return graph
61
+ return graph , modified
57
62
58
63
59
64
class FuseViewCopyTransform (ExportPass ):
60
65
def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
61
- graph_module .graph = merge_view_copy_chains (graph_module .graph )
62
- graph_module .graph = remove_noop_view_copy (graph_module .graph )
63
- return PassResult (graph_module , True )
66
+ graph_module .graph , merge_modified = merge_view_copy_chains (graph_module .graph )
67
+ graph_module .graph , noop_modified = remove_noop_view_copy (graph_module .graph )
68
+ modified = merge_modified or noop_modified
69
+ if modified :
70
+ graph_module .recompile ()
71
+ graph_module = super ().call (graph_module ).graph_module
72
+ return PassResult (graph_module , modified )
0 commit comments