Skip to content

Commit 781ad6d

Browse files
committed
[Arm tester] Run delegate nodes using tosa_reference_model
Instead of executing just one delegate, we execute the graph and dispatch the delegate nodes to the tosa_reference_model. This makes the tester more flexible and enables running tests with multiple delegates, multiple outputs etc. Signed-off-by: Erik Lundell <[email protected]> Change-Id: I5fb0192da2d187f29c025e0c51238a9cb7d7ec90
1 parent 08770b7 commit 781ad6d

File tree

9 files changed

+286
-198
lines changed

9 files changed

+286
-198
lines changed

backends/arm/_passes/tag_io_quant_pass.py

Lines changed: 0 additions & 51 deletions
This file was deleted.

backends/arm/arm_backend.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2024 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -202,6 +202,20 @@ def is_tosa(compile_spec: List[CompileSpec]) -> bool:
202202
return False
203203

204204

205+
def is_quantize_io(compile_specs: List[CompileSpec]) -> bool:
206+
for spec in compile_specs:
207+
if spec.key == "quantize_io" and spec.value.decode() == "True":
208+
return True
209+
return False
210+
211+
212+
def get_tosa_version(compile_spec: List[CompileSpec]) -> TosaSpecification:
213+
for spec in compile_spec:
214+
if spec.key == "tosa_version":
215+
return TosaSpecification.create_from_string(spec.value.decode())
216+
raise RuntimeError("Could not find TOSA version in CompileSpec")
217+
218+
205219
def get_intermediate_path(compile_spec: List[CompileSpec]) -> Optional[str]:
206220
for spec in compile_spec:
207221
if spec.key == "debug_artifact_path":

backends/arm/arm_partitioner.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2024 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -10,8 +10,10 @@
1010
from typing import Callable, final, List, Optional, Tuple
1111

1212
import torch
13-
from executorch.backends.arm.arm_backend import ArmBackend # usort: skip
14-
from executorch.backends.arm._passes.tag_io_quant_pass import TagIOQuantPass
13+
from executorch.backends.arm.arm_backend import (
14+
ArmBackend,
15+
is_quantize_io,
16+
) # usort: skip
1517
from executorch.backends.arm.operator_support.tosa_supported_operators import (
1618
TOSASupportedOperators,
1719
)
@@ -23,7 +25,7 @@
2325
PartitionResult,
2426
)
2527
from executorch.exir.backend.utils import tag_constant_data
26-
from executorch.exir.passes import PassManager
28+
from executorch.exir.dialects._ops import ops as exir_ops
2729
from torch.export.exported_program import ExportedProgram
2830
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
2931

@@ -35,6 +37,22 @@
3537
logger.setLevel(logging.INFO)
3638

3739

40+
def is_quant_node(node: torch.fx.node.Node) -> bool:
41+
return node.target in {
42+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
43+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
44+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
45+
}
46+
47+
48+
def is_dequant_node(node: torch.fx.node.Node) -> bool:
49+
return node.target in {
50+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
51+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
52+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
53+
}
54+
55+
3856
@final
3957
class ArmPartitioner(Partitioner):
4058
def __init__(self, compile_spec: List[CompileSpec]) -> None:
@@ -43,6 +61,7 @@ def __init__(self, compile_spec: List[CompileSpec]) -> None:
4361
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
4462
# Run the CapabilityBasedPartitioner to return the largest possible
4563
# subgraphs containing the nodes with the tags
64+
4665
logger.info("ArmPartitioner::partition")
4766
partition_tags = {}
4867

@@ -52,28 +71,42 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
5271

5372
logger.info(f"Partitioning for {tosa_spec}")
5473

55-
for spec in self.delegation_spec.compile_specs:
56-
if spec.key == "quantize_io" and spec.value.decode() == "True":
57-
# Exclude IO quantization from the partition
58-
passes = PassManager(
59-
passes=[
60-
TagIOQuantPass(),
61-
]
62-
)
63-
passes(exported_program.graph_module)
64-
6574
capability_partitioner = CapabilityBasedPartitioner(
6675
exported_program.graph_module,
6776
TOSASupportedOperators(tosa_spec),
6877
allows_single_node_partition=True,
6978
)
7079
partition_list = capability_partitioner.propose_partitions()
7180
for partition in partition_list:
81+
tag = f"tag{partition.id}"
82+
83+
def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
84+
return (
85+
"delegation_tag" in node.meta and node.meta["delegation_tag"] == tag
86+
)
87+
7288
for node in partition.nodes:
73-
tag = f"tag{partition.id}"
7489
node.meta["delegation_tag"] = tag
7590
partition_tags[tag] = self.delegation_spec
7691

92+
if not is_quantize_io(self.delegation_spec.compile_specs):
93+
continue
94+
95+
# De-tag outmost q-nodes upwards and dq-nodes downwards.
96+
# De-tag if at least one input/ output is not part of partition.
97+
for node in partition.nodes:
98+
if is_quant_node(node):
99+
for input in node.all_input_nodes:
100+
if not is_partitioned(input):
101+
del node.meta["delegation_tag"]
102+
break
103+
104+
if is_dequant_node(node):
105+
for user in node.users:
106+
if not is_partitioned(user):
107+
del node.meta["delegation_tag"]
108+
break
109+
77110
tag_constant_data(exported_program)
78111

79112
return PartitionResult(

backends/arm/test/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
@@ -50,7 +50,6 @@ def maybe_get_tosa_collate_path() -> str | None:
5050
tosa_test_base = os.path.join(tosa_test_base, "tosa-mi")
5151
else:
5252
tosa_test_base = os.path.join(tosa_test_base, "other")
53-
5453
return os.path.join(tosa_test_base, test_class, test_name)
5554

5655
return None
@@ -83,6 +82,7 @@ def get_tosa_compile_spec_unbuilt(
8382
.tosa_compile_spec(tosa_version)
8483
.set_permute_memory_format(permute_memory_to_nhwc)
8584
.dump_intermediate_artifacts_to(custom_path)
85+
.set_quantize_io(True)
8686
)
8787

8888
return compile_spec_builder
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
12+
13+
14+
class TestMultipleDelegates(unittest.TestCase):
15+
class MultipleDelegatesModule(torch.nn.Module):
16+
inputs = (torch.randn(10, 4, 5), torch.randn(10, 4, 5))
17+
18+
def get_inputs(self):
19+
return self.inputs
20+
21+
def forward(self, x: torch.Tensor, y: torch.Tensor):
22+
z = x + y
23+
s = torch.sin(z)
24+
return s * z
25+
26+
def test_tosa_MI(self):
27+
module = self.MultipleDelegatesModule()
28+
inputs = module.get_inputs()
29+
(
30+
ArmTester(
31+
module,
32+
example_inputs=inputs,
33+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
34+
)
35+
.export()
36+
.to_edge_transform_and_lower()
37+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 2})
38+
.to_executorch()
39+
.run_method_and_compare_outputs(inputs=inputs)
40+
)
41+
42+
def test_tosa_BI(self):
43+
module = self.MultipleDelegatesModule()
44+
inputs = module.get_inputs()
45+
(
46+
ArmTester(
47+
module,
48+
example_inputs=inputs,
49+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
50+
)
51+
.quantize()
52+
.export()
53+
.to_edge_transform_and_lower()
54+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 2})
55+
.to_executorch()
56+
.run_method_and_compare_outputs(inputs=inputs, qtol=1.0)
57+
)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
12+
13+
14+
class TestMultipleOutputs(unittest.TestCase):
15+
class MultipleOutputsModule(torch.nn.Module):
16+
inputs = (torch.randn(10, 4, 5), torch.randn(10, 4, 5))
17+
18+
def get_inputs(self):
19+
return self.inputs
20+
21+
def forward(self, x: torch.Tensor, y: torch.Tensor):
22+
return (x * y, x.sum(dim=-1, keepdim=True))
23+
24+
def test_tosa_MI_pipeline(self):
25+
module = self.MultipleOutputsModule()
26+
inputs = module.get_inputs()
27+
(
28+
ArmTester(
29+
module,
30+
example_inputs=inputs,
31+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
32+
)
33+
.export()
34+
.to_edge_transform_and_lower()
35+
.to_executorch()
36+
.run_method_and_compare_outputs(inputs=inputs)
37+
)
38+
39+
def test_tosa_BI_pipeline(self):
40+
module = self.MultipleOutputsModule()
41+
inputs = module.get_inputs()
42+
(
43+
ArmTester(
44+
module,
45+
example_inputs=inputs,
46+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
47+
)
48+
.quantize()
49+
.export()
50+
.to_edge_transform_and_lower()
51+
.to_executorch()
52+
.run_method_and_compare_outputs(inputs=inputs, qtol=1.0)
53+
)

backends/arm/test/passes/test_tag_io_quant_pass.py

Lines changed: 0 additions & 64 deletions
This file was deleted.

0 commit comments

Comments
 (0)