Skip to content

Commit b86ae33

Browse files
oscarandersson8218Erik-Lundell
authored andcommitted
Permute permutation vector for op_permute
Permute vector needs to be permuted when dim_order != (0, 1, 2, 3) Change-Id: I2a35c6852376f9a57deeedd4fc38bda870e453a4 Signed-off-by: Oscar Andersson <[email protected]> Signed-off-by: Erik Lundell <[email protected]>
1 parent 6433646 commit b86ae33

File tree

2 files changed

+114
-18
lines changed

2 files changed

+114
-18
lines changed

backends/arm/operators/op_permute.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -18,6 +18,54 @@
1818
from serializer.tosa_serializer import TosaOp
1919

2020

21+
def permutation_vector_to_matrix(permutation_vector: list[int]) -> torch.Tensor:
22+
"""
23+
Converts a permutation vector of length N to a NxN matrix that describes the same permutation.
24+
for example:
25+
(1,0,2)
26+
->
27+
[0 1 0]
28+
|1 0 0|
29+
[0 0 1]
30+
"""
31+
N = len(permutation_vector)
32+
P = torch.zeros(N, N)
33+
for row_index, col_index in enumerate(permutation_vector):
34+
P[row_index][col_index] = 1
35+
return P
36+
37+
38+
def permutation_matrix_to_vector(permutation_matrix: torch.Tensor) -> list[int]:
39+
"""
40+
Converts a NxN permutation matrix to a permutation vector of length N that describes the same permutation.
41+
[0 1 0]
42+
|1 0 0|
43+
[0 0 1]
44+
->
45+
(1,0,2)
46+
"""
47+
N = len(permutation_matrix)
48+
assert N == len(
49+
permutation_matrix[0]
50+
), f"A permutation matrix must be square, got shape {permutation_matrix.shape}"
51+
52+
p = [0] * N
53+
for row_index, row in enumerate(permutation_matrix):
54+
saw_one = False
55+
for col_index, value in enumerate(row):
56+
if value == 1:
57+
assert (
58+
not saw_one
59+
), f"A permutation matrix can only have one 1 per row, got row {row}."
60+
p[row_index] = col_index
61+
saw_one = True
62+
else:
63+
assert (
64+
value == 0
65+
), f"A permutation matrix only contains 1's and 0's, got value {value}."
66+
return p
67+
68+
2169
@register_node_visitor
2270
class PermuteVisitor(NodeVisitor):
2371
target = "aten.permute_copy.default"
@@ -40,8 +88,33 @@ def define_node(
4088
)
4189
return
4290

91+
# The permutation vector describes a permutation P in default Pytorch dim_order.
92+
# For rank 4, the default dim_order NCHW.
93+
# E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h)
94+
permutation_vector = inputs[1].special
95+
96+
if output.dim_order != tuple(range(len(output.dim_order))):
97+
# the permutation vector can't be used directly if we are not in NCHW dim_order.
98+
# We need to first transform to NCHW, apply P,
99+
# and then transform back to the original dim_order.
100+
# This transformation, S, is also a permutation, with the dim_order as permutation vector.
101+
102+
# To do this, represent P and S with permutation matrices.
103+
# Matrices can handle chained transformations and inversion easily.
104+
S = permutation_vector_to_matrix(output.dim_order)
105+
# The inverse of a permutation matrix is its transpose.
106+
S_inverse = S.transpose(1, 0)
107+
P = permutation_vector_to_matrix(permutation_vector)
108+
109+
# The complete transformation is S * P * S_inverse.
110+
transformation_matrix = S.matmul(P.matmul(S_inverse))
111+
112+
# Luckily, since it is just a combination of permutations, the result is also a permutation
113+
# that can again be described by a new permutation vector.
114+
permutation_vector = permutation_matrix_to_vector(transformation_matrix)
115+
43116
attr = ts.TosaSerializerAttribute()
44-
attr.TransposeAttribute(inputs[1].special)
117+
attr.TransposeAttribute(permutation_vector)
45118
tosa_graph.addOperator(
46119
TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr
47120
)

backends/arm/test/ops/test_permute.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,18 @@
1818
from executorch.backends.arm.test import common
1919
from executorch.backends.arm.test.tester.arm_tester import ArmTester
2020
from executorch.backends.xnnpack.test.tester.tester import Quantize
21+
from executorch.exir.backend.compile_spec_schema import CompileSpec
2122
from parameterized import parameterized
2223
from torchvision.ops import Permute
2324

2425
test_data_suite = [
2526
# (test_name,test_data,dims)
26-
("zeros", torch.zeros(10, 10, 10, 10), [1, 0, 3, 2]),
27-
("ones", torch.ones(10, 10, 10, 10), [3, 1, 0, 2]),
28-
("rand", torch.rand(10, 10, 10, 10) - 0.5, [0, 2, 3, 1]),
29-
("randn_pos", torch.randn(10, 10, 10) + 10, [2, 0, 1]),
30-
("randn_neg", torch.randn(10, 10, 10) - 10, [1, 2, 0]),
31-
("ramp", torch.arange(-16, 16, 0.2), [0]),
27+
("rank_2", torch.rand(10, 10), [1, 0]),
28+
("rank_3", torch.rand(10, 10, 10), [2, 0, 1]),
29+
("rank_3", torch.rand(10, 10, 10), [1, 2, 0]),
30+
("rank_4", torch.rand(1, 5, 1, 10), [0, 2, 3, 1]),
31+
("rank_4", torch.rand(1, 2, 5, 10), [1, 0, 2, 3]),
32+
("rank_4", torch.rand(1, 10, 10, 5), [2, 0, 1, 3]),
3233
]
3334

3435

@@ -46,13 +47,18 @@ def forward(self, x):
4647
return self.permute(x)
4748

4849
def _test_permute_tosa_MI_pipeline(
49-
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
50+
self,
51+
module: torch.nn.Module,
52+
test_data: Tuple[torch.tensor],
53+
permute_memory_to_nhwc: bool,
5054
):
5155
(
5256
ArmTester(
5357
module,
5458
example_inputs=test_data,
55-
compile_spec=common.get_tosa_compile_spec(),
59+
compile_spec=common.get_tosa_compile_spec(
60+
permute_memory_to_nhwc=permute_memory_to_nhwc
61+
),
5662
)
5763
.export()
5864
.check(["torch.ops.aten.permute.default"])
@@ -87,15 +93,18 @@ def _test_permute_tosa_BI_pipeline(
8793
.run_method_and_compare_outputs(inputs=test_data)
8894
)
8995

90-
def _test_permute_tosa_u55_BI_pipeline(
91-
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
96+
def _test_permute_ethos_BI_pipeline(
97+
self,
98+
module: torch.nn.Module,
99+
compile_spec: CompileSpec,
100+
test_data: Tuple[torch.Tensor],
92101
):
93102
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
94103
(
95104
ArmTester(
96105
module,
97106
example_inputs=test_data,
98-
compile_spec=common.get_u55_compile_spec(),
107+
compile_spec=compile_spec,
99108
)
100109
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
101110
.export()
@@ -106,24 +115,38 @@ def _test_permute_tosa_u55_BI_pipeline(
106115
.check_not(["executorch_exir_dialects_edge__ops_aten_permute_default"])
107116
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
108117
.to_executorch()
118+
.serialize()
109119
)
110120

111121
@parameterized.expand(test_data_suite)
112122
def test_permute_tosa_MI(
113123
self, test_name: str, test_data: torch.Tensor, dims: list[int]
114124
):
115-
self._test_permute_tosa_MI_pipeline(self.Permute(dims=dims), (test_data,))
125+
self._test_permute_tosa_MI_pipeline(self.Permute(dims=dims), (test_data,), True)
126+
self._test_permute_tosa_MI_pipeline(
127+
self.Permute(dims=dims), (test_data,), False
128+
)
116129

117130
@parameterized.expand(test_data_suite)
118131
def test_permute_tosa_BI(
119132
self, test_name: str, test_data: torch.Tensor, dims: list[int]
120133
):
121134
self._test_permute_tosa_BI_pipeline(self.Permute(dims=dims), (test_data,))
122135

123-
# Expected to fail as Permute is not supported by the NPU
124-
@parameterized.expand(test_data_suite)
136+
# Expected to fail as TOSA.Transpose is not supported by Ethos-U55.
137+
@parameterized.expand(test_data_suite[0:1])
125138
@unittest.expectedFailure
126-
def test_permute_tosa_u55_BI(
139+
def test_permute_u55_BI(
127140
self, test_name: str, test_data: torch.Tensor, dims: list[int]
128141
):
129-
self._test_permute_tosa_u55_BI_pipeline(self.Permute(dims=dims), (test_data,))
142+
self._test_permute_ethos_BI_pipeline(
143+
self.Permute(dims=dims), common.get_u55_compile_spec(), (test_data,)
144+
)
145+
146+
@parameterized.expand(test_data_suite)
147+
def test_permute_u85_BI(
148+
self, test_name: str, test_data: torch.Tensor, dims: list[int]
149+
):
150+
self._test_permute_ethos_BI_pipeline(
151+
self.Permute(dims=dims), common.get_u85_compile_spec(), (test_data,)
152+
)

0 commit comments

Comments
 (0)