12
12
from executorch .backends .arm ._passes .arm_pass_utils import (
13
13
create_node ,
14
14
get_first_fake_tensor ,
15
+ get_node_arg ,
15
16
insert_q_dq_pair ,
16
17
)
17
18
from executorch .backends .arm .tosa_quant_utils import dq_op , q_op , register_passable_op
@@ -83,14 +84,48 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
83
84
84
85
return False
85
86
86
- def insert_input_transpose (self , node , input_node , graph_module ):
87
+ @staticmethod
88
+ def memory_format_differs (shape ):
89
+ """Returns true if the shape will have a different memory layout in NCHW and NHWC format"""
90
+ if len (shape ) >= 4 :
91
+ C = shape [1 ]
92
+ H = shape [2 ]
93
+ W = shape [3 ]
94
+ elif len (shape ) == 3 :
95
+ C = shape [0 ]
96
+ H = shape [1 ]
97
+ W = shape [2 ]
98
+ if len (shape ) <= 2 :
99
+ return False
100
+
101
+ return C > 1 and (H > 1 or W > 1 )
102
+
103
+ @staticmethod
104
+ def is_channel_reshape (input_shape , output_shape ):
105
+ """Returns true if the reshape changes the channel dimension"""
106
+ if not len (input_shape ) == len (output_shape ) == 4 :
107
+ return False
108
+
109
+ C_old = input_shape [1 ]
110
+ C_new = output_shape [1 ]
111
+
112
+ N_new = output_shape [0 ]
113
+ N_old = input_shape [0 ]
114
+
115
+ return (N_old != N_new ) or (C_old != C_new )
116
+
117
+ @staticmethod
118
+ def insert_input_transpose (node , input_node , graph_module ):
87
119
quantize = input_node .target == dq_op
88
120
q_params = input_node .args [1 :] if quantize else None
89
121
with graph_module .graph .inserting_before (node ):
90
122
permute_node = create_node (
91
123
graph_module .graph ,
92
124
torch .ops .passthrough_to_tosa ._transpose ,
93
- args = (input_node , list (self .NHWC_inverse_order )),
125
+ args = (
126
+ input_node ,
127
+ list (AnnotateChannelsLastDimOrder .NHWC_inverse_order ),
128
+ ),
94
129
quantize = quantize ,
95
130
q_params = q_params ,
96
131
)
@@ -100,14 +135,17 @@ def insert_input_transpose(self, node, input_node, graph_module):
100
135
range (len (input_node .meta ["val" ].size ()))
101
136
)
102
137
103
- def insert_output_transpose (self , node , graph_module ):
138
+ @staticmethod
139
+ def insert_output_transpose (node , graph_module ):
104
140
with graph_module .graph .inserting_after (node ):
105
141
permute_node = create_node (
106
142
graph_module .graph ,
107
143
torch .ops .passthrough_to_tosa ._transpose ,
108
- args = (node , list (self .NHWC_order )),
144
+ args = (node , list (AnnotateChannelsLastDimOrder .NHWC_order )),
145
+ )
146
+ permute_node .meta ["tosa_dim_order" ] = (
147
+ AnnotateChannelsLastDimOrder .NHWC_order
109
148
)
110
- permute_node .meta ["tosa_dim_order" ] = self .NHWC_order
111
149
node .meta ["tosa_dim_order" ] = (0 , 1 , 2 , 3 )
112
150
users = [user for user in node .users if user != permute_node ]
113
151
for user in users :
@@ -118,54 +156,96 @@ def insert_output_transpose(self, node, graph_module):
118
156
q_params = node .args [0 ].args [1 :]
119
157
insert_q_dq_pair (graph_module .graph , node , q_params )
120
158
159
+ @staticmethod
160
+ def _insert_squeeze_transpose (
161
+ input_shape , output_shape , node , input_node , graph_module
162
+ ):
163
+ nhwc_to_nhwc = len (input_shape ) == 4 and len (output_shape ) <= 3
164
+
165
+ if nhwc_to_nhwc and AnnotateChannelsLastDimOrder .memory_format_differs (
166
+ input_shape
167
+ ):
168
+ AnnotateChannelsLastDimOrder .insert_input_transpose (
169
+ node , input_node , graph_module
170
+ )
171
+
172
+ @staticmethod
173
+ def _insert_unsqueeze_transpose (input_shape , output_shape , node , graph_module ):
174
+ nchw_to_nhwc = len (input_shape ) == 3 and len (output_shape ) == 4
175
+ if nchw_to_nhwc and AnnotateChannelsLastDimOrder .memory_format_differs (
176
+ output_shape
177
+ ):
178
+ AnnotateChannelsLastDimOrder .insert_output_transpose (node , graph_module )
179
+
180
+ @staticmethod
181
+ def _insert_view_transpose (
182
+ input_shape , output_shape , node , input_node , graph_module
183
+ ):
184
+ nchw_to_nhwc = len (input_shape ) < 4 and len (output_shape ) == 4
185
+ nhwc_to_nchw = len (input_shape ) == 4 and len (output_shape ) < 4
186
+ channel_reshape = AnnotateChannelsLastDimOrder .is_channel_reshape (
187
+ output_shape , input_shape
188
+ )
189
+
190
+ if (
191
+ channel_reshape or nhwc_to_nchw
192
+ ) and AnnotateChannelsLastDimOrder .memory_format_differs (input_shape ):
193
+ AnnotateChannelsLastDimOrder .insert_input_transpose (
194
+ node , input_node , graph_module
195
+ )
196
+ if (
197
+ channel_reshape or nchw_to_nhwc
198
+ ) and AnnotateChannelsLastDimOrder .memory_format_differs (output_shape ):
199
+ AnnotateChannelsLastDimOrder .insert_output_transpose (node , graph_module )
200
+
121
201
def insert_tosa_transposes (self , graph_module : torch .fx .GraphModule ):
122
202
"""
123
- Reshape operations are not equivalent in NCHW and NHWC.
124
- To get around this, transposes need to be added if the previous or new shape
125
- fulfil the following condition:
126
- C > 1 and (H or W > 1)
127
-
128
- This is relevant for the following operations;
129
- squeeze: 4D -> 3D
130
- unsqueeze: <4D -> 4D
131
- view: <4D -> 4D
132
- view: 4D -> <4D
133
- view: 4D -> 4D
134
- """
135
-
136
- def transpose_condition (shape ):
137
- if len (shape ) != 4 :
138
- return False
139
- C = shape [1 ]
140
- H = shape [2 ]
141
- W = shape [3 ]
142
- return C > 1 and (H > 1 or W > 1 )
203
+ Transposes are needed for operators transforming the input to a different rank, as 4D-tensors are assumed to be in NHWC-format, whereas all other are in NCHW format.
204
+ This is relevant for the following cases:
205
+ - squeeze: 4D -> <4D
206
+ - unsqueeze: 3D -> 4D
207
+ - view: <4D -> 4D
208
+ - view: 4D -> <4D
209
+ Additionally, a 4D->4D view operation acting on the channel dimension currently needs to be performed in NCHW format, leadning to one extra input and output transpose for this case.
143
210
211
+ Transposes can be avoided for shapes where there is no difference in actual memory, e.g for
212
+ - H == W == 1
213
+ - C == 1
214
+ - 1D/2D tensors
215
+ """
144
216
for node in graph_module .graph .nodes :
145
217
if node .op != "call_function" :
146
218
continue
219
+
147
220
if node .target == exir_ops .edge .aten .squeeze_copy .dims :
148
221
input_node = node .args [0 ]
149
222
input_shape = input_node .meta ["val" ].shape
150
- if transpose_condition (input_shape ):
151
- self .insert_input_transpose (node , input_node , graph_module )
223
+ output_shape = node .meta ["val" ].shape
224
+
225
+ self ._insert_squeeze_transpose (
226
+ input_shape , output_shape , node , input_node , graph_module
227
+ )
152
228
153
229
elif node .target == exir_ops .edge .aten .unsqueeze_copy .default :
230
+ input_node = get_node_arg (node .args , 0 , default_value = False )
231
+ if input_node :
232
+ input_shape = input_node .meta ["val" ].shape
233
+ else :
234
+ input_shape = ()
154
235
output_shape = node .meta ["val" ].shape
155
- if transpose_condition (output_shape ):
156
- self .insert_output_transpose (node , graph_module )
236
+
237
+ self ._insert_unsqueeze_transpose (
238
+ input_shape , output_shape , node , graph_module
239
+ )
157
240
158
241
elif node .target == exir_ops .edge .aten .view_copy .default :
159
242
input_node = node .args [0 ]
243
+ input_shape = input_node .meta ["val" ].shape
244
+ output_shape = node .meta ["val" ].shape
160
245
161
- old_shape = input_node .meta ["val" ].shape
162
- new_shape = node .meta ["val" ].shape
163
-
164
- if transpose_condition (old_shape ):
165
- self .insert_input_transpose (node , input_node , graph_module )
166
-
167
- if transpose_condition (new_shape ):
168
- self .insert_output_transpose (node , graph_module )
246
+ self ._insert_view_transpose (
247
+ input_shape , output_shape , node , input_node , graph_module
248
+ )
169
249
170
250
def call (self , graph_module : torch .fx .GraphModule ):
171
251
for node in graph_module .graph .nodes :
0 commit comments