4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ from enum import Enum
7
8
from typing import Optional , Tuple
8
9
9
10
import torch
14
15
from executorch .exir .pass_base import PassResult
15
16
16
17
18
+ class InputDimOrder (Enum ):
19
+ NCHW = 1
20
+ NHWC = 2
21
+
22
+
17
23
# TODO(T151254305) use subgraph_rewriter
18
24
class ChannelsLastTaggedReshapePass (XNNPACKPass ):
19
25
"""
@@ -84,11 +90,13 @@ def mark_as_nhwc_node(self, node: torch.fx.Node) -> None:
84
90
def mark_as_nchw_node (self , node : torch .fx .Node ) -> None :
85
91
node .meta [ChannelsLastTaggedReshapePass .XNN_NHWC_NODE ] = False
86
92
87
- def is_nhwc_node (self , node : torch .fx .Node ) -> bool :
93
+ @staticmethod
94
+ def is_nhwc_node (node : torch .fx .Node ) -> bool :
88
95
return node .meta .get (ChannelsLastTaggedReshapePass .XNN_NHWC_NODE , False )
89
96
90
- def is_nchw_node (self , node : torch .fx .Node ) -> bool :
91
- return not self .is_nhwc_node (node )
97
+ @staticmethod
98
+ def is_nchw_node (node : torch .fx .Node ) -> bool :
99
+ return not ChannelsLastTaggedReshapePass .is_nhwc_node (node )
92
100
93
101
def requires_nhwc_input (self , node : torch .fx .Node ) -> bool :
94
102
return (
@@ -114,7 +122,7 @@ def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
114
122
is_nchw_constant = (
115
123
is_param_node (self .exported_program , node )
116
124
and (ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in node .meta )
117
- and (self .is_nchw_node (node ))
125
+ and (ChannelsLastTaggedReshapePass .is_nchw_node (node ))
118
126
)
119
127
return is_4d and not is_nchw_constant
120
128
@@ -257,6 +265,22 @@ def insert_copy_and_assign_partner_nodes_quantization_sensitive(
257
265
# in that case
258
266
self .make_partners (original_input , copy_node )
259
267
268
+ def input_dim_order (
269
+ self , input_node : torch .fx .Node , input_order : InputDimOrder
270
+ ) -> bool :
271
+ if input_node .name == "x" :
272
+ return (
273
+ input_node .meta ["val" ].is_contiguous ()
274
+ if input_order == InputDimOrder .NCHW
275
+ else not input_node .meta ["val" ].is_contiguous ()
276
+ )
277
+ else :
278
+ return (
279
+ ChannelsLastTaggedReshapePass .is_nchw_node (input_node )
280
+ if input_order == InputDimOrder .NCHW
281
+ else ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
282
+ )
283
+
260
284
def input_to_nhwc (
261
285
self ,
262
286
graph_module : torch .fx .GraphModule ,
@@ -266,7 +290,7 @@ def input_to_nhwc(
266
290
if is_param_node (self .exported_program , input_node ):
267
291
if (
268
292
ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in input_node .meta
269
- and self .is_nchw_node (input_node )
293
+ and ChannelsLastTaggedReshapePass .is_nchw_node (input_node )
270
294
):
271
295
# This constant data tensor has been used somewhere else
272
296
# in NCHW format so we can't use it here in NHWC format
@@ -283,6 +307,9 @@ def input_to_nhwc(
283
307
elif self .is_nhwc_node (input_node ):
284
308
return
285
309
310
+ if self .input_dim_order (input_node , InputDimOrder .NHWC ):
311
+ return
312
+
286
313
if not self .can_be_converted_to_nhwc (input_node ):
287
314
raise AssertionError (
288
315
"Attempting to convert non-NHWC compatible node to NHWC"
@@ -332,7 +359,7 @@ def input_to_nchw(
332
359
if is_param_node (self .exported_program , input_node ):
333
360
if (
334
361
ChannelsLastTaggedReshapePass .XNN_NHWC_NODE in input_node .meta
335
- and self .is_nhwc_node (input_node )
362
+ and ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
336
363
):
337
364
# This constant data tensor has been used somewhere else
338
365
# in NHWC format so we can't use it here in NCHW format
@@ -350,6 +377,9 @@ def input_to_nchw(
350
377
elif self .is_nchw_node (input_node ):
351
378
return
352
379
380
+ if self .input_dim_order (input_node , InputDimOrder .NCHW ):
381
+ return
382
+
353
383
if ChannelsLastTaggedReshapePass .PARTNER_NODE in input_node .meta :
354
384
# Already has an associated NCHW node
355
385
input_node_nchw = input_node .meta [
@@ -391,7 +421,7 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
391
421
self .input_to_nhwc (graph_module , node .args [0 ], node )
392
422
393
423
for input_node in node .all_input_nodes [1 :]:
394
- if self .is_nhwc_node (input_node ):
424
+ if ChannelsLastTaggedReshapePass .is_nhwc_node (input_node ):
395
425
raise AssertionError (
396
426
f"Expected { input_node } to be NCHW in channels last reshape pass"
397
427
)
@@ -409,7 +439,8 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
409
439
# The node can have inputs in any format (but all must be the
410
440
# same format)
411
441
is_or_isnt_nhwc_node = [
412
- self .is_nhwc_node (input_node ) for input_node in node .all_input_nodes
442
+ ChannelsLastTaggedReshapePass .is_nhwc_node (input_node )
443
+ for input_node in node .all_input_nodes
413
444
]
414
445
if all (is_or_isnt_nhwc_node ):
415
446
# All inputs are nhwc so this node's output is nhwc too
0 commit comments