Skip to content

Commit ef2bfcd

Browse files
Arm backend: Refactor gt, ge, lt, le and eq tests to pipeline (#8828)
Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 3ece593 commit ef2bfcd

File tree

6 files changed

+620
-640
lines changed

6 files changed

+620
-640
lines changed

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,11 @@ def is_node_supported(
195195
exir_ops.edge.aten.bitwise_xor.Tensor,
196196
exir_ops.edge.aten.amax.default,
197197
exir_ops.edge.aten.amin.default,
198+
exir_ops.edge.aten.eq.Tensor,
199+
exir_ops.edge.aten.ge.Tensor,
200+
exir_ops.edge.aten.gt.Tensor,
201+
exir_ops.edge.aten.le.Tensor,
202+
exir_ops.edge.aten.lt.Tensor,
198203
]
199204

200205
if node.target in unsupported_ops:

backends/arm/test/ops/test_eq.py

Lines changed: 123 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -1,145 +1,136 @@
11
# Copyright 2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
65

7-
import unittest
6+
from typing import Tuple
87

8+
import pytest
99
import torch
1010
from executorch.backends.arm.test import common
11-
from executorch.backends.arm.test.tester.arm_tester import ArmTester
12-
from executorch.exir.backend.compile_spec_schema import CompileSpec
13-
from parameterized import parameterized
14-
15-
test_data_suite = [
16-
# (test_name, input, other,) See torch.eq() for info
17-
(
18-
"op_eq_rank1_ones",
19-
torch.ones(5),
20-
torch.ones(5),
21-
),
22-
(
23-
"op_eq_rank2_rand",
24-
torch.rand(4, 5),
25-
torch.rand(1, 5),
26-
),
27-
(
28-
"op_eq_rank3_randn",
29-
torch.randn(10, 5, 2),
30-
torch.randn(10, 5, 2),
31-
),
32-
(
33-
"op_eq_rank4_randn",
34-
torch.randn(3, 2, 2, 2),
35-
torch.randn(3, 2, 2, 2),
36-
),
37-
]
38-
39-
40-
class TestEqual(unittest.TestCase):
41-
class Equal(torch.nn.Module):
42-
def forward(
43-
self,
44-
input_: torch.Tensor,
45-
other_: torch.Tensor,
46-
):
47-
return input_ == other_
48-
49-
def _test_eq_tosa_MI_pipeline(
50-
self,
51-
compile_spec: list[CompileSpec],
52-
module: torch.nn.Module,
53-
test_data: tuple[torch.Tensor, torch.Tensor],
54-
):
55-
(
56-
ArmTester(
57-
module,
58-
example_inputs=test_data,
59-
compile_spec=compile_spec,
60-
)
61-
.export()
62-
.check_count({"torch.ops.aten.eq.Tensor": 1})
63-
.to_edge()
64-
.partition()
65-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
66-
.to_executorch()
67-
.run_method_and_compare_outputs(inputs=test_data)
68-
)
69-
70-
def _test_eq_tosa_BI_pipeline(
71-
self,
72-
compile_spec: list[CompileSpec],
73-
module: torch.nn.Module,
74-
test_data: tuple[torch.Tensor, torch.Tensor],
75-
):
76-
(
77-
ArmTester(
78-
module,
79-
example_inputs=test_data,
80-
compile_spec=compile_spec,
81-
)
82-
.quantize()
83-
.export()
84-
.check_count({"torch.ops.aten.eq.Tensor": 1})
85-
.check(["torch.ops.quantized_decomposed"])
86-
.to_edge()
87-
.partition()
88-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
89-
.to_executorch()
90-
.run_method_and_compare_outputs(inputs=test_data)
91-
)
92-
93-
@parameterized.expand(test_data_suite)
94-
def test_eq_tosa_MI(
95-
self,
96-
test_name: str,
97-
input_: torch.Tensor,
98-
other_: torch.Tensor,
99-
):
100-
test_data = (input_, other_)
101-
self._test_eq_tosa_MI_pipeline(
102-
common.get_tosa_compile_spec("TOSA-0.80+MI"), self.Equal(), test_data
103-
)
10411

105-
@parameterized.expand(test_data_suite)
106-
def test_eq_tosa_BI(
107-
self,
108-
test_name: str,
109-
input_: torch.Tensor,
110-
other_: torch.Tensor,
111-
):
112-
test_data = (input_, other_)
113-
self._test_eq_tosa_BI_pipeline(
114-
common.get_tosa_compile_spec("TOSA-0.80+BI"), self.Equal(), test_data
115-
)
116-
117-
@parameterized.expand(test_data_suite)
118-
@unittest.skip
119-
def test_eq_u55_BI(
120-
self,
121-
test_name: str,
122-
input_: torch.Tensor,
123-
other_: torch.Tensor,
124-
):
125-
test_data = (input_, other_)
126-
self._test_eq_tosa_BI_pipeline(
127-
common.get_u55_compile_spec(permute_memory_to_nhwc=True),
128-
self.Equal(),
129-
test_data,
130-
)
131-
132-
@parameterized.expand(test_data_suite)
133-
@unittest.skip
134-
def test_eq_u85_BI(
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
EthosU85PipelineBI,
14+
OpNotSupportedPipeline,
15+
TosaPipelineBI,
16+
TosaPipelineMI,
17+
)
18+
19+
aten_op = "torch.ops.aten.eq.Tensor"
20+
exir_op = "executorch_exir_dialects_edge__ops_aten_eq_Tensor"
21+
22+
input_t = Tuple[torch.Tensor]
23+
24+
25+
class Equal(torch.nn.Module):
26+
def __init__(self, input, other):
27+
super().__init__()
28+
self.input_ = input
29+
self.other_ = other
30+
31+
def forward(
13532
self,
136-
test_name: str,
13733
input_: torch.Tensor,
13834
other_: torch.Tensor,
13935
):
140-
test_data = (input_, other_)
141-
self._test_eq_tosa_BI_pipeline(
142-
common.get_u85_compile_spec(permute_memory_to_nhwc=True),
143-
self.Equal(),
144-
test_data,
145-
)
36+
return input_ == other_
37+
38+
def get_inputs(self):
39+
return (self.input_, self.other_)
40+
41+
42+
op_eq_rank1_ones = Equal(
43+
torch.ones(5),
44+
torch.ones(5),
45+
)
46+
op_eq_rank2_rand = Equal(
47+
torch.rand(4, 5),
48+
torch.rand(1, 5),
49+
)
50+
op_eq_rank3_randn = Equal(
51+
torch.randn(10, 5, 2),
52+
torch.randn(10, 5, 2),
53+
)
54+
op_eq_rank4_randn = Equal(
55+
torch.randn(3, 2, 2, 2),
56+
torch.randn(3, 2, 2, 2),
57+
)
58+
59+
test_data_common = {
60+
"eq_rank1_ones": op_eq_rank1_ones,
61+
"eq_rank2_rand": op_eq_rank2_rand,
62+
"eq_rank3_randn": op_eq_rank3_randn,
63+
"eq_rank4_randn": op_eq_rank4_randn,
64+
}
65+
66+
67+
@common.parametrize("test_module", test_data_common)
68+
def test_eq_tosa_MI(test_module):
69+
pipeline = TosaPipelineMI[input_t](
70+
test_module, test_module.get_inputs(), aten_op, exir_op
71+
)
72+
pipeline.run()
73+
74+
75+
@common.parametrize("test_module", test_data_common)
76+
def test_eq_tosa_BI(test_module):
77+
pipeline = TosaPipelineBI[input_t](
78+
test_module, test_module.get_inputs(), aten_op, exir_op
79+
)
80+
pipeline.run()
81+
82+
83+
@common.parametrize("test_module", test_data_common)
84+
def test_eq_u55_BI(test_module):
85+
# EQUAL is not supported on U55.
86+
pipeline = OpNotSupportedPipeline[input_t](
87+
test_module,
88+
test_module.get_inputs(),
89+
"TOSA-0.80+BI+u55",
90+
{exir_op: 1},
91+
)
92+
pipeline.run()
93+
94+
95+
@common.parametrize("test_module", test_data_common)
96+
def test_eq_u85_BI(test_module):
97+
pipeline = EthosU85PipelineBI[input_t](
98+
test_module,
99+
test_module.get_inputs(),
100+
aten_op,
101+
exir_op,
102+
run_on_fvp=False,
103+
use_to_edge_transform_and_lower=True,
104+
)
105+
pipeline.run()
106+
107+
108+
@common.parametrize("test_module", test_data_common)
109+
@pytest.mark.skip(reason="The same as test_eq_u55_BI")
110+
def test_eq_u55_BI_on_fvp(test_module):
111+
# EQUAL is not supported on U55.
112+
pipeline = OpNotSupportedPipeline[input_t](
113+
test_module,
114+
test_module.get_inputs(),
115+
"TOSA-0.80+BI+u55",
116+
{exir_op: 1},
117+
)
118+
pipeline.run()
119+
120+
121+
@common.parametrize(
122+
"test_module",
123+
test_data_common,
124+
xfails={"eq_rank4_randn": "4D fails because boolean Tensors can't be subtracted"},
125+
)
126+
@common.SkipIfNoCorstone320
127+
def test_eq_u85_BI_on_fvp(test_module):
128+
pipeline = EthosU85PipelineBI[input_t](
129+
test_module,
130+
test_module.get_inputs(),
131+
aten_op,
132+
exir_op,
133+
run_on_fvp=True,
134+
use_to_edge_transform_and_lower=True,
135+
)
136+
pipeline.run()

0 commit comments

Comments
 (0)