Skip to content

Commit 3fd1649

Browse files
Arm backend: Refactor misc tests for TOSA V1.0
Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I9be385f8d521479083019948f29b46390b0193d3
1 parent f8e7264 commit 3fd1649

10 files changed

+519
-588
lines changed

backends/arm/test/common.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,15 @@ def maybe_get_tosa_collate_path() -> str | None:
4747
tosa_test_base = os.environ.get("TOSA_TESTCASES_BASE_PATH")
4848
if tosa_test_base:
4949
current_test = os.environ.get("PYTEST_CURRENT_TEST")
50-
#'backends/arm/test/ops/test_mean_dim.py::TestMeanDim::test_meandim_tosa_BI_0_zeros (call)'
51-
test_class = current_test.split("::")[1] # type: ignore[union-attr]
52-
test_name = current_test.split("::")[-1].split(" ")[0] # type: ignore[union-attr]
50+
# '::test_collate_tosa_BI_tests[randn] (call)'
51+
test_name = current_test.split("::")[1].split(" ")[0] # type: ignore[union-attr]
5352
if "BI" in test_name:
5453
tosa_test_base = os.path.join(tosa_test_base, "tosa-bi")
5554
elif "MI" in test_name:
5655
tosa_test_base = os.path.join(tosa_test_base, "tosa-mi")
5756
else:
5857
tosa_test_base = os.path.join(tosa_test_base, "other")
59-
return os.path.join(tosa_test_base, test_class, test_name)
58+
return os.path.join(tosa_test_base, test_name)
6059

6160
return None
6261

backends/arm/test/misc/test_custom_partition.py

Lines changed: 86 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,25 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import logging
7+
from typing import Tuple
78

89
import torch
910
from executorch.backends.arm.test import common
10-
from executorch.backends.arm.test.tester.arm_tester import ArmTester
11-
from executorch.backends.arm.tosa_partitioner import TOSAPartitioner
11+
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineMI
1212
from executorch.exir.backend.operator_support import (
1313
DontPartition,
1414
DontPartitionModule,
1515
DontPartitionName,
1616
)
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818

19+
input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x, y
20+
1921

2022
class CustomPartitioning(torch.nn.Module):
21-
inputs = (torch.randn(10, 4, 5), torch.randn(10, 4, 5))
23+
inputs = {
24+
"randn": (torch.randn(10, 4, 5), torch.randn(10, 4, 5)),
25+
}
2226

2327
def forward(self, x: torch.Tensor, y: torch.Tensor):
2428
z = x + y
@@ -27,7 +31,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
2731

2832

2933
class NestedModule(torch.nn.Module):
30-
inputs = (torch.randn(10, 4, 5), torch.randn(10, 4, 5))
34+
inputs = {
35+
"randn": (torch.randn(10, 4, 5), torch.randn(10, 4, 5)),
36+
}
3137

3238
def __init__(self):
3339
super().__init__()
@@ -39,192 +45,139 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
3945
return self.nested(a, b)
4046

4147

42-
def test_single_reject(caplog):
48+
@common.parametrize("test_data", CustomPartitioning.inputs)
49+
def test_single_reject(caplog, test_data: input_t1):
4350
caplog.set_level(logging.INFO)
4451

4552
module = CustomPartitioning()
46-
inputs = module.inputs
47-
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
53+
pipeline = TosaPipelineMI[input_t1](module, test_data, [], exir_op=[])
4854
check = DontPartition(exir_ops.edge.aten.sigmoid.default)
49-
partitioner = TOSAPartitioner(compile_spec, additional_checks=[check])
50-
(
51-
ArmTester(
52-
module,
53-
example_inputs=inputs,
54-
compile_spec=compile_spec,
55-
)
56-
.export()
57-
.to_edge_transform_and_lower(partitioners=[partitioner])
58-
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
59-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 2})
60-
.to_executorch()
61-
.run_method_and_compare_outputs(inputs=inputs)
55+
pipeline.change_args("to_edge_transform_and_lower", additional_checks=[check])
56+
pipeline.change_args(
57+
"check_count.exir", {"torch.ops.higher_order.executorch_call_delegate": 2}
58+
)
59+
pipeline.change_args(
60+
"check_count.exir",
61+
{"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1},
6262
)
63+
pipeline.run()
6364
assert check.has_rejected_node()
6465
assert "Rejected by DontPartition" in caplog.text
6566

6667

67-
def test_multiple_reject():
68+
@common.parametrize("test_data", CustomPartitioning.inputs)
69+
def test_multiple_reject(test_data: input_t1):
6870
module = CustomPartitioning()
69-
inputs = module.inputs
70-
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
71+
pipeline = TosaPipelineMI[input_t1](module, test_data, [], exir_op=[])
7172
check = DontPartition(
7273
exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.mul.Tensor
7374
)
74-
partitioner = TOSAPartitioner(compile_spec, additional_checks=[check])
75-
(
76-
ArmTester(
77-
module,
78-
example_inputs=inputs,
79-
compile_spec=compile_spec,
80-
)
81-
.export()
82-
.to_edge_transform_and_lower(partitioners=[partitioner])
83-
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
84-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
85-
.to_executorch()
86-
.run_method_and_compare_outputs(inputs=inputs)
75+
pipeline.change_args("to_edge_transform_and_lower", additional_checks=[check])
76+
pipeline.change_args(
77+
"check_count.exir", {"torch.ops.higher_order.executorch_call_delegate": 2}
78+
)
79+
pipeline.change_args(
80+
"check_count.exir",
81+
{"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1},
8782
)
83+
pipeline.run()
8884
assert check.has_rejected_node()
8985

9086

91-
def test_torch_op_reject(caplog):
87+
@common.parametrize("test_data", CustomPartitioning.inputs)
88+
def test_torch_op_reject(caplog, test_data: input_t1):
9289
caplog.set_level(logging.INFO)
9390

9491
module = CustomPartitioning()
95-
inputs = module.inputs
96-
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
9792
check = DontPartition(torch.ops.aten.sigmoid.default)
98-
partitioner = TOSAPartitioner(compile_spec, additional_checks=[check])
99-
(
100-
ArmTester(
101-
module,
102-
example_inputs=inputs,
103-
compile_spec=compile_spec,
104-
)
105-
.export()
106-
.to_edge_transform_and_lower(partitioners=[partitioner])
107-
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
108-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 2})
109-
.to_executorch()
110-
.run_method_and_compare_outputs(inputs=inputs)
93+
pipeline = TosaPipelineMI[input_t1](module, test_data, [], exir_op=[])
94+
pipeline.change_args("to_edge_transform_and_lower", additional_checks=[check])
95+
pipeline.change_args(
96+
"check_count.exir", {"torch.ops.higher_order.executorch_call_delegate": 2}
11197
)
98+
pipeline.change_args(
99+
"check_count.exir",
100+
{"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1},
101+
)
102+
pipeline.run()
112103
assert check.has_rejected_node()
113104
assert "Rejected by DontPartition" in caplog.text
114105

115106

116-
def test_string_op_reject():
107+
@common.parametrize("test_data", CustomPartitioning.inputs)
108+
def test_string_op_reject(test_data: input_t1):
117109
module = CustomPartitioning()
118-
inputs = module.inputs
119-
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
120110
check = DontPartition("aten.sigmoid.default")
121-
partitioner = TOSAPartitioner(compile_spec, additional_checks=[check])
122-
(
123-
ArmTester(
124-
module,
125-
example_inputs=inputs,
126-
compile_spec=compile_spec,
127-
)
128-
.export()
129-
.to_edge_transform_and_lower(partitioners=[partitioner])
130-
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
131-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 2})
132-
.to_executorch()
133-
.run_method_and_compare_outputs(inputs=inputs)
111+
pipeline = TosaPipelineMI[input_t1](module, test_data, [], exir_op=[])
112+
pipeline.change_args("to_edge_transform_and_lower", additional_checks=[check])
113+
pipeline.change_args(
114+
"check_count.exir", {"torch.ops.higher_order.executorch_call_delegate": 2}
134115
)
135-
116+
pipeline.change_args(
117+
"check_count.exir",
118+
{"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1},
119+
)
120+
pipeline.run()
136121
assert check.has_rejected_node()
137122

138123

139-
def test_name_reject(caplog):
124+
@common.parametrize("test_data", CustomPartitioning.inputs)
125+
def test_name_reject(caplog, test_data: input_t1):
140126
caplog.set_level(logging.INFO)
141127

142128
module = CustomPartitioning()
143-
inputs = module.inputs
144-
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
145129
check = DontPartitionName("mul", "sigmoid", exact=False)
146-
partitioner = TOSAPartitioner(compile_spec, additional_checks=[check])
147-
(
148-
ArmTester(
149-
module,
150-
example_inputs=inputs,
151-
compile_spec=compile_spec,
152-
)
153-
.export()
154-
.to_edge_transform_and_lower(partitioners=[partitioner])
155-
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
156-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
157-
.to_executorch()
158-
.run_method_and_compare_outputs(inputs=inputs)
130+
pipeline = TosaPipelineMI[input_t1](module, test_data, [], exir_op=[])
131+
pipeline.change_args("to_edge_transform_and_lower", additional_checks=[check])
132+
pipeline.change_args(
133+
"check_count.exir",
134+
{"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1},
159135
)
136+
pipeline.run()
160137
assert check.has_rejected_node()
161138
assert "Rejected by DontPartitionName" in caplog.text
162139

163140

164-
def test_module_reject():
141+
@common.parametrize("test_data", CustomPartitioning.inputs)
142+
def test_module_reject(test_data: input_t1):
165143
module = NestedModule()
166-
inputs = module.inputs
167-
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
168144
check = DontPartitionModule(module_name="CustomPartitioning")
169-
partitioner = TOSAPartitioner(compile_spec, additional_checks=[check])
170-
(
171-
ArmTester(
172-
module,
173-
example_inputs=inputs,
174-
compile_spec=compile_spec,
175-
)
176-
.export()
177-
.to_edge_transform_and_lower(partitioners=[partitioner])
178-
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
179-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
180-
.to_executorch()
181-
.run_method_and_compare_outputs(inputs=inputs)
145+
pipeline = TosaPipelineMI[input_t1](module, test_data, [], exir_op=[])
146+
pipeline.change_args("to_edge_transform_and_lower", additional_checks=[check])
147+
pipeline.change_args(
148+
"check_count.exir",
149+
{"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1},
182150
)
151+
pipeline.run()
183152
assert check.has_rejected_node()
184153

185154

186-
def test_inexact_module_reject(caplog):
155+
@common.parametrize("test_data", CustomPartitioning.inputs)
156+
def test_inexact_module_reject(caplog, test_data: input_t1):
187157
caplog.set_level(logging.INFO)
188158

189159
module = NestedModule()
190-
inputs = module.inputs
191-
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
192160
check = DontPartitionModule(module_name="Custom", exact=False)
193-
partitioner = TOSAPartitioner(compile_spec, additional_checks=[check])
194-
(
195-
ArmTester(
196-
module,
197-
example_inputs=inputs,
198-
compile_spec=compile_spec,
199-
)
200-
.export()
201-
.to_edge_transform_and_lower(partitioners=[partitioner])
202-
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
203-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
204-
.to_executorch()
205-
.run_method_and_compare_outputs(inputs=inputs)
161+
pipeline = TosaPipelineMI[input_t1](module, test_data, [], exir_op=[])
162+
pipeline.change_args("to_edge_transform_and_lower", additional_checks=[check])
163+
pipeline.change_args(
164+
"check_count.exir",
165+
{"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1},
206166
)
167+
pipeline.run()
207168
assert check.has_rejected_node()
208169
assert "Rejected by DontPartitionModule" in caplog.text
209170

210171

211-
def test_module_instance_reject():
172+
@common.parametrize("test_data", CustomPartitioning.inputs)
173+
def test_module_instance_reject(test_data: input_t1):
212174
module = NestedModule()
213-
inputs = module.inputs
214-
compile_spec = common.get_tosa_compile_spec("TOSA-0.80+MI")
215175
check = DontPartitionModule(instance_name="nested")
216-
partitioner = TOSAPartitioner(compile_spec, additional_checks=[check])
217-
(
218-
ArmTester(
219-
module,
220-
example_inputs=inputs,
221-
compile_spec=compile_spec,
222-
)
223-
.export()
224-
.to_edge_transform_and_lower(partitioners=[partitioner])
225-
.check(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
226-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
227-
.to_executorch()
228-
.run_method_and_compare_outputs(inputs=inputs)
176+
pipeline = TosaPipelineMI[input_t1](module, test_data, [], exir_op=[])
177+
pipeline.change_args("to_edge_transform_and_lower", additional_checks=[check])
178+
pipeline.change_args(
179+
"check_count.exir",
180+
{"executorch_exir_dialects_edge__ops_aten_sigmoid_default": 1},
229181
)
182+
pipeline.run()
230183
assert check.has_rejected_node()

0 commit comments

Comments
 (0)