Skip to content

Commit 77a0e35

Browse files
AdrianLundellfreddan80
authored andcommitted
Add stricter transpose condition for TOSA reshape lowering
Removes transposes in lowered graph for reshapes in only H,W dimension and clarifies logic in the annotate_channels_last_dum_order_pass Change-Id: I87e8575d7da8ad56a1f4e937837d7549c05aa11e
1 parent 0b1c1e5 commit 77a0e35

File tree

2 files changed

+118
-37
lines changed

2 files changed

+118
-37
lines changed

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 117 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from executorch.backends.arm._passes.arm_pass_utils import (
1313
create_node,
1414
get_first_fake_tensor,
15+
get_node_arg,
1516
insert_q_dq_pair,
1617
)
1718
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op
@@ -83,14 +84,48 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
8384

8485
return False
8586

86-
def insert_input_transpose(self, node, input_node, graph_module):
87+
@staticmethod
88+
def memory_format_differs(shape):
89+
"""Returns true if the shape will have a different memory layout in NCHW and NHWC format"""
90+
if len(shape) >= 4:
91+
C = shape[1]
92+
H = shape[2]
93+
W = shape[3]
94+
elif len(shape) == 3:
95+
C = shape[0]
96+
H = shape[1]
97+
W = shape[2]
98+
if len(shape) <= 2:
99+
return False
100+
101+
return C > 1 and (H > 1 or W > 1)
102+
103+
@staticmethod
104+
def is_channel_reshape(input_shape, output_shape):
105+
"""Returns true if the reshape changes the channel dimension"""
106+
if not len(input_shape) == len(output_shape) == 4:
107+
return False
108+
109+
C_old = input_shape[1]
110+
C_new = output_shape[1]
111+
112+
N_new = output_shape[0]
113+
N_old = input_shape[0]
114+
115+
return (N_old != N_new) or (C_old != C_new)
116+
117+
@staticmethod
118+
def insert_input_transpose(node, input_node, graph_module):
87119
quantize = input_node.target == dq_op
88120
q_params = input_node.args[1:] if quantize else None
89121
with graph_module.graph.inserting_before(node):
90122
permute_node = create_node(
91123
graph_module.graph,
92124
torch.ops.passthrough_to_tosa._transpose,
93-
args=(input_node, list(self.NHWC_inverse_order)),
125+
args=(
126+
input_node,
127+
list(AnnotateChannelsLastDimOrder.NHWC_inverse_order),
128+
),
94129
quantize=quantize,
95130
q_params=q_params,
96131
)
@@ -100,14 +135,17 @@ def insert_input_transpose(self, node, input_node, graph_module):
100135
range(len(input_node.meta["val"].size()))
101136
)
102137

103-
def insert_output_transpose(self, node, graph_module):
138+
@staticmethod
139+
def insert_output_transpose(node, graph_module):
104140
with graph_module.graph.inserting_after(node):
105141
permute_node = create_node(
106142
graph_module.graph,
107143
torch.ops.passthrough_to_tosa._transpose,
108-
args=(node, list(self.NHWC_order)),
144+
args=(node, list(AnnotateChannelsLastDimOrder.NHWC_order)),
145+
)
146+
permute_node.meta["tosa_dim_order"] = (
147+
AnnotateChannelsLastDimOrder.NHWC_order
109148
)
110-
permute_node.meta["tosa_dim_order"] = self.NHWC_order
111149
node.meta["tosa_dim_order"] = (0, 1, 2, 3)
112150
users = [user for user in node.users if user != permute_node]
113151
for user in users:
@@ -118,54 +156,96 @@ def insert_output_transpose(self, node, graph_module):
118156
q_params = node.args[0].args[1:]
119157
insert_q_dq_pair(graph_module.graph, node, q_params)
120158

159+
@staticmethod
160+
def _insert_squeeze_transpose(
161+
input_shape, output_shape, node, input_node, graph_module
162+
):
163+
nhwc_to_nhwc = len(input_shape) == 4 and len(output_shape) <= 3
164+
165+
if nhwc_to_nhwc and AnnotateChannelsLastDimOrder.memory_format_differs(
166+
input_shape
167+
):
168+
AnnotateChannelsLastDimOrder.insert_input_transpose(
169+
node, input_node, graph_module
170+
)
171+
172+
@staticmethod
173+
def _insert_unsqueeze_transpose(input_shape, output_shape, node, graph_module):
174+
nchw_to_nhwc = len(input_shape) == 3 and len(output_shape) == 4
175+
if nchw_to_nhwc and AnnotateChannelsLastDimOrder.memory_format_differs(
176+
output_shape
177+
):
178+
AnnotateChannelsLastDimOrder.insert_output_transpose(node, graph_module)
179+
180+
@staticmethod
181+
def _insert_view_transpose(
182+
input_shape, output_shape, node, input_node, graph_module
183+
):
184+
nchw_to_nhwc = len(input_shape) < 4 and len(output_shape) == 4
185+
nhwc_to_nchw = len(input_shape) == 4 and len(output_shape) < 4
186+
channel_reshape = AnnotateChannelsLastDimOrder.is_channel_reshape(
187+
output_shape, input_shape
188+
)
189+
190+
if (
191+
channel_reshape or nhwc_to_nchw
192+
) and AnnotateChannelsLastDimOrder.memory_format_differs(input_shape):
193+
AnnotateChannelsLastDimOrder.insert_input_transpose(
194+
node, input_node, graph_module
195+
)
196+
if (
197+
channel_reshape or nchw_to_nhwc
198+
) and AnnotateChannelsLastDimOrder.memory_format_differs(output_shape):
199+
AnnotateChannelsLastDimOrder.insert_output_transpose(node, graph_module)
200+
121201
def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
122202
"""
123-
Reshape operations are not equivalent in NCHW and NHWC.
124-
To get around this, transposes need to be added if the previous or new shape
125-
fulfil the following condition:
126-
C > 1 and (H or W > 1)
127-
128-
This is relevant for the following operations;
129-
squeeze: 4D -> 3D
130-
unsqueeze: <4D -> 4D
131-
view: <4D -> 4D
132-
view: 4D -> <4D
133-
view: 4D -> 4D
134-
"""
135-
136-
def transpose_condition(shape):
137-
if len(shape) != 4:
138-
return False
139-
C = shape[1]
140-
H = shape[2]
141-
W = shape[3]
142-
return C > 1 and (H > 1 or W > 1)
203+
Transposes are needed for operators transforming the input to a different rank, as 4D-tensors are assumed to be in NHWC-format, whereas all other are in NCHW format.
204+
This is relevant for the following cases:
205+
- squeeze: 4D -> <4D
206+
- unsqueeze: 3D -> 4D
207+
- view: <4D -> 4D
208+
- view: 4D -> <4D
209+
Additionally, a 4D->4D view operation acting on the channel dimension currently needs to be performed in NCHW format, leadning to one extra input and output transpose for this case.
143210
211+
Transposes can be avoided for shapes where there is no difference in actual memory, e.g for
212+
- H == W == 1
213+
- C == 1
214+
- 1D/2D tensors
215+
"""
144216
for node in graph_module.graph.nodes:
145217
if node.op != "call_function":
146218
continue
219+
147220
if node.target == exir_ops.edge.aten.squeeze_copy.dims:
148221
input_node = node.args[0]
149222
input_shape = input_node.meta["val"].shape
150-
if transpose_condition(input_shape):
151-
self.insert_input_transpose(node, input_node, graph_module)
223+
output_shape = node.meta["val"].shape
224+
225+
self._insert_squeeze_transpose(
226+
input_shape, output_shape, node, input_node, graph_module
227+
)
152228

153229
elif node.target == exir_ops.edge.aten.unsqueeze_copy.default:
230+
input_node = get_node_arg(node.args, 0, default_value=False)
231+
if input_node:
232+
input_shape = input_node.meta["val"].shape
233+
else:
234+
input_shape = ()
154235
output_shape = node.meta["val"].shape
155-
if transpose_condition(output_shape):
156-
self.insert_output_transpose(node, graph_module)
236+
237+
self._insert_unsqueeze_transpose(
238+
input_shape, output_shape, node, graph_module
239+
)
157240

158241
elif node.target == exir_ops.edge.aten.view_copy.default:
159242
input_node = node.args[0]
243+
input_shape = input_node.meta["val"].shape
244+
output_shape = node.meta["val"].shape
160245

161-
old_shape = input_node.meta["val"].shape
162-
new_shape = node.meta["val"].shape
163-
164-
if transpose_condition(old_shape):
165-
self.insert_input_transpose(node, input_node, graph_module)
166-
167-
if transpose_condition(new_shape):
168-
self.insert_output_transpose(node, graph_module)
246+
self._insert_view_transpose(
247+
input_shape, output_shape, node, input_node, graph_module
248+
)
169249

170250
def call(self, graph_module: torch.fx.GraphModule):
171251
for node in graph_module.graph.nodes:

backends/arm/test/ops/test_view.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class View(torch.nn.Module):
4343
(torch.rand(1, 1, 5, 10), (1, 1, 50, 1)),
4444
(torch.rand(5, 10, 1, 1), (1, 25, 2)),
4545
(torch.rand(2, 50, 1, 1), (1, 100)),
46+
(torch.rand(2, 3, 2, 3), (2, 3, 3, 2)),
4647
]
4748

4849
def forward(self, x: torch.Tensor, new_shape):

0 commit comments

Comments
 (0)