Skip to content

Commit 6e0cb06

Browse files
Arm unittest refactor of Add and Conv2D test cases (#7541)
* Arm unittest refactor of Add and Conv test cases This is a showcase of multiple improvements which can be done across all tests: 1. Move pipeline definition from tests to a general (flexible) pipeline class - Define a default step of stages using e.g. TosaPipelineBI() - Add custom config or debug stages using helper functions s.a. .change_args(), .add_stage(), .dump() etc. - Run the full pipeline using .run() 2. Move towards a pure pytest approach to remove dependencies on unittest and parametrize 3. Separate tests running on FVP from tests not running on FVP rather than configuring this from the command line - FVP tests are skipped if not installed - To filter out tests one may instead use pytest markers/name filtering - This should give a clearer picture of what has been tested 4. Introduces one favored way of marking tests as xfails, in the parameterize decorator Co-authored-by: Erik Lundell <[email protected]>
1 parent 6115ce4 commit 6e0cb06

File tree

6 files changed

+685
-295
lines changed

6 files changed

+685
-295
lines changed

backends/arm/test/common.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,19 @@
99

1010
import tempfile
1111
from datetime import datetime
12+
1213
from pathlib import Path
14+
from typing import Any
1315

16+
import pytest
1417
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
1518
from executorch.backends.arm.tosa_specification import TosaSpecification
1619
from executorch.exir.backend.compile_spec_schema import CompileSpec
20+
from runner_utils import (
21+
arm_executor_runner_exists,
22+
corstone300_installed,
23+
corstone320_installed,
24+
)
1725

1826

1927
def get_time_formatted_path(path: str, log_prefix: str) -> str:
@@ -145,3 +153,44 @@ def get_u85_compile_spec_unbuilt(
145153
.dump_intermediate_artifacts_to(artifact_path)
146154
)
147155
return compile_spec # type: ignore[return-value]
156+
157+
158+
SkipIfNoCorstone300 = pytest.mark.skipif(
159+
not corstone300_installed() or not arm_executor_runner_exists("corstone-300"),
160+
reason="Did not find Corstone-300 FVP or executor_runner on path",
161+
)
162+
"""Skips a test if Corsone300 FVP is not installed, or if the executor runner is not built"""
163+
164+
SkipIfNoCorstone320 = pytest.mark.skipif(
165+
not corstone320_installed() or not arm_executor_runner_exists("corstone-320"),
166+
reason="Did not find Corstone-320 FVP or executor_runner on path",
167+
)
168+
"""Skips a test if Corsone320 FVP is not installed, or if the executor runner is not built."""
169+
170+
171+
def parametrize(
172+
arg_name: str, test_data: dict[str, Any], xfails: dict[str, str] = None
173+
):
174+
"""
175+
Custom version of pytest.mark.parametrize with some syntatic sugar and added xfail functionality
176+
- test_data is expected as a dict of (id, test_data) pairs
177+
- alllows to specifiy a dict of (id, failure_reason) pairs to mark specific tests as xfail
178+
"""
179+
if xfails is None:
180+
xfails = {}
181+
182+
def decorator_func(func):
183+
"""Test data is transformed from a dict of (id, data) pairs to a list of pytest params to work with the native pytests parametrize function"""
184+
pytest_testsuite = []
185+
for id, test_parameters in test_data.items():
186+
if id in xfails:
187+
pytest_param = pytest.param(
188+
test_parameters, id=id, marks=pytest.mark.xfail(reason=xfails[id])
189+
)
190+
else:
191+
pytest_param = pytest.param(test_parameters, id=id)
192+
pytest_testsuite.append(pytest_param)
193+
194+
return pytest.mark.parametrize(arg_name, pytest_testsuite)(func)
195+
196+
return decorator_func

backends/arm/test/ops/test_add.py

Lines changed: 136 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -5,169 +5,143 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
import unittest
98

109
from typing import Tuple
1110

12-
import pytest
1311
import torch
14-
from executorch.backends.arm.test import common, conftest
15-
from executorch.backends.arm.test.tester.arm_tester import ArmTester
16-
from executorch.exir.backend.compile_spec_schema import CompileSpec
17-
from parameterized import parameterized # type: ignore[import-untyped]
18-
19-
20-
class TestSimpleAdd(unittest.TestCase):
21-
"""Tests a single add op, x+x and x+y."""
22-
23-
class Add(torch.nn.Module):
24-
test_parameters = [
25-
(torch.FloatTensor([1, 2, 3, 5, 7]),),
26-
(3 * torch.ones(8),),
27-
(10 * torch.randn(8),),
28-
(torch.ones(1, 1, 4, 4),),
29-
(torch.ones(1, 3, 4, 2),),
30-
]
31-
32-
def forward(self, x):
33-
return x + x
34-
35-
class Add2(torch.nn.Module):
36-
test_parameters = [
37-
(
38-
torch.FloatTensor([1, 2, 3, 5, 7]),
39-
(torch.FloatTensor([2, 1, 2, 1, 10])),
40-
),
41-
(torch.ones(1, 10, 4, 6), torch.ones(1, 10, 4, 6)),
42-
(torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)),
43-
(torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)),
44-
(10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)),
45-
]
46-
47-
def __init__(self):
48-
super().__init__()
49-
50-
def forward(self, x, y):
51-
return x + y
52-
53-
def _test_add_tosa_MI_pipeline(
54-
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
55-
):
56-
(
57-
ArmTester(
58-
module,
59-
example_inputs=test_data,
60-
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
61-
)
62-
.export()
63-
.check_count({"torch.ops.aten.add.Tensor": 1})
64-
.check_not(["torch.ops.quantized_decomposed"])
65-
.to_edge()
66-
.partition()
67-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
68-
.to_executorch()
69-
.run_method_and_compare_outputs(inputs=test_data)
70-
)
71-
72-
def _test_add_tosa_BI_pipeline(
73-
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
74-
):
75-
(
76-
ArmTester(
77-
module,
78-
example_inputs=test_data,
79-
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
80-
)
81-
.quantize()
82-
.export()
83-
.check_count({"torch.ops.aten.add.Tensor": 1})
84-
.check(["torch.ops.quantized_decomposed"])
85-
.to_edge()
86-
.partition()
87-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
88-
.to_executorch()
89-
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
90-
)
91-
92-
def _test_add_ethos_BI_pipeline(
93-
self,
94-
module: torch.nn.Module,
95-
compile_spec: CompileSpec,
96-
test_data: Tuple[torch.Tensor],
97-
):
98-
tester = (
99-
ArmTester(
100-
module,
101-
example_inputs=test_data,
102-
compile_spec=compile_spec,
103-
)
104-
.quantize()
105-
.export()
106-
.check_count({"torch.ops.aten.add.Tensor": 1})
107-
.check(["torch.ops.quantized_decomposed"])
108-
.to_edge()
109-
.partition()
110-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
111-
.to_executorch()
112-
.serialize()
113-
)
114-
if conftest.is_option_enabled("corstone_fvp"):
115-
tester.run_method_and_compare_outputs(qtol=1, inputs=test_data)
116-
117-
return tester
118-
119-
@parameterized.expand(Add.test_parameters)
120-
def test_add_tosa_MI(self, test_data: torch.Tensor):
121-
test_data = (test_data,)
122-
self._test_add_tosa_MI_pipeline(self.Add(), test_data)
123-
124-
@parameterized.expand(Add.test_parameters)
125-
def test_add_tosa_BI(self, test_data: torch.Tensor):
126-
test_data = (test_data,)
127-
self._test_add_tosa_BI_pipeline(self.Add(), test_data)
128-
129-
@parameterized.expand(Add.test_parameters)
130-
@pytest.mark.corstone_fvp
131-
def test_add_u55_BI(self, test_data: torch.Tensor):
132-
test_data = (test_data,)
133-
self._test_add_ethos_BI_pipeline(
134-
self.Add(),
135-
common.get_u55_compile_spec(),
136-
test_data,
137-
)
138-
139-
@parameterized.expand(Add.test_parameters)
140-
@pytest.mark.corstone_fvp
141-
def test_add_u85_BI(self, test_data: torch.Tensor):
142-
test_data = (test_data,)
143-
self._test_add_ethos_BI_pipeline(
144-
self.Add(),
145-
common.get_u85_compile_spec(),
146-
test_data,
147-
)
148-
149-
@parameterized.expand(Add2.test_parameters)
150-
def test_add2_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor):
151-
test_data = (operand1, operand2)
152-
self._test_add_tosa_MI_pipeline(self.Add2(), test_data)
153-
154-
@parameterized.expand(Add2.test_parameters)
155-
def test_add2_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
156-
test_data = (operand1, operand2)
157-
self._test_add_tosa_BI_pipeline(self.Add2(), test_data)
158-
159-
@parameterized.expand(Add2.test_parameters)
160-
@pytest.mark.corstone_fvp
161-
def test_add2_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
162-
test_data = (operand1, operand2)
163-
self._test_add_ethos_BI_pipeline(
164-
self.Add2(), common.get_u55_compile_spec(), test_data
165-
)
166-
167-
@parameterized.expand(Add2.test_parameters)
168-
@pytest.mark.corstone_fvp
169-
def test_add2_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
170-
test_data = (operand1, operand2)
171-
self._test_add_ethos_BI_pipeline(
172-
self.Add2(), common.get_u85_compile_spec(), test_data
173-
)
12+
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.test.tester.test_pipeline import (
14+
EthosU55PipelineBI,
15+
EthosU85PipelineBI,
16+
TosaPipelineBI,
17+
TosaPipelineMI,
18+
)
19+
20+
aten_op = "torch.ops.aten.add.Tensor"
21+
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"
22+
23+
input_t1 = Tuple[torch.Tensor] # Input x
24+
25+
26+
class Add(torch.nn.Module):
27+
def forward(self, x: torch.Tensor):
28+
return x + x
29+
30+
test_data: list[input_t1] = {
31+
"5d_float": (torch.FloatTensor([1, 2, 3, 5, 7]),),
32+
"1d_ones": ((3 * torch.ones(8),)),
33+
"1d_randn": (10 * torch.randn(8),),
34+
"4d_ones_1": (torch.ones(1, 1, 4, 4),),
35+
"4d_ones_2": (torch.ones(1, 3, 4, 2),),
36+
}
37+
38+
39+
input_t2 = Tuple[torch.Tensor, torch.Tensor] # Input x, y
40+
41+
42+
class Add2(torch.nn.Module):
43+
def forward(self, x: torch.Tensor, y: torch.Tensor):
44+
return x + y
45+
46+
test_data: list[input_t2] = {
47+
"5d_float": (
48+
torch.FloatTensor([1, 2, 3, 5, 7]),
49+
(torch.FloatTensor([2, 1, 2, 1, 10])),
50+
),
51+
"4d_ones": (torch.ones(1, 10, 4, 6), torch.ones(1, 10, 4, 6)),
52+
"4d_randn_1": (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)),
53+
"4d_randn_2": (torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)),
54+
"4d_randn_big": (10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)),
55+
}
56+
57+
58+
@common.parametrize("test_data", Add.test_data)
59+
def test_add_tosa_MI(test_data: input_t1):
60+
pipeline = TosaPipelineMI[input_t1](Add(), test_data, aten_op, exir_op)
61+
pipeline.run()
62+
63+
64+
@common.parametrize("test_data", Add.test_data)
65+
def test_add_tosa_BI(test_data: input_t1):
66+
pipeline = TosaPipelineBI[input_t1](Add(), test_data, aten_op, exir_op)
67+
pipeline.run()
68+
69+
70+
@common.parametrize("test_data", Add.test_data)
71+
def test_add_u55_BI(test_data: input_t1):
72+
pipeline = EthosU55PipelineBI[input_t1](
73+
Add(), test_data, aten_op, exir_op, run_on_fvp=False
74+
)
75+
pipeline.run()
76+
77+
78+
@common.parametrize("test_data", Add.test_data)
79+
def test_add_u85_BI(test_data: input_t1):
80+
pipeline = EthosU85PipelineBI[input_t1](
81+
Add(), test_data, aten_op, exir_op, run_on_fvp=False
82+
)
83+
pipeline.run()
84+
85+
86+
@common.parametrize("test_data", Add.test_data)
87+
@common.SkipIfNoCorstone300
88+
def test_add_u55_BI_on_fvp(test_data: input_t1):
89+
pipeline = EthosU55PipelineBI[input_t1](
90+
Add(), test_data, aten_op, exir_op, run_on_fvp=True
91+
)
92+
pipeline.run()
93+
94+
95+
@common.parametrize("test_data", Add.test_data)
96+
@common.SkipIfNoCorstone320
97+
def test_add_u85_BI_on_fvp(test_data: input_t1):
98+
pipeline = EthosU85PipelineBI[input_t1](
99+
Add(), test_data, aten_op, exir_op, run_on_fvp=True
100+
)
101+
pipeline.run()
102+
103+
104+
@common.parametrize("test_data", Add2.test_data)
105+
def test_add2_tosa_MI(test_data: input_t2):
106+
pipeline = TosaPipelineMI[input_t2](Add2(), test_data, aten_op, exir_op)
107+
pipeline.run()
108+
109+
110+
@common.parametrize("test_data", Add2.test_data)
111+
def test_add2_tosa_BI(test_data: input_t2):
112+
pipeline = TosaPipelineBI[input_t2](Add2(), test_data, aten_op, exir_op)
113+
pipeline.run()
114+
115+
116+
@common.parametrize("test_data", Add2.test_data)
117+
def test_add2_u55_BI(test_data: input_t2):
118+
pipeline = EthosU55PipelineBI[input_t2](
119+
Add2(), test_data, aten_op, exir_op, run_on_fvp=False
120+
)
121+
pipeline.run()
122+
123+
124+
@common.parametrize("test_data", Add2.test_data)
125+
@common.SkipIfNoCorstone300
126+
def test_add2_u55_BI_on_fvp(test_data: input_t2):
127+
pipeline = EthosU55PipelineBI[input_t2](
128+
Add2(), test_data, aten_op, exir_op, run_on_fvp=True
129+
)
130+
pipeline.run()
131+
132+
133+
@common.parametrize("test_data", Add2.test_data)
134+
def test_add2_u85_BI(test_data: input_t2):
135+
pipeline = EthosU85PipelineBI[input_t2](
136+
Add2(), test_data, aten_op, exir_op, run_on_fvp=False
137+
)
138+
pipeline.run()
139+
140+
141+
@common.parametrize("test_data", Add2.test_data)
142+
@common.SkipIfNoCorstone320
143+
def test_add2_u85_BI_on_fvp(test_data: input_t2):
144+
pipeline = EthosU85PipelineBI[input_t2](
145+
Add2(), test_data, aten_op, exir_op, run_on_fvp=True
146+
)
147+
pipeline.run()

0 commit comments

Comments
 (0)