Skip to content

Commit 6433646

Browse files
SaoirseARMErik-Lundell
authored andcommitted
Add missing unit tests for operators
* hardtanh * permute Change-Id: Ia1802bdc37d365af382835b3c14174d841892927
1 parent 0aa802d commit 6433646

File tree

2 files changed

+254
-0
lines changed

2 files changed

+254
-0
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024 Arm Limited and/or its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import unittest
9+
from typing import Tuple
10+
11+
import torch
12+
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
ArmQuantizer,
15+
get_symmetric_quantization_config,
16+
)
17+
18+
from executorch.backends.arm.test import common
19+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
20+
from executorch.backends.xnnpack.test.tester.tester import Quantize
21+
from parameterized import parameterized
22+
23+
24+
test_data_suite = [
25+
# (test_name, test_data)
26+
("zeros", torch.zeros(1, 10, 10, 10)),
27+
("ones", torch.ones(10, 10, 10)),
28+
("rand", torch.rand(10, 10) - 0.5),
29+
("randn_pos", torch.randn(10) + 10),
30+
("randn_neg", torch.randn(10) - 10),
31+
("ramp", torch.arange(-16, 16, 0.2)),
32+
]
33+
34+
35+
class TestHardTanh(unittest.TestCase):
36+
"""Tests HardTanh Operator."""
37+
38+
class HardTanh(torch.nn.Module):
39+
40+
def __init__(self):
41+
super().__init__()
42+
43+
self.hardTanh = torch.nn.Hardtanh()
44+
45+
def forward(self, x):
46+
return self.hardTanh(x)
47+
48+
def _test_hardtanh_tosa_MI_pipeline(
49+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
50+
):
51+
(
52+
ArmTester(
53+
module,
54+
example_inputs=test_data,
55+
compile_spec=common.get_tosa_compile_spec(),
56+
)
57+
.export()
58+
.check(["torch.ops.aten.hardtanh.default"])
59+
.check_not(["torch.ops.quantized_decomposed"])
60+
.to_edge()
61+
.partition()
62+
.check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"])
63+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
64+
.to_executorch()
65+
.run_method_and_compare_outputs(inputs=test_data)
66+
)
67+
68+
def _test_hardtanh_tosa_BI_pipeline(
69+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
70+
):
71+
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
72+
(
73+
ArmTester(
74+
module,
75+
example_inputs=test_data,
76+
compile_spec=common.get_tosa_compile_spec(),
77+
)
78+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
79+
.export()
80+
.check_count({"torch.ops.aten.hardtanh.default": 1})
81+
.check(["torch.ops.quantized_decomposed"])
82+
.to_edge()
83+
.partition()
84+
.check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"])
85+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
86+
.to_executorch()
87+
.run_method_and_compare_outputs(inputs=test_data)
88+
)
89+
90+
def _test_hardtanh_tosa_u55_BI_pipeline(
91+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
92+
):
93+
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
94+
(
95+
ArmTester(
96+
module,
97+
example_inputs=test_data,
98+
compile_spec=common.get_u55_compile_spec(),
99+
)
100+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
101+
.export()
102+
.check_count({"torch.ops.aten.hardtanh.default": 1})
103+
.check(["torch.ops.quantized_decomposed"])
104+
.to_edge()
105+
.partition()
106+
.check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"])
107+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
108+
.to_executorch()
109+
)
110+
111+
@parameterized.expand(test_data_suite)
112+
def test_hardtanh_tosa_MI(
113+
self,
114+
test_name: str,
115+
test_data: torch.Tensor,
116+
):
117+
self._test_hardtanh_tosa_MI_pipeline(self.HardTanh(), (test_data,))
118+
119+
@parameterized.expand(test_data_suite)
120+
def test_hardtanh_tosa_BI(self, test_name: str, test_data: torch.Tensor):
121+
self._test_hardtanh_tosa_BI_pipeline(self.HardTanh(), (test_data,))
122+
123+
@parameterized.expand(test_data_suite)
124+
def test_hardtanh_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor):
125+
self._test_hardtanh_tosa_u55_BI_pipeline(self.HardTanh(), (test_data,))

backends/arm/test/ops/test_permute.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024 Arm Limited and/or its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import unittest
9+
from typing import Tuple
10+
11+
import torch
12+
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
ArmQuantizer,
15+
get_symmetric_quantization_config,
16+
)
17+
18+
from executorch.backends.arm.test import common
19+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
20+
from executorch.backends.xnnpack.test.tester.tester import Quantize
21+
from parameterized import parameterized
22+
from torchvision.ops import Permute
23+
24+
test_data_suite = [
25+
# (test_name,test_data,dims)
26+
("zeros", torch.zeros(10, 10, 10, 10), [1, 0, 3, 2]),
27+
("ones", torch.ones(10, 10, 10, 10), [3, 1, 0, 2]),
28+
("rand", torch.rand(10, 10, 10, 10) - 0.5, [0, 2, 3, 1]),
29+
("randn_pos", torch.randn(10, 10, 10) + 10, [2, 0, 1]),
30+
("randn_neg", torch.randn(10, 10, 10) - 10, [1, 2, 0]),
31+
("ramp", torch.arange(-16, 16, 0.2), [0]),
32+
]
33+
34+
35+
class TestPermute(unittest.TestCase):
36+
"""Tests Permute Operator."""
37+
38+
class Permute(torch.nn.Module):
39+
40+
def __init__(self, dims: list[int]):
41+
super().__init__()
42+
43+
self.permute = Permute(dims=dims)
44+
45+
def forward(self, x):
46+
return self.permute(x)
47+
48+
def _test_permute_tosa_MI_pipeline(
49+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
50+
):
51+
(
52+
ArmTester(
53+
module,
54+
example_inputs=test_data,
55+
compile_spec=common.get_tosa_compile_spec(),
56+
)
57+
.export()
58+
.check(["torch.ops.aten.permute.default"])
59+
.check_not(["torch.ops.quantized_decomposed"])
60+
.to_edge()
61+
.partition()
62+
.check_not(["executorch_exir_dialects_edge__ops_aten_permute_default"])
63+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
64+
.to_executorch()
65+
.run_method_and_compare_outputs(inputs=test_data)
66+
)
67+
68+
def _test_permute_tosa_BI_pipeline(
69+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
70+
):
71+
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
72+
(
73+
ArmTester(
74+
module,
75+
example_inputs=test_data,
76+
compile_spec=common.get_tosa_compile_spec(),
77+
)
78+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
79+
.export()
80+
.check_count({"torch.ops.aten.permute.default": 1})
81+
.check(["torch.ops.quantized_decomposed"])
82+
.to_edge()
83+
.partition()
84+
.check_not(["executorch_exir_dialects_edge__ops_aten_permute_default"])
85+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
86+
.to_executorch()
87+
.run_method_and_compare_outputs(inputs=test_data)
88+
)
89+
90+
def _test_permute_tosa_u55_BI_pipeline(
91+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
92+
):
93+
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
94+
(
95+
ArmTester(
96+
module,
97+
example_inputs=test_data,
98+
compile_spec=common.get_u55_compile_spec(),
99+
)
100+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
101+
.export()
102+
.check_count({"torch.ops.aten.permute.default": 1})
103+
.check(["torch.ops.quantized_decomposed"])
104+
.to_edge()
105+
.partition()
106+
.check_not(["executorch_exir_dialects_edge__ops_aten_permute_default"])
107+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
108+
.to_executorch()
109+
)
110+
111+
@parameterized.expand(test_data_suite)
112+
def test_permute_tosa_MI(
113+
self, test_name: str, test_data: torch.Tensor, dims: list[int]
114+
):
115+
self._test_permute_tosa_MI_pipeline(self.Permute(dims=dims), (test_data,))
116+
117+
@parameterized.expand(test_data_suite)
118+
def test_permute_tosa_BI(
119+
self, test_name: str, test_data: torch.Tensor, dims: list[int]
120+
):
121+
self._test_permute_tosa_BI_pipeline(self.Permute(dims=dims), (test_data,))
122+
123+
# Expected to fail as Permute is not supported by the NPU
124+
@parameterized.expand(test_data_suite)
125+
@unittest.expectedFailure
126+
def test_permute_tosa_u55_BI(
127+
self, test_name: str, test_data: torch.Tensor, dims: list[int]
128+
):
129+
self._test_permute_tosa_u55_BI_pipeline(self.Permute(dims=dims), (test_data,))

0 commit comments

Comments
 (0)