Skip to content

Commit c7a3c3f

Browse files
Arm backend: Insert transposes between <4D and 4D tensors
Differential Revision: D64472653 Pull Request resolved: #6045
1 parent 01d8783 commit c7a3c3f

File tree

8 files changed

+136
-29
lines changed

8 files changed

+136
-29
lines changed

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 75 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,50 @@
99
from typing import cast
1010

1111
import torch
12-
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
12+
from executorch.backends.arm._passes.arm_pass_utils import (
13+
create_node,
14+
get_first_fake_tensor,
15+
)
1316
from executorch.backends.arm.tosa_quant_utils import dq_op
1417
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
18+
from executorch.exir.dialects._ops import ops as exir_ops
1519
from executorch.exir.pass_base import ExportPass, PassResult
20+
from torch.library import impl, Library
21+
22+
# Define lib with passthrough operators. The operators have no real meaning in edge IR
23+
# except for argument validaiton and a passthrough output. The operators will be used
24+
# when lowering to TOSA, e.g. a passthrough_to_tosa._transpose will not affect
25+
# the edge IR graph but will be lowered to a TOSA-TRANSPOSE.
26+
lib = Library("passthrough_to_tosa", "DEF")
27+
# For operators that change the rank of the input, such as unsqueeze and squeeze, we may need
28+
# to switch dim_order before the opertation. Changing tosa_dim_order is not sufficient
29+
# as we also need transpose the data into the correct data format.
30+
# By utilizing an edge IR passthrough operator we can keep the edge program in
31+
# channels-first/contiguous and get the desired behavior in the TOSA lowering.
32+
lib.define("_transpose(Tensor self, int[] dim_order) -> Tensor")
33+
34+
35+
@impl(lib, "_transpose")
36+
def _transpose_impl(*args, **kwargs):
37+
# Validate length of dim_order array
38+
dim = args[1]
39+
assert len(dim) <= 4
40+
# Pass-through in edge-IR
41+
return args[0]
1642

1743

1844
class AnnotateChannelsLastDimOrder(ExportPass):
1945
"""
2046
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
21-
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes.
22-
The annotated tosa_dim_order is used to permute the node's shape such that it
23-
gives a TOSA-compliant shape.
47+
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts passthrough_to_tosa._transpose
48+
when a transition between 3D and 4D tensors happen.
49+
The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
2450
"""
2551

52+
NHWC_order = (0, 2, 3, 1)
53+
NHWC_inverse_order = (0, 3, 1, 2)
54+
HWCM_order = (2, 3, 0, 1)
55+
2656
def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
2757
"""
2858
returns True for dq and w in the following sequences;
@@ -49,20 +79,56 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
4979

5080
return False
5181

82+
def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
83+
for node in graph_module.graph.nodes:
84+
if node.op != "call_function":
85+
continue
86+
if node.target == exir_ops.edge.aten.squeeze_copy.dims:
87+
input_node = node.args[0]
88+
if input_node.meta["val"].dim() == 4:
89+
with graph_module.graph.inserting_before(node):
90+
permute_node = create_node(
91+
graph_module.graph,
92+
torch.ops.passthrough_to_tosa._transpose,
93+
args=(input_node, list(self.NHWC_inverse_order)),
94+
)
95+
permute_node.meta["tosa_dim_order"] = tuple(
96+
range(len(input_node.meta["val"].size()))
97+
)
98+
node.replace_input_with(input_node, permute_node)
99+
100+
if node.target == exir_ops.edge.aten.unsqueeze_copy.default:
101+
if node.meta["val"].dim() == 4:
102+
with graph_module.graph.inserting_after(node):
103+
permute_node = create_node(
104+
graph_module.graph,
105+
torch.ops.passthrough_to_tosa._transpose,
106+
args=(node, list(self.NHWC_order)),
107+
)
108+
permute_node.meta["tosa_dim_order"] = self.NHWC_order
109+
node.meta["tosa_dim_order"] = (0, 1, 2, 3)
110+
users = [user for user in node.users if user != permute_node]
111+
for user in users:
112+
user.replace_input_with(node, permute_node)
113+
52114
def call(self, graph_module: torch.fx.GraphModule):
53-
NHWC_Order = (0, 2, 3, 1)
54-
HWCM_Order = (2, 3, 0, 1)
55115
for node in graph_module.graph.nodes:
56116
node_data = get_first_fake_tensor(node).data
57117

58-
if len(node_data.shape) == 4:
59-
dim_order = NHWC_Order
118+
if node_data.dim() == 4:
119+
dim_order = self.NHWC_order
60120
if self.is_weight_node_for_depthwise_conv2d(node):
61121
# The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to
62122
# dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
63-
dim_order = HWCM_Order
123+
dim_order = self.HWCM_order
64124
else:
65125
dim_order = tuple(range(node_data.dim()))
66126
node.meta["tosa_dim_order"] = dim_order
127+
# Take care of cases when:
128+
# 4D (NHWC) -> >4D (NCH)
129+
# 3D (NCH) -> 4D (NHWC)
130+
self.insert_tosa_transposes(graph_module)
67131
graph_module.recompile()
132+
graph_module = super().call(graph_module).graph_module
133+
68134
return PassResult(graph_module, True)

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
op_squeeze,
3535
op_sub,
3636
op_sum,
37+
op_transpose,
3738
op_unsqueeze,
3839
op_view,
3940
)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import List
7+
8+
import serializer.tosa_serializer as ts
9+
import torch
10+
from executorch.backends.arm.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.arm.tosa_mapping import TosaArg
15+
from serializer.tosa_serializer import TosaOp
16+
17+
18+
@register_node_visitor
19+
class TransposeVisitor(NodeVisitor):
20+
"""
21+
This node visitor targets the _transpose op defined in the
22+
passthrough_to_tosa library. Used when switching between tosa_dim_orders.
23+
Inserts a TOSA TRANSPOSE.
24+
"""
25+
26+
target = "_transpose"
27+
28+
def define_node(
29+
self,
30+
node: torch.fx.Node,
31+
tosa_graph: ts.TosaSerializer,
32+
inputs: List[TosaArg],
33+
output: TosaArg,
34+
is_quant_node: bool,
35+
) -> None:
36+
output_rank = len(output.shape)
37+
perms = [dim % output_rank for dim in inputs[1].special]
38+
attr = ts.TosaSerializerAttribute()
39+
attr.TransposeAttribute(perms)
40+
tosa_graph.addOperator(
41+
TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr
42+
)

backends/arm/test/ops/test_expand.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from executorch.backends.arm.test.tester.arm_tester import ArmTester
2222

2323
from executorch.backends.xnnpack.test.tester.tester import Quantize
24+
from executorch.exir.backend.backend_details import CompileSpec
2425
from parameterized import parameterized
2526

2627

@@ -77,14 +78,14 @@ def _test_expand_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tupl
7778
)
7879

7980
def _test_expand_ethosu_BI_pipeline(
80-
self, module: torch.nn.Module, test_data: Tuple
81+
self, compile_spec: CompileSpec, module: torch.nn.Module, test_data: Tuple
8182
):
8283
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
8384
(
8485
ArmTester(
8586
module,
8687
example_inputs=test_data,
87-
compile_spec=common.get_u55_compile_spec(),
88+
compile_spec=compile_spec,
8889
)
8990
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
9091
.export()
@@ -104,17 +105,14 @@ def test_expand_tosa_MI(self, test_input, multiples):
104105
def test_expand_tosa_BI(self, test_input, multiples):
105106
self._test_expand_tosa_BI_pipeline(self.Expand(), (test_input, multiples))
106107

107-
# Expected failure since tosa.TILE is unsupported by Vela.
108108
@parameterized.expand(Expand.test_parameters)
109-
@unittest.expectedFailure # TODO: MLBEDSW-9386
110109
def test_expand_u55_BI(self, test_input, multiples):
111110
self._test_expand_ethosu_BI_pipeline(
112-
self.Expand(), common.get_u55_compile_spec(), (test_input, multiples)
111+
common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
113112
)
114113

115114
@parameterized.expand(Expand.test_parameters)
116-
@unittest.expectedFailure # TODO: MLBEDSW-9386
117115
def test_expand_u85_BI(self, test_input, multiples):
118116
self._test_expand_ethosu_BI_pipeline(
119-
self.Expand(), common.get_u85_compile_spec(), (test_input, multiples)
117+
common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
120118
)

backends/arm/test/ops/test_repeat.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,12 @@ def test_repeat_tosa_BI(self, test_input, multiples):
107107
self._test_repeat_tosa_BI_pipeline(self.Repeat(), (test_input, multiples))
108108

109109
@parameterized.expand(Repeat.test_parameters)
110-
@unittest.expectedFailure # TODO: MLBEDSW-9386
111110
def test_repeat_u55_BI(self, test_input, multiples):
112111
self._test_repeat_ethosu_pipeline(
113112
common.get_u55_compile_spec(), self.Repeat(), (test_input, multiples)
114113
)
115114

116115
@parameterized.expand(Repeat.test_parameters)
117-
@unittest.expectedFailure # TODO: MLBEDSW-9386
118116
def test_repeat_u85_BI(self, test_input, multiples):
119117
self._test_repeat_ethosu_pipeline(
120118
common.get_u85_compile_spec(), self.Repeat(), (test_input, multiples)

backends/arm/test/ops/test_squeeze.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def forward(self, x: torch.Tensor, dim: int):
3838

3939
class SqueezeDims(torch.nn.Module):
4040
test_parameters: list[tuple[torch.Tensor, tuple[int]]] = [
41+
(torch.randn(1, 1, 5), (0, 1)),
4142
(torch.randn(1, 5, 5, 1), (0, -1)),
4243
(torch.randn(1, 5, 1, 5), (0, -2)),
4344
]
@@ -47,6 +48,7 @@ def forward(self, x: torch.Tensor, dims: tuple[int]):
4748

4849
class Squeeze(torch.nn.Module):
4950
test_parameters: list[tuple[torch.Tensor]] = [
51+
(torch.randn(1, 1, 5),),
5052
(torch.randn(1, 5, 5, 1),),
5153
(torch.randn(1, 5, 1, 5),),
5254
]
@@ -64,7 +66,7 @@ def _test_squeeze_tosa_MI_pipeline(
6466
ArmTester(
6567
module,
6668
example_inputs=test_data,
67-
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
69+
compile_spec=common.get_tosa_compile_spec(),
6870
)
6971
.export()
7072
.check_count({export_target: 1})
@@ -86,7 +88,7 @@ def _test_squeeze_tosa_BI_pipeline(
8688
ArmTester(
8789
module,
8890
example_inputs=test_data,
89-
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
91+
compile_spec=common.get_tosa_compile_spec(),
9092
)
9193
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
9294
.export()
@@ -184,7 +186,7 @@ def test_squeeze_dim_u55_BI(self, test_tensor: torch.Tensor, dim: int):
184186
@parameterized.expand(SqueezeDim.test_parameters)
185187
def test_squeeze_dim_u85_BI(self, test_tensor: torch.Tensor, dim: int):
186188
self._test_squeeze_ethosu_BI_pipeline(
187-
common.get_u85_compile_spec(permute_memory_to_nhwc=False),
189+
common.get_u85_compile_spec(permute_memory_to_nhwc=True),
188190
self.SqueezeDim(),
189191
(test_tensor, dim),
190192
"torch.ops.aten.squeeze.dim",
@@ -214,7 +216,7 @@ def test_squeeze_dims_u55_BI(self, test_tensor: torch.Tensor, dims: tuple[int]):
214216
@parameterized.expand(SqueezeDims.test_parameters)
215217
def test_squeeze_dims_u85_BI(self, test_tensor: torch.Tensor, dims: tuple[int]):
216218
self._test_squeeze_ethosu_BI_pipeline(
217-
common.get_u85_compile_spec(permute_memory_to_nhwc=False),
219+
common.get_u85_compile_spec(),
218220
self.SqueezeDims(),
219221
(test_tensor, dims),
220222
"torch.ops.aten.squeeze.dims",

backends/arm/test/ops/test_unsqueeze.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
class TestSimpleUnsqueeze(unittest.TestCase):
2929
class Unsqueeze(torch.nn.Module):
30-
shapes: list[int | Sequence[int]] = [5, (5, 5), (5, 5), (5, 4, 3)]
30+
shapes: list[int | Sequence[int]] = [5, (5, 5), (5, 4), (5, 4, 3)]
3131
test_parameters: list[tuple[torch.Tensor]] = [(torch.randn(n),) for n in shapes]
3232

3333
def forward(self, x: torch.Tensor, dim):
@@ -40,7 +40,7 @@ def _test_unsqueeze_tosa_MI_pipeline(
4040
ArmTester(
4141
module,
4242
example_inputs=test_data,
43-
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
43+
compile_spec=common.get_tosa_compile_spec(),
4444
)
4545
.export()
4646
.check_count({"torch.ops.aten.unsqueeze.default": 1})
@@ -59,7 +59,7 @@ def _test_unsqueeze_tosa_BI_pipeline(
5959
ArmTester(
6060
module,
6161
example_inputs=test_data,
62-
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=False),
62+
compile_spec=common.get_tosa_compile_spec(),
6363
)
6464
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
6565
.export()
@@ -102,18 +102,18 @@ def test_unsqueeze_tosa_MI(self, test_tensor: torch.Tensor):
102102
def test_unsqueeze_tosa_BI(self, test_tensor: torch.Tensor):
103103
self._test_unsqueeze_tosa_BI_pipeline(self.Unsqueeze(), (test_tensor, 0))
104104

105-
@parameterized.expand(Unsqueeze.test_parameters)
105+
@parameterized.expand(Unsqueeze.test_parameters[:-1])
106106
def test_unsqueeze_u55_BI(self, test_tensor: torch.Tensor):
107107
self._test_unsqueeze_ethosu_BI_pipeline(
108-
common.get_u55_compile_spec(permute_memory_to_nhwc=False),
108+
common.get_u55_compile_spec(),
109109
self.Unsqueeze(),
110110
(test_tensor, 0),
111111
)
112112

113113
@parameterized.expand(Unsqueeze.test_parameters)
114114
def test_unsqueeze_u85_BI(self, test_tensor: torch.Tensor):
115115
self._test_unsqueeze_ethosu_BI_pipeline(
116-
common.get_u85_compile_spec(permute_memory_to_nhwc=False),
116+
common.get_u85_compile_spec(),
117117
self.Unsqueeze(),
118118
(test_tensor, 0),
119119
)

examples/arm/setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ function setup_vela() {
261261
if [[ ! -e ethos-u-vela ]]; then
262262
git clone https://review.mlplatform.org/ml/ethos-u/ethos-u-vela
263263
repo_dir="${root_dir}/ethos-u-vela"
264-
base_rev=fe0eaa55c5ed319f78c01978f3b40eb11a9bcb38
264+
base_rev=57ce18c89ccc6f6309333dccb24ed30dc68b571f
265265
patch_repo
266266
fi
267267
cd "${root_dir}/ethos-u-vela"

0 commit comments

Comments
 (0)