Skip to content

Commit 173e41f

Browse files
committed
Optimize transposes in XNNPACK partition
1 parent a8e4be4 commit 173e41f

File tree

4 files changed

+201
-8
lines changed

4 files changed

+201
-8
lines changed

backends/xnnpack/_passes/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
FuseBatchNormWithConvPass,
2626
)
2727
from executorch.backends.xnnpack._passes.prelu_reshape_pass import PReLUReshapePass
28+
29+
from executorch.backends.xnnpack._passes.remove_redundant_ops_pass import (
30+
RemoveRedundantOpsPass,
31+
)
2832
from executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (
2933
TagImplicitQDqPass,
3034
)
@@ -70,6 +74,7 @@ def __init__(
7074
Conv1dUnsqueezePass,
7175
PReLUReshapePass,
7276
ChannelsLastTaggedReshapePass,
77+
RemoveRedundantOpsPass,
7378
TagImplicitQDqPass,
7479
]
7580
else:

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from enum import Enum
78
from typing import Optional, Tuple
89

910
import torch
@@ -14,6 +15,11 @@
1415
from executorch.exir.pass_base import PassResult
1516

1617

18+
class InputDimOrder(Enum):
19+
NCHW = 1
20+
NHWC = 2
21+
22+
1723
# TODO(T151254305) use subgraph_rewriter
1824
class ChannelsLastTaggedReshapePass(XNNPACKPass):
1925
"""
@@ -84,11 +90,13 @@ def mark_as_nhwc_node(self, node: torch.fx.Node) -> None:
8490
def mark_as_nchw_node(self, node: torch.fx.Node) -> None:
8591
node.meta[ChannelsLastTaggedReshapePass.XNN_NHWC_NODE] = False
8692

87-
def is_nhwc_node(self, node: torch.fx.Node) -> bool:
93+
@staticmethod
94+
def is_nhwc_node(node: torch.fx.Node) -> bool:
8895
return node.meta.get(ChannelsLastTaggedReshapePass.XNN_NHWC_NODE, False)
8996

90-
def is_nchw_node(self, node: torch.fx.Node) -> bool:
91-
return not self.is_nhwc_node(node)
97+
@staticmethod
98+
def is_nchw_node(node: torch.fx.Node) -> bool:
99+
return not ChannelsLastTaggedReshapePass.is_nhwc_node(node)
92100

93101
def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
94102
return (
@@ -114,7 +122,7 @@ def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
114122
is_nchw_constant = (
115123
is_param_node(self.exported_program, node)
116124
and (ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in node.meta)
117-
and (self.is_nchw_node(node))
125+
and (ChannelsLastTaggedReshapePass.is_nchw_node(node))
118126
)
119127
return is_4d and not is_nchw_constant
120128

@@ -257,6 +265,22 @@ def insert_copy_and_assign_partner_nodes_quantization_sensitive(
257265
# in that case
258266
self.make_partners(original_input, copy_node)
259267

268+
def input_dim_order(
269+
self, input_node: torch.fx.Node, input_order: InputDimOrder
270+
) -> bool:
271+
if input_node.name == "x":
272+
return (
273+
input_node.meta["val"].is_contiguous()
274+
if input_order == InputDimOrder.NCHW
275+
else not input_node.meta["val"].is_contiguous()
276+
)
277+
else:
278+
return (
279+
ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
280+
if input_order == InputDimOrder.NCHW
281+
else ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
282+
)
283+
260284
def input_to_nhwc(
261285
self,
262286
graph_module: torch.fx.GraphModule,
@@ -266,7 +290,7 @@ def input_to_nhwc(
266290
if is_param_node(self.exported_program, input_node):
267291
if (
268292
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in input_node.meta
269-
and self.is_nchw_node(input_node)
293+
and ChannelsLastTaggedReshapePass.is_nchw_node(input_node)
270294
):
271295
# This constant data tensor has been used somewhere else
272296
# in NCHW format so we can't use it here in NHWC format
@@ -283,6 +307,9 @@ def input_to_nhwc(
283307
elif self.is_nhwc_node(input_node):
284308
return
285309

310+
if self.input_dim_order(input_node, InputDimOrder.NHWC):
311+
return
312+
286313
if not self.can_be_converted_to_nhwc(input_node):
287314
raise AssertionError(
288315
"Attempting to convert non-NHWC compatible node to NHWC"
@@ -332,7 +359,7 @@ def input_to_nchw(
332359
if is_param_node(self.exported_program, input_node):
333360
if (
334361
ChannelsLastTaggedReshapePass.XNN_NHWC_NODE in input_node.meta
335-
and self.is_nhwc_node(input_node)
362+
and ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
336363
):
337364
# This constant data tensor has been used somewhere else
338365
# in NHWC format so we can't use it here in NCHW format
@@ -350,6 +377,9 @@ def input_to_nchw(
350377
elif self.is_nchw_node(input_node):
351378
return
352379

380+
if self.input_dim_order(input_node, InputDimOrder.NCHW):
381+
return
382+
353383
if ChannelsLastTaggedReshapePass.PARTNER_NODE in input_node.meta:
354384
# Already has an associated NCHW node
355385
input_node_nchw = input_node.meta[
@@ -391,7 +421,7 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
391421
self.input_to_nhwc(graph_module, node.args[0], node)
392422

393423
for input_node in node.all_input_nodes[1:]:
394-
if self.is_nhwc_node(input_node):
424+
if ChannelsLastTaggedReshapePass.is_nhwc_node(input_node):
395425
raise AssertionError(
396426
f"Expected {input_node} to be NCHW in channels last reshape pass"
397427
)
@@ -409,7 +439,8 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
409439
# The node can have inputs in any format (but all must be the
410440
# same format)
411441
is_or_isnt_nhwc_node = [
412-
self.is_nhwc_node(input_node) for input_node in node.all_input_nodes
442+
ChannelsLastTaggedReshapePass.is_nhwc_node(input_node)
443+
for input_node in node.all_input_nodes
413444
]
414445
if all(is_or_isnt_nhwc_node):
415446
# All inputs are nhwc so this node's output is nhwc too
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
9+
ChannelsLastTaggedReshapePass,
10+
)
11+
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import PassResult
14+
15+
16+
class RemoveRedundantOpsPass(XNNPACKPass):
17+
def call(self, graph_module: torch.fx.GraphModule):
18+
graph = graph_module.graph
19+
original_nodes = list(graph.nodes)
20+
21+
# Store first subsequent visitation of to_copy node
22+
prev = None
23+
for node in original_nodes:
24+
if len(node.all_input_nodes) == 0:
25+
continue
26+
27+
# If we encounter a to_copy node, check if it is preceded by an opposite to_copy node
28+
if node.target == exir_ops.edge.aten._to_copy.default:
29+
if prev and ChannelsLastTaggedReshapePass.is_nchw_node(
30+
prev
31+
) != ChannelsLastTaggedReshapePass.is_nchw_node(node):
32+
# If we find an opposite to_copy node, remove both nodes
33+
prevPrev = prev.args[0]
34+
35+
for user in node.users.copy():
36+
user.replace_input_with(node, prevPrev)
37+
38+
graph.erase_node(node)
39+
graph.erase_node(prev)
40+
41+
prev = None
42+
continue
43+
prev = node
44+
else:
45+
prev = None
46+
47+
graph_module.recompile()
48+
49+
# Since we are overriding "call", we need to call the parent's "call"
50+
# to retrace the graph and regenerate metadata
51+
graph_module = super().call(graph_module).graph_module
52+
53+
return PassResult(graph_module, True)
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack._passes.channels_last_tagged_reshape_pass import (
11+
ChannelsLastTaggedReshapePass,
12+
)
13+
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
14+
from executorch.backends.xnnpack._passes.remove_redundant_ops_pass import (
15+
RemoveRedundantOpsPass,
16+
)
17+
from executorch.backends.xnnpack.test.tester import RunPasses, Tester
18+
from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass
19+
20+
21+
class TestChannelsLastTaggedReshapePass(unittest.TestCase):
22+
PassStage = RunPasses(
23+
[
24+
DimOrderOpsRevertPass,
25+
ConvertToLinearPass,
26+
ChannelsLastTaggedReshapePass,
27+
RemoveRedundantOpsPass,
28+
]
29+
)
30+
31+
def setUp(self):
32+
torch._dynamo.reset()
33+
34+
def run_tester(self, module, inputs):
35+
tester = Tester(
36+
module.eval(),
37+
inputs,
38+
)
39+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
40+
41+
class ChannelsLastToContiguous(torch.nn.Module):
42+
def __init__(self):
43+
super().__init__()
44+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
45+
self.linear1 = torch.nn.Linear(4, 3)
46+
47+
def forward(self, x):
48+
y = self.linear1(x)
49+
y = y.to(memory_format=torch.channels_last)
50+
y = y.to(memory_format=torch.contiguous_format)
51+
y = y.to(memory_format=torch.channels_last)
52+
y = y.to(memory_format=torch.contiguous_format)
53+
y = y.to(memory_format=torch.channels_last)
54+
y = y.to(memory_format=torch.contiguous_format)
55+
return self.conv1(y)
56+
57+
ChannelsLastToContiguousModule = ChannelsLastToContiguous()
58+
59+
class ContiguousToChannelsLast(torch.nn.Module):
60+
def __init__(self):
61+
super().__init__()
62+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
63+
self.linear1 = torch.nn.Linear(4, 3)
64+
65+
def forward(self, x):
66+
y = self.linear1(x)
67+
y = y.to(memory_format=torch.contiguous_format)
68+
y = y.to(memory_format=torch.channels_last)
69+
y = y.to(memory_format=torch.contiguous_format)
70+
y = y.to(memory_format=torch.channels_last)
71+
y = y.to(memory_format=torch.contiguous_format)
72+
y = y.to(memory_format=torch.channels_last)
73+
74+
return self.conv1(y)
75+
76+
ContiguousToChannelsLastModule = ContiguousToChannelsLast()
77+
78+
def test_redundant_to_copy_op_removal(self):
79+
(
80+
Tester(self.ChannelsLastToContiguousModule, (torch.randn(1, 3, 6, 4),))
81+
.export()
82+
.to_edge()
83+
.run_passes(self.PassStage)
84+
.check_count(
85+
{
86+
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 2,
87+
}
88+
)
89+
.run_method_and_compare_outputs()
90+
)
91+
92+
def test_redundant_to_copy_op_removal_2(self):
93+
(
94+
Tester(self.ContiguousToChannelsLastModule, (torch.randn(1, 3, 6, 4),))
95+
.export()
96+
.to_edge()
97+
.run_passes(self.PassStage)
98+
.check_count(
99+
{
100+
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 1,
101+
}
102+
)
103+
.run_method_and_compare_outputs()
104+
)

0 commit comments

Comments
 (0)