Skip to content

Commit 26e921e

Browse files
authored
exir dialect view to squeeze/unsqueeze pass
Differential Revision: D61732548 Pull Request resolved: #4877
1 parent 6d29c1d commit 26e921e

File tree

2 files changed

+142
-0
lines changed

2 files changed

+142
-0
lines changed

backends/transforms/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,20 @@ runtime.python_library(
8888
],
8989
)
9090

91+
runtime.python_library(
92+
name = "view_copy_to_squeeze_unsqueeze",
93+
srcs = ["view_copy_to_squeeze_unsqueeze.py"],
94+
visibility = [
95+
"//executorch/backends/...",
96+
],
97+
deps = [
98+
":utils",
99+
"//caffe2:torch",
100+
"//executorch/exir:pass_base",
101+
"//executorch/exir/dialects:lib",
102+
],
103+
)
104+
91105
runtime.python_library(
92106
name = "fuse_view_copy",
93107
srcs = ["fuse_view_copy.py"],
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import List, Optional, Union
10+
11+
import torch
12+
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
16+
17+
class ViewCopyToSqueezeUnsqueezePass(ExportPass):
18+
"""
19+
Replaces view_copy nodes with squeeze_copy.dims nodes if the view node reduces dims of size 1.
20+
Replaces view_copy nodes with unsqueeze_copy.default nodes if the view node adds a dim of size 1.
21+
"""
22+
23+
def __init__(self) -> None:
24+
super().__init__()
25+
self.view_copy_op: torch._ops.OpOverload = exir_ops.edge.aten.view_copy.default
26+
self.squeeze_op: torch._ops.OpOverload = exir_ops.edge.aten.squeeze_copy.dims
27+
self.unsqueeze_op: torch._ops.OpOverload = (
28+
exir_ops.edge.aten.unsqueeze_copy.default
29+
)
30+
31+
def is_node_target(
32+
self, node: torch.fx.Node, target: torch._ops.OperatorBase
33+
) -> bool:
34+
return node.op == "call_function" and node.target == target
35+
36+
def find_squeeze_dims(
37+
self,
38+
input_shape: List[int],
39+
view_shape: List[int],
40+
) -> Optional[List[int]]:
41+
# view_shape should be a subset of input_shape
42+
if len(input_shape) <= len(view_shape):
43+
return None
44+
45+
# check that all dims are equal except the removed dims
46+
i = 0
47+
j = 0
48+
idx = []
49+
while i < len(input_shape):
50+
if input_shape[i] != view_shape[j]:
51+
if input_shape[i] == 1:
52+
idx.append(i)
53+
j -= 1
54+
# continue to check remaining dims are equal
55+
else:
56+
return None
57+
i += 1
58+
j += 1
59+
return idx
60+
61+
def find_unsqueeze_dim(
62+
self,
63+
input_shape: List[int],
64+
view_shape: List[int],
65+
) -> Optional[int]:
66+
# unsqueeze should increase the length of input_shape by 1
67+
if len(view_shape) - len(input_shape) != 1:
68+
return None
69+
70+
# check that all dims are equal except the added dim
71+
i = 0
72+
j = 0
73+
idx = -1
74+
while j < len(view_shape):
75+
if input_shape[i] != view_shape[j]:
76+
if view_shape[j] == 1:
77+
idx = j
78+
i -= 1
79+
# continue to check remaining dims are equal
80+
else:
81+
return None
82+
i += 1
83+
j += 1
84+
return idx
85+
86+
def replace_view_copy_node(
87+
self,
88+
graph_module: torch.fx.GraphModule,
89+
view_node: torch.fx.Node,
90+
op: torch._ops.OpOverload,
91+
arg: Union[List[int], int],
92+
) -> None:
93+
with graph_module.graph.inserting_before(view_node):
94+
new_node = graph_module.graph.create_node(
95+
"call_function",
96+
op,
97+
(view_node.args[0], arg),
98+
)
99+
new_node.meta = view_node.meta
100+
view_node.replace_all_uses_with(new_node)
101+
graph_module.graph.erase_node(view_node)
102+
103+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
104+
modified = False
105+
for node in graph_module.graph.nodes:
106+
if self.is_node_target(node, self.view_copy_op):
107+
input_node = node.args[0]
108+
input_shape = input_node.meta["val"].shape
109+
view_shape = node.args[1]
110+
squeeze_dims = self.find_squeeze_dims(input_shape, view_shape)
111+
if squeeze_dims:
112+
self.replace_view_copy_node(
113+
graph_module, node, self.squeeze_op, squeeze_dims
114+
)
115+
modified = True
116+
continue
117+
unsqueeze_dim = self.find_unsqueeze_dim(input_shape, view_shape)
118+
if unsqueeze_dim:
119+
self.replace_view_copy_node(
120+
graph_module, node, self.unsqueeze_op, unsqueeze_dim
121+
)
122+
modified = True
123+
continue
124+
125+
if modified:
126+
graph_module.recompile()
127+
graph_module = super().call(graph_module).graph_module
128+
return PassResult(graph_module, modified)

0 commit comments

Comments
 (0)