Skip to content

Commit c35df8b

Browse files
Arm backend: Add additional tosa_supported_op checks for BI (#8593)
Add additional tosa_supported_op checks for BI If a TosaSpecification without floating point support is used, additional checks will be made during paritioning to make sure that we don't partition operators that: - are not quantized properly, i.e. does not have a dq-q pair surrounding them. - should have been decomposed prior to qunatization, e.g. div should be decomposed to a mul and recip before quantization. Signed-off-by: Oscar Andersson <[email protected]> Co-authored-by: Erik Lundell <[email protected]>
1 parent 68eb62f commit c35df8b

File tree

3 files changed

+242
-6
lines changed

3 files changed

+242
-6
lines changed

backends/arm/_passes/fuse_quantized_activation_pass.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313

1414

1515
class FuseQuantizedActivationPass(ExportPass):
16-
def _is_fuseable_quantized_activation(self, node: Node):
16+
@staticmethod
17+
def _is_fuseable_quantized_activation(node: Node):
1718
"""Fuse activations that have a 0 lower bound and quantized with a qmin zero-point"""
1819
is_fuseable = node.target == exir_ops.edge.aten.relu.default
1920
if node.target == exir_ops.edge.aten.hardtanh.default:
@@ -29,7 +30,8 @@ def _is_fuseable_quantized_activation(self, node: Node):
2930
else:
3031
return False
3132

32-
def _is_fuseable_input(self, node: Node):
33+
@staticmethod
34+
def _is_fuseable_input(node: Node):
3335
return (
3436
node.target
3537
in (
@@ -45,11 +47,11 @@ def call(self, graph_module: torch.fx.GraphModule):
4547
if node.op != "call_function":
4648
continue
4749

48-
if not self._is_fuseable_quantized_activation(node):
50+
if not FuseQuantizedActivationPass._is_fuseable_quantized_activation(node):
4951
continue
5052

5153
input_node = node.args[0]
52-
if not self._is_fuseable_input(input_node):
54+
if not FuseQuantizedActivationPass._is_fuseable_input(input_node):
5355
continue
5456

5557
node.replace_all_uses_with(input_node)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 171 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,22 @@
55

66
# pyre-unsafe
77

8+
import itertools
89
import operator
10+
import typing
911
from typing import final, Optional, Sequence, Type
1012

13+
import torch
14+
1115
import torch.fx as fx
16+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
17+
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
18+
FuseQuantizedActivationPass,
19+
)
1220
from executorch.backends.arm.tosa_specification import TosaSpecification
1321
from executorch.exir.dialects._ops import ops as exir_ops
1422
from torch.fx.passes.operator_support import any_chain, chain, OperatorSupportBase
23+
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1524

1625

1726
class SupportedTOSAOperatorCheck(OperatorSupportBase):
@@ -27,7 +36,9 @@ def __init__(self, tosa_spec: TosaSpecification):
2736
targets: list[str] = []
2837

2938
@final
30-
def is_node_supported(self, submodules, node: fx.Node) -> bool:
39+
def is_node_supported(
40+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
41+
) -> bool:
3142
if node.target not in self.targets:
3243
return False
3344
return self.is_node_tosa_supported(node, self.tosa_spec)
@@ -75,6 +86,10 @@ def tosa_support_factory(
7586
tosa_spec: TosaSpecification,
7687
additional_checks: Optional[Sequence[OperatorSupportBase]] = None,
7788
) -> OperatorSupportBase:
89+
negative_checks: list[OperatorSupportBase] = []
90+
if not tosa_spec.support_float():
91+
negative_checks.append(NeedsDecompositionCheck())
92+
negative_checks.append(CheckProperQuantization())
7893
return chain(
7994
any_chain(
8095
BaseTOSASupportList(),
@@ -83,13 +98,16 @@ def tosa_support_factory(
8398
for check in get_registered_tosa_support_checks(tosa_spec)
8499
),
85100
),
101+
*negative_checks,
86102
*additional_checks if additional_checks else [],
87103
)
88104

89105

90106
class BaseTOSASupportList(OperatorSupportBase):
91107

92-
def is_node_supported(self, submodules, node: fx.Node) -> bool:
108+
def is_node_supported(
109+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
110+
) -> bool:
93111
supported = node.op == "call_function" and node.target in [
94112
exir_ops.edge.aten.abs.default,
95113
exir_ops.edge.aten.add.Tensor,
@@ -150,3 +168,154 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
150168
]
151169

152170
return supported
171+
172+
173+
class NeedsDecompositionCheck(OperatorSupportBase):
174+
"""
175+
Targeted operators need to be decomposed prior to quantization in order to get a pair of q-dq-nodes surrounding
176+
the operator, and to get optimal quantization parameters for each operator. This check will reject operators
177+
that need to be decomposed.
178+
"""
179+
180+
def is_node_supported(
181+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
182+
) -> bool:
183+
184+
if node.op != "call_function":
185+
return True
186+
if node.target == exir_ops.edge.aten.mean.dim:
187+
dim = node.args[1]
188+
return dim == [-1, -2]
189+
needs_decomp = node.target in [
190+
exir_ops.edge.aten.div.Tensor,
191+
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
192+
exir_ops.edge.aten.native_layer_norm.default,
193+
exir_ops.edge.aten.mean.dim,
194+
exir_ops.edge.aten._softmax.default,
195+
exir_ops.edge.aten._log_softmax.default,
196+
exir_ops.edge.aten.var.correction,
197+
exir_ops.edge.aten.var.dim,
198+
]
199+
return not needs_decomp
200+
201+
202+
class CheckProperQuantization(OperatorSupportBase):
203+
"""
204+
For targeted nodes, check that it has been quantized as expected. In most cases this means that a pair of quantize
205+
and dequantize nodes surrounds the node. This is neccessary for table operators and operators that need to rescale
206+
activations.
207+
"""
208+
209+
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
210+
q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
211+
212+
def _is_matmul_node_supported(
213+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
214+
):
215+
"""
216+
Find the matmul source partition containing this node and check that all its inputs and outputs are quantized.
217+
"""
218+
for graph_module in submodules.values():
219+
graph_module = typing.cast(fx.GraphModule, graph_module)
220+
matmul_partitions = get_source_partitions(
221+
graph_module.graph,
222+
[
223+
torch.matmul,
224+
],
225+
None,
226+
)
227+
matmul_partitions = list(
228+
itertools.chain.from_iterable(matmul_partitions.values())
229+
)
230+
matched_partition = None
231+
for partition in matmul_partitions:
232+
if node in partition.nodes:
233+
matched_partition = partition
234+
if matched_partition is not None:
235+
input_quantized = all(
236+
input_node.target == self.dq_op
237+
for input_node in matched_partition.input_nodes
238+
)
239+
if not input_quantized:
240+
return False
241+
output_quantized = all(
242+
output_node_user.target == self.q_op
243+
for output_node_user in matched_partition.output_nodes[0].users
244+
)
245+
if not output_quantized:
246+
return False
247+
else:
248+
return False
249+
250+
return True
251+
252+
def is_node_supported(
253+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
254+
) -> bool:
255+
output_quantized = False
256+
input_quantized = False
257+
if node.target not in (
258+
exir_ops.edge.aten.add.Tensor,
259+
exir_ops.edge.aten.avg_pool2d.default,
260+
exir_ops.edge.aten.bmm.default,
261+
exir_ops.edge.aten.convolution.default,
262+
exir_ops.edge.aten.exp.default,
263+
exir_ops.edge.aten.hardtanh.default,
264+
exir_ops.edge.aten.linear.default,
265+
exir_ops.edge.aten.log.default,
266+
exir_ops.edge.aten.max_pool2d_with_indices.default,
267+
exir_ops.edge.aten.mm.default,
268+
exir_ops.edge.aten.mul.Tensor,
269+
exir_ops.edge.aten.reciprocal.default,
270+
exir_ops.edge.aten.relu.default,
271+
exir_ops.edge.aten.rsqrt.default,
272+
exir_ops.edge.aten.sigmoid.default,
273+
exir_ops.edge.aten.sub.Tensor,
274+
exir_ops.edge.aten.tanh.default,
275+
exir_ops.edge.aten.upsample_nearest2d.vec,
276+
):
277+
return True
278+
elif node.target in (
279+
exir_ops.edge.aten.bmm.default,
280+
exir_ops.edge.aten.mm.default,
281+
):
282+
source_fn_stack: tuple[typing.Any] = node.meta.get("source_fn_stack", [])
283+
if len(source_fn_stack) > 0:
284+
if source_fn_stack[-1][1] in (torch.matmul,):
285+
return self._is_matmul_node_supported(submodules, node)
286+
287+
elif node.target in (exir_ops.edge.aten.max_pool2d_with_indices.default,):
288+
users = node.users
289+
output_quantized = all(
290+
user.target == operator.getitem
291+
and all(user_user.target == self.q_op for user_user in user.users)
292+
for user in users
293+
)
294+
elif FuseQuantizedActivationPass._is_fuseable_input(node):
295+
users = node.users
296+
output_quantized = all(
297+
FuseQuantizedActivationPass._is_fuseable_quantized_activation(user)
298+
for user in users
299+
)
300+
elif FuseQuantizedActivationPass._is_fuseable_quantized_activation(node):
301+
input_node = node.all_input_nodes[0]
302+
input_quantized = FuseQuantizedActivationPass._is_fuseable_input(input_node)
303+
304+
input_quantized = input_quantized or all(
305+
(input_node.target == self.dq_op)
306+
or (not get_first_fake_tensor(input_node).dtype.is_floating_point)
307+
for input_node in node.all_input_nodes
308+
)
309+
310+
if not input_quantized:
311+
return False
312+
313+
output_quantized = output_quantized or all(
314+
(output_node.target == self.q_op)
315+
or (not get_first_fake_tensor(output_node).dtype.is_floating_point)
316+
for output_node in node.users
317+
)
318+
319+
if not output_quantized:
320+
return False
321+
return True
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# Test that tosa_supported_operators reject operators that are not
7+
# quantized properly. This is typically a consequence of a torch op
8+
# such a Softplus that is decompsed into many other ops without
9+
# surrounding q/dq nodes.
10+
11+
from typing import Tuple
12+
13+
import torch
14+
from executorch.backends.arm.test import common
15+
16+
from executorch.backends.arm.test.tester.test_pipeline import (
17+
TosaPipelineBI,
18+
TosaPipelineMI,
19+
)
20+
21+
input_t1 = Tuple[torch.Tensor]
22+
aten_op: list[str] = ["torch.ops.aten.add.Tensor", "torch.ops.aten.softplus.default"]
23+
exir_op: list[str] = [
24+
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
25+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor",
26+
"executorch_exir_dialects_edge__ops_aten_exp_default",
27+
"executorch_exir_dialects_edge__ops_aten_div_Tensor",
28+
]
29+
30+
31+
test_data: dict[input_t1] = {
32+
"3d_rand": (torch.rand(1, 5, 5),),
33+
}
34+
35+
36+
class Module(torch.nn.Module):
37+
def __init__(self):
38+
super().__init__()
39+
self.softplus = torch.nn.Softplus()
40+
41+
def forward(self, x: torch.Tensor):
42+
return self.softplus(x + x)
43+
44+
45+
@common.parametrize("test_data", test_data)
46+
def test_softplus_tosa_MI(test_data: input_t1):
47+
pipeline = TosaPipelineMI[input_t1](
48+
Module(), test_data=test_data, aten_op=aten_op, exir_op=exir_op
49+
)
50+
# remove check_count.exir as there will be more than one delegate
51+
pipeline.pop_stage("check_count.exir")
52+
pipeline.run()
53+
54+
55+
@common.parametrize("test_data", test_data)
56+
def test_softplus_tosa_BI(test_data: input_t1):
57+
pipeline = TosaPipelineBI[input_t1](
58+
Module(), test_data=test_data, aten_op=aten_op, exir_op=exir_op
59+
)
60+
pipeline.pop_stage("check_not.exir")
61+
# check that all ops in exir_op except add are rejected
62+
pipeline.add_stage_after(
63+
"partition", pipeline.tester.check, exir_op[1:], suffix="exir_post_partition"
64+
)
65+
pipeline.run()

0 commit comments

Comments
 (0)