Skip to content

Commit cbfdf78

Browse files
authored
Insert transposes around view_copy ops
Differential Revision: D64764224 Pull Request resolved: #6435
1 parent 28a213f commit cbfdf78

File tree

5 files changed

+126
-56
lines changed

5 files changed

+126
-56
lines changed

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 79 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from executorch.backends.arm._passes.arm_pass_utils import (
1313
create_node,
1414
get_first_fake_tensor,
15+
insert_q_dq_pair,
1516
)
16-
from executorch.backends.arm.tosa_quant_utils import dq_op
17+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
1718
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
1819
from executorch.exir.dialects._ops import ops as exir_ops
1920
from executorch.exir.pass_base import ExportPass, PassResult
@@ -79,37 +80,89 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
7980

8081
return False
8182

83+
def insert_input_transpose(self, node, input_node, graph_module):
84+
quantize = input_node.target == dq_op
85+
q_params = input_node.args[1:] if quantize else None
86+
with graph_module.graph.inserting_before(node):
87+
permute_node = create_node(
88+
graph_module.graph,
89+
torch.ops.passthrough_to_tosa._transpose,
90+
args=(input_node, list(self.NHWC_inverse_order)),
91+
quantize=quantize,
92+
q_params=q_params,
93+
)
94+
node.replace_input_with(input_node, permute_node)
95+
96+
permute_node.meta["tosa_dim_order"] = tuple(
97+
range(len(input_node.meta["val"].size()))
98+
)
99+
100+
def insert_output_transpose(self, node, graph_module):
101+
with graph_module.graph.inserting_after(node):
102+
permute_node = create_node(
103+
graph_module.graph,
104+
torch.ops.passthrough_to_tosa._transpose,
105+
args=(node, list(self.NHWC_order)),
106+
)
107+
permute_node.meta["tosa_dim_order"] = self.NHWC_order
108+
node.meta["tosa_dim_order"] = (0, 1, 2, 3)
109+
users = [user for user in node.users if user != permute_node]
110+
for user in users:
111+
user.replace_input_with(node, permute_node)
112+
113+
quantize = node.args[0] == q_op
114+
if quantize:
115+
q_params = node.args[0].args[1:]
116+
insert_q_dq_pair(graph_module.graph, node, q_params)
117+
82118
def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
119+
"""
120+
Reshape operations are not equivalent in NCHW and NHWC.
121+
To get around this, transposes need to be added if the previous or new shape
122+
fulfil the following condition:
123+
C > 1 and (H or W > 1)
124+
125+
This is relevant for the following operations;
126+
squeeze: 4D -> 3D
127+
unsqueeze: <4D -> 4D
128+
view: <4D -> 4D
129+
view: 4D -> <4D
130+
view: 4D -> 4D
131+
"""
132+
133+
def transpose_condition(shape):
134+
if len(shape) != 4:
135+
return False
136+
C = shape[1]
137+
H = shape[2]
138+
W = shape[3]
139+
return C > 1 and (H > 1 or W > 1)
140+
83141
for node in graph_module.graph.nodes:
84142
if node.op != "call_function":
85143
continue
86144
if node.target == exir_ops.edge.aten.squeeze_copy.dims:
87145
input_node = node.args[0]
88-
if input_node.meta["val"].dim() == 4:
89-
with graph_module.graph.inserting_before(node):
90-
permute_node = create_node(
91-
graph_module.graph,
92-
torch.ops.passthrough_to_tosa._transpose,
93-
args=(input_node, list(self.NHWC_inverse_order)),
94-
)
95-
permute_node.meta["tosa_dim_order"] = tuple(
96-
range(len(input_node.meta["val"].size()))
97-
)
98-
node.replace_input_with(input_node, permute_node)
99-
100-
if node.target == exir_ops.edge.aten.unsqueeze_copy.default:
101-
if node.meta["val"].dim() == 4:
102-
with graph_module.graph.inserting_after(node):
103-
permute_node = create_node(
104-
graph_module.graph,
105-
torch.ops.passthrough_to_tosa._transpose,
106-
args=(node, list(self.NHWC_order)),
107-
)
108-
permute_node.meta["tosa_dim_order"] = self.NHWC_order
109-
node.meta["tosa_dim_order"] = (0, 1, 2, 3)
110-
users = [user for user in node.users if user != permute_node]
111-
for user in users:
112-
user.replace_input_with(node, permute_node)
146+
input_shape = input_node.meta["val"].shape
147+
if transpose_condition(input_shape):
148+
self.insert_input_transpose(node, input_node, graph_module)
149+
150+
elif node.target == exir_ops.edge.aten.unsqueeze_copy.default:
151+
output_shape = node.meta["val"].shape
152+
if transpose_condition(output_shape):
153+
self.insert_output_transpose(node, graph_module)
154+
155+
elif node.target == exir_ops.edge.aten.view_copy.default:
156+
input_node = node.args[0]
157+
158+
old_shape = input_node.meta["val"].shape
159+
new_shape = node.meta["val"].shape
160+
161+
if transpose_condition(old_shape):
162+
self.insert_input_transpose(node, input_node, graph_module)
163+
164+
if transpose_condition(new_shape):
165+
self.insert_output_transpose(node, graph_module)
113166

114167
def call(self, graph_module: torch.fx.GraphModule):
115168
for node in graph_module.graph.nodes:

backends/arm/test/misc/test_debug_feats.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,8 @@ def test_numerical_diff_prints(self):
107107
ArmTester(
108108
model,
109109
example_inputs=model.get_inputs(),
110-
compile_spec=common.get_tosa_compile_spec(),
110+
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
111111
)
112-
.quantize()
113112
.export()
114113
.to_edge()
115114
.partition()

backends/arm/test/ops/test_layer_norm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def _test_layernorm_tosa_MI_pipeline(
7474
ArmTester(
7575
model=module,
7676
example_inputs=test_data,
77-
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
77+
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True),
7878
)
7979
.export()
8080
.check(["torch.ops.aten.layer_norm.default"])
@@ -93,7 +93,7 @@ def _test_layernorm_tosa_BI_pipeline(
9393
ArmTester(
9494
model=module,
9595
example_inputs=test_data,
96-
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
96+
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True),
9797
)
9898
.quantize()
9999
.check_not(["torch.ops.aten.layer_norm.default"])
@@ -148,7 +148,8 @@ def test_layer_norm_tosa_BI(
148148
self.LayerNorm(*model_params), (test_data,)
149149
)
150150

151-
@parameterized.expand(test_data_suite)
151+
# Skip tests that require transposes.
152+
@parameterized.expand(test_data_suite[:-2])
152153
def test_layer_norm_u55_BI(
153154
self,
154155
test_name: str,

backends/arm/test/ops/test_squeeze.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def _test_squeeze_tosa_BI_pipeline(
9595
.check_count({export_target: 1})
9696
.to_edge()
9797
.partition()
98-
.dump_artifact()
9998
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
10099
.to_executorch()
101100
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
@@ -156,7 +155,7 @@ def test_squeeze_u85_BI(
156155
test_tensor: torch.Tensor,
157156
):
158157
self._test_squeeze_ethosu_BI_pipeline(
159-
common.get_u85_compile_spec(permute_memory_to_nhwc=False),
158+
common.get_u85_compile_spec(permute_memory_to_nhwc=True),
160159
self.Squeeze(),
161160
(test_tensor,),
162161
"torch.ops.aten.squeeze.default",

backends/arm/test/ops/test_view.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,33 @@
2525
from parameterized import parameterized
2626

2727

28-
class TestSimpleView(unittest.TestCase):
28+
class TestView(unittest.TestCase):
2929
"""Tests the view operation."""
3030

3131
class View(torch.nn.Module):
3232

33-
sizes = [10, 15, 50, 100]
34-
test_parameters = [(torch.ones(n),) for n in sizes]
35-
36-
def forward(self, x: torch.Tensor):
37-
return x.view(-1, 5)
33+
needs_transpose_tests = [
34+
(torch.rand(100), (1, -1, 5, 2)),
35+
(torch.rand(10, 2, 1, 5), (1, -1, 5, 2)),
36+
(torch.rand(1, 2, 1, 9), (3, 1, 3, 2)),
37+
(torch.rand(2, 1, 1, 9), (3, 2, 3, 1)),
38+
(torch.rand(2, 50, 2, 1), (1, 200)),
39+
(torch.rand(2, 5, 2, 3), (1, 15, 4)),
40+
]
41+
42+
no_transpose_tests = [
43+
(torch.rand(2, 1, 1, 9), (3, 1, 3, 2)),
44+
(torch.rand(5, 10, 1, 1), (25, 2, 1, 1)),
45+
(torch.rand(10, 2), (1, 1, 5, 4)),
46+
(torch.rand(10, 10), (5, 1, 5, 4)),
47+
(torch.rand(1, 1, 1, 10), (1, 1, 10, 1)),
48+
(torch.rand(1, 1, 5, 10), (1, 1, 50, 1)),
49+
(torch.rand(5, 10, 1, 1), (1, 25, 2)),
50+
(torch.rand(2, 50, 1, 1), (1, 100)),
51+
]
52+
53+
def forward(self, x: torch.Tensor, new_shape):
54+
return x.view(new_shape)
3855

3956
def _test_view_tosa_MI_pipeline(
4057
self, module: torch.nn.Module, test_data: torch.Tensor
@@ -82,11 +99,7 @@ def _test_view_ethos_BI_pipeline(
8299
):
83100
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
84101
(
85-
ArmTester(
86-
module,
87-
example_inputs=test_data,
88-
compile_spec=common.get_u55_compile_spec(),
89-
)
102+
ArmTester(module, example_inputs=test_data, compile_spec=compile_spec)
90103
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
91104
.export()
92105
.check_count({"torch.ops.aten.view.default": 1})
@@ -110,18 +123,23 @@ def _test_view_u85_BI_pipeline(
110123
common.get_u85_compile_spec(), module, test_data
111124
)
112125

113-
@parameterized.expand(View.test_parameters)
114-
def test_view_tosa_MI(self, test_tensor: torch.Tensor):
115-
self._test_view_tosa_MI_pipeline(self.View(), (test_tensor,))
126+
@parameterized.expand(View.needs_transpose_tests + View.no_transpose_tests)
127+
def test_view_tosa_MI(self, test_tensor: torch.Tensor, new_shape):
128+
self._test_view_tosa_MI_pipeline(self.View(), (test_tensor, new_shape))
129+
130+
@parameterized.expand(View.needs_transpose_tests + View.no_transpose_tests)
131+
def test_view_tosa_BI(self, test_tensor: torch.Tensor, new_shape):
132+
self._test_view_tosa_BI_pipeline(self.View(), (test_tensor, new_shape))
116133

117-
@parameterized.expand(View.test_parameters)
118-
def test_view_tosa_BI(self, test_tensor: torch.Tensor):
119-
self._test_view_tosa_BI_pipeline(self.View(), (test_tensor,))
134+
@parameterized.expand(View.no_transpose_tests)
135+
def test_view_u55_BI(self, test_tensor: torch.Tensor, new_shape):
136+
self._test_view_u55_BI_pipeline(self.View(), (test_tensor, new_shape))
120137

121-
@parameterized.expand(View.test_parameters)
122-
def test_view_u55_BI(self, test_tensor: torch.Tensor):
123-
self._test_view_u55_BI_pipeline(self.View(), (test_tensor,))
138+
@parameterized.expand(View.needs_transpose_tests)
139+
@unittest.expectedFailure
140+
def test_view_transpose_u55_BI(self, test_tensor: torch.Tensor, new_shape):
141+
self._test_view_u55_BI_pipeline(self.View(), (test_tensor, new_shape))
124142

125-
@parameterized.expand(View.test_parameters)
126-
def test_view_u85_BI(self, test_tensor: torch.Tensor):
127-
self._test_view_u85_BI_pipeline(self.View(), (test_tensor,))
143+
@parameterized.expand(View.needs_transpose_tests + View.no_transpose_tests)
144+
def test_view_u85_BI(self, test_tensor: torch.Tensor, new_shape):
145+
self._test_view_u85_BI_pipeline(self.View(), (test_tensor, new_shape))

0 commit comments

Comments
 (0)