Skip to content

Commit f6778d5

Browse files
authored
Fix dim order in op_permute
Differential Revision: D64765057 Pull Request resolved: #6432
1 parent 169ddbf commit f6778d5

File tree

3 files changed

+352
-2
lines changed

3 files changed

+352
-2
lines changed

backends/arm/operators/op_permute.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -18,6 +18,54 @@
1818
from serializer.tosa_serializer import TosaOp
1919

2020

21+
def permutation_vector_to_matrix(permutation_vector: list[int]) -> torch.Tensor:
22+
"""
23+
Converts a permutation vector of length N to a NxN matrix that describes the same permutation.
24+
for example:
25+
(1,0,2)
26+
->
27+
[0 1 0]
28+
|1 0 0|
29+
[0 0 1]
30+
"""
31+
N = len(permutation_vector)
32+
P = torch.zeros(N, N)
33+
for row_index, col_index in enumerate(permutation_vector):
34+
P[row_index][col_index] = 1
35+
return P
36+
37+
38+
def permutation_matrix_to_vector(permutation_matrix: torch.Tensor) -> list[int]:
39+
"""
40+
Converts a NxN permutation matrix to a permutation vector of length N that describes the same permutation.
41+
[0 1 0]
42+
|1 0 0|
43+
[0 0 1]
44+
->
45+
(1,0,2)
46+
"""
47+
N = len(permutation_matrix)
48+
assert N == len(
49+
permutation_matrix[0]
50+
), f"A permutation matrix must be square, got shape {permutation_matrix.shape}"
51+
52+
p = [0] * N
53+
for row_index, row in enumerate(permutation_matrix):
54+
saw_one = False
55+
for col_index, value in enumerate(row):
56+
if value == 1:
57+
assert (
58+
not saw_one
59+
), f"A permutation matrix can only have one 1 per row, got row {row}."
60+
p[row_index] = col_index
61+
saw_one = True
62+
else:
63+
assert (
64+
value == 0
65+
), f"A permutation matrix only contains 1's and 0's, got value {value}."
66+
return p
67+
68+
2169
@register_node_visitor
2270
class PermuteVisitor(NodeVisitor):
2371
target = "aten.permute_copy.default"
@@ -40,8 +88,33 @@ def define_node(
4088
)
4189
return
4290

91+
# The permutation vector describes a permutation P in default Pytorch dim_order.
92+
# For rank 4, the default dim_order NCHW.
93+
# E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h)
94+
permutation_vector = inputs[1].special
95+
96+
if output.dim_order != tuple(range(len(output.dim_order))):
97+
# the permutation vector can't be used directly if we are not in NCHW dim_order.
98+
# We need to first transform to NCHW, apply P,
99+
# and then transform back to the original dim_order.
100+
# This transformation, S, is also a permutation, with the dim_order as permutation vector.
101+
102+
# To do this, represent P and S with permutation matrices.
103+
# Matrices can handle chained transformations and inversion easily.
104+
S = permutation_vector_to_matrix(output.dim_order)
105+
# The inverse of a permutation matrix is its transpose.
106+
S_inverse = S.transpose(1, 0)
107+
P = permutation_vector_to_matrix(permutation_vector)
108+
109+
# The complete transformation is S * P * S_inverse.
110+
transformation_matrix = S.matmul(P.matmul(S_inverse))
111+
112+
# Luckily, since it is just a combination of permutations, the result is also a permutation
113+
# that can again be described by a new permutation vector.
114+
permutation_vector = permutation_matrix_to_vector(transformation_matrix)
115+
43116
attr = ts.TosaSerializerAttribute()
44-
attr.TransposeAttribute(inputs[1].special)
117+
attr.TransposeAttribute(permutation_vector)
45118
tosa_graph.addOperator(
46119
TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr
47120
)
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024 Arm Limited and/or its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import unittest
9+
from typing import Tuple
10+
11+
import torch
12+
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
ArmQuantizer,
15+
get_symmetric_quantization_config,
16+
)
17+
18+
from executorch.backends.arm.test import common
19+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
20+
from executorch.backends.xnnpack.test.tester.tester import Quantize
21+
from parameterized import parameterized
22+
23+
24+
test_data_suite = [
25+
# (test_name, test_data)
26+
("zeros", torch.zeros(1, 10, 10, 10)),
27+
("ones", torch.ones(10, 10, 10)),
28+
("rand", torch.rand(10, 10) - 0.5),
29+
("randn_pos", torch.randn(10) + 10),
30+
("randn_neg", torch.randn(10) - 10),
31+
("ramp", torch.arange(-16, 16, 0.2)),
32+
]
33+
34+
35+
class TestHardTanh(unittest.TestCase):
36+
"""Tests HardTanh Operator."""
37+
38+
class HardTanh(torch.nn.Module):
39+
40+
def __init__(self):
41+
super().__init__()
42+
43+
self.hardTanh = torch.nn.Hardtanh()
44+
45+
def forward(self, x):
46+
return self.hardTanh(x)
47+
48+
def _test_hardtanh_tosa_MI_pipeline(
49+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
50+
):
51+
(
52+
ArmTester(
53+
module,
54+
example_inputs=test_data,
55+
compile_spec=common.get_tosa_compile_spec(),
56+
)
57+
.export()
58+
.check(["torch.ops.aten.hardtanh.default"])
59+
.check_not(["torch.ops.quantized_decomposed"])
60+
.to_edge()
61+
.partition()
62+
.check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"])
63+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
64+
.to_executorch()
65+
.run_method_and_compare_outputs(inputs=test_data)
66+
)
67+
68+
def _test_hardtanh_tosa_BI_pipeline(
69+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
70+
):
71+
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
72+
(
73+
ArmTester(
74+
module,
75+
example_inputs=test_data,
76+
compile_spec=common.get_tosa_compile_spec(),
77+
)
78+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
79+
.export()
80+
.check_count({"torch.ops.aten.hardtanh.default": 1})
81+
.check(["torch.ops.quantized_decomposed"])
82+
.to_edge()
83+
.partition()
84+
.check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"])
85+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
86+
.to_executorch()
87+
.run_method_and_compare_outputs(inputs=test_data)
88+
)
89+
90+
def _test_hardtanh_tosa_u55_BI_pipeline(
91+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
92+
):
93+
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
94+
(
95+
ArmTester(
96+
module,
97+
example_inputs=test_data,
98+
compile_spec=common.get_u55_compile_spec(),
99+
)
100+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
101+
.export()
102+
.check_count({"torch.ops.aten.hardtanh.default": 1})
103+
.check(["torch.ops.quantized_decomposed"])
104+
.to_edge()
105+
.partition()
106+
.check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"])
107+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
108+
.to_executorch()
109+
)
110+
111+
@parameterized.expand(test_data_suite)
112+
def test_hardtanh_tosa_MI(
113+
self,
114+
test_name: str,
115+
test_data: torch.Tensor,
116+
):
117+
self._test_hardtanh_tosa_MI_pipeline(self.HardTanh(), (test_data,))
118+
119+
@parameterized.expand(test_data_suite)
120+
def test_hardtanh_tosa_BI(self, test_name: str, test_data: torch.Tensor):
121+
self._test_hardtanh_tosa_BI_pipeline(self.HardTanh(), (test_data,))
122+
123+
@parameterized.expand(test_data_suite)
124+
def test_hardtanh_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor):
125+
self._test_hardtanh_tosa_u55_BI_pipeline(self.HardTanh(), (test_data,))

backends/arm/test/ops/test_permute.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024 Arm Limited and/or its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import unittest
9+
from typing import Tuple
10+
11+
import torch
12+
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
ArmQuantizer,
15+
get_symmetric_quantization_config,
16+
)
17+
18+
from executorch.backends.arm.test import common
19+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
20+
from executorch.backends.xnnpack.test.tester.tester import Quantize
21+
from executorch.exir.backend.compile_spec_schema import CompileSpec
22+
from parameterized import parameterized
23+
from torchvision.ops import Permute
24+
25+
test_data_suite = [
26+
# (test_name,test_data,dims)
27+
("rank_2", torch.rand(10, 10), [1, 0]),
28+
("rank_3", torch.rand(10, 10, 10), [2, 0, 1]),
29+
("rank_3", torch.rand(10, 10, 10), [1, 2, 0]),
30+
("rank_4", torch.rand(1, 5, 1, 10), [0, 2, 3, 1]),
31+
("rank_4", torch.rand(1, 2, 5, 10), [1, 0, 2, 3]),
32+
("rank_4", torch.rand(1, 10, 10, 5), [2, 0, 1, 3]),
33+
]
34+
35+
36+
class TestPermute(unittest.TestCase):
37+
"""Tests Permute Operator."""
38+
39+
class Permute(torch.nn.Module):
40+
41+
def __init__(self, dims: list[int]):
42+
super().__init__()
43+
44+
self.permute = Permute(dims=dims)
45+
46+
def forward(self, x):
47+
return self.permute(x)
48+
49+
def _test_permute_tosa_MI_pipeline(
50+
self,
51+
module: torch.nn.Module,
52+
test_data: Tuple[torch.tensor],
53+
permute_memory_to_nhwc: bool,
54+
):
55+
(
56+
ArmTester(
57+
module,
58+
example_inputs=test_data,
59+
compile_spec=common.get_tosa_compile_spec(
60+
permute_memory_to_nhwc=permute_memory_to_nhwc
61+
),
62+
)
63+
.export()
64+
.check(["torch.ops.aten.permute.default"])
65+
.check_not(["torch.ops.quantized_decomposed"])
66+
.to_edge()
67+
.partition()
68+
.check_not(["executorch_exir_dialects_edge__ops_aten_permute_default"])
69+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
70+
.to_executorch()
71+
.run_method_and_compare_outputs(inputs=test_data)
72+
)
73+
74+
def _test_permute_tosa_BI_pipeline(
75+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
76+
):
77+
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
78+
(
79+
ArmTester(
80+
module,
81+
example_inputs=test_data,
82+
compile_spec=common.get_tosa_compile_spec(),
83+
)
84+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
85+
.export()
86+
.check_count({"torch.ops.aten.permute.default": 1})
87+
.check(["torch.ops.quantized_decomposed"])
88+
.to_edge()
89+
.partition()
90+
.check_not(["executorch_exir_dialects_edge__ops_aten_permute_default"])
91+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
92+
.to_executorch()
93+
.run_method_and_compare_outputs(inputs=test_data)
94+
)
95+
96+
def _test_permute_ethos_BI_pipeline(
97+
self,
98+
module: torch.nn.Module,
99+
compile_spec: CompileSpec,
100+
test_data: Tuple[torch.Tensor],
101+
):
102+
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
103+
(
104+
ArmTester(
105+
module,
106+
example_inputs=test_data,
107+
compile_spec=compile_spec,
108+
)
109+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
110+
.export()
111+
.check_count({"torch.ops.aten.permute.default": 1})
112+
.check(["torch.ops.quantized_decomposed"])
113+
.to_edge()
114+
.partition()
115+
.check_not(["executorch_exir_dialects_edge__ops_aten_permute_default"])
116+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
117+
.to_executorch()
118+
.serialize()
119+
)
120+
121+
@parameterized.expand(test_data_suite)
122+
def test_permute_tosa_MI(
123+
self, test_name: str, test_data: torch.Tensor, dims: list[int]
124+
):
125+
self._test_permute_tosa_MI_pipeline(self.Permute(dims=dims), (test_data,), True)
126+
self._test_permute_tosa_MI_pipeline(
127+
self.Permute(dims=dims), (test_data,), False
128+
)
129+
130+
@parameterized.expand(test_data_suite)
131+
def test_permute_tosa_BI(
132+
self, test_name: str, test_data: torch.Tensor, dims: list[int]
133+
):
134+
self._test_permute_tosa_BI_pipeline(self.Permute(dims=dims), (test_data,))
135+
136+
# Expected to fail as TOSA.Transpose is not supported by Ethos-U55.
137+
@parameterized.expand(test_data_suite[0:1])
138+
@unittest.expectedFailure
139+
def test_permute_u55_BI(
140+
self, test_name: str, test_data: torch.Tensor, dims: list[int]
141+
):
142+
self._test_permute_ethos_BI_pipeline(
143+
self.Permute(dims=dims), common.get_u55_compile_spec(), (test_data,)
144+
)
145+
146+
@parameterized.expand(test_data_suite)
147+
def test_permute_u85_BI(
148+
self, test_name: str, test_data: torch.Tensor, dims: list[int]
149+
):
150+
self._test_permute_ethos_BI_pipeline(
151+
self.Permute(dims=dims), common.get_u85_compile_spec(), (test_data,)
152+
)

0 commit comments

Comments
 (0)