Skip to content

Commit a059981

Browse files
authored
Quantize compatible node + activation patterns as one block (#7555)
Annotate conv1d/conv2d/linear followed by relu/relu6 patterns as one block and fuse the activation into its parent. The activation will then be implicitly done in the tosa.rescale node that will have a -128 zero-point. Change-Id: I5bf1e2c91be21ab842012fbc20d159af7fe2222d
1 parent e2afedf commit a059981

File tree

4 files changed

+133
-4
lines changed

4 files changed

+133
-4
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
QuantizeFullArgument,
3838
RetraceFoldedDtypesPass,
3939
)
40+
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
41+
FuseQuantizedActivationPass,
42+
)
4043
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
4144
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
4245
KeepDimsFalseToSqueezePass,
@@ -73,6 +76,7 @@ def transform_to_backend_pipeline(
7376
self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
7477
):
7578
"""Apply passes before transforming program to backend"""
79+
self.add_pass(FuseQuantizedActivationPass())
7680
self.add_pass(DecomposeLinearPass())
7781
self.add_pass(RemoveGetItemPass())
7882
self.add_pass(DecomposeLayerNormPass())
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm.tosa_quant_utils import q_op
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
from torch.fx import Node
11+
12+
13+
class FuseQuantizedActivationPass(ExportPass):
14+
def _is_fuseable_quantized_activation(self, node: Node):
15+
"""Fuse activations that have a 0 lower bound and quantized with a qmin zero-point"""
16+
is_fuseable = node.target == exir_ops.edge.aten.relu.default
17+
if node.target == exir_ops.edge.aten.hardtanh.default:
18+
min_val = node.args[1]
19+
is_fuseable = min_val == 0
20+
21+
is_quantized = len(node.users) == 1 and next(iter(node.users)).target == q_op
22+
if is_quantized:
23+
quant_node = next(iter(node.users))
24+
zp = quant_node.args[2]
25+
qmin = quant_node.args[3]
26+
27+
return is_fuseable and is_quantized and zp == qmin
28+
29+
def _is_fuseable_input(self, node: Node):
30+
return (
31+
node.target
32+
in (
33+
exir_ops.edge.aten.convolution.default,
34+
exir_ops.edge.aten.linear.default,
35+
)
36+
and len(node.users) == 1
37+
)
38+
39+
def call(self, graph_module: torch.fx.GraphModule):
40+
modified = False
41+
for node in graph_module.graph.nodes:
42+
if node.op != "call_function":
43+
continue
44+
45+
if not self._is_fuseable_quantized_activation(node):
46+
continue
47+
48+
input_node = node.args[0]
49+
if not self._is_fuseable_input(input_node):
50+
continue
51+
52+
node.replace_all_uses_with(input_node)
53+
graph_module.graph.erase_node(node)
54+
modified = True
55+
56+
if modified:
57+
graph_module.recompile()
58+
graph_module = super().call(graph_module).graph_module
59+
60+
return PassResult(graph_module, modified)

backends/arm/quantizer/quantization_annotator.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,41 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
8989
_annotate_output_qspec(node, quant_property.qspec)
9090

9191

92+
def _match_pattern(
93+
node: Node, pattern: List[List], filter_fn: Optional[Callable[[Node], bool]] = None
94+
) -> bool:
95+
"""
96+
Check if there's a chain of node.ancestors? -> node -> node.descendant? that matches the
97+
chain provided in 'pattern'. If 'filter_fn' is provided, check that all the nodes in the
98+
chain pass the filtering.
99+
100+
Each 'pattern' element is composed of a list of disjunctive nodes types.
101+
"""
102+
assert len(pattern) == 2, "Only two-nodes patterns supported currently"
103+
104+
if node.target in pattern[0]:
105+
assert len(node.users) != 0
106+
parent = node
107+
child = next(iter(node.users))
108+
elif node.target in pattern[1]:
109+
assert len(node.args) != 0
110+
parent = node.args[0]
111+
child = node
112+
else:
113+
return False
114+
115+
if len(parent.users) != 1:
116+
return False
117+
118+
if parent.target not in pattern[0] or child.target not in pattern[1]:
119+
return False
120+
121+
if filter_fn is not None:
122+
return filter_fn(parent) and filter_fn(child)
123+
124+
return True
125+
126+
92127
_one_to_one = [
93128
torch.ops.aten.exp.default,
94129
torch.ops.aten.log.default,
@@ -164,7 +199,36 @@ def get_quant_properties( # noqa: C901
164199
bias_qspec = quantization_config.get_bias_qspec()
165200

166201
quant_properties = _OpQuantProperties()
167-
if node.target in (
202+
203+
def any_or_hardtanh_min_zero(n: Node):
204+
# Check that if the node is a hardtanh, its min_val is zero
205+
return n.target != torch.ops.aten.hardtanh.default or n.args[1] == 0
206+
207+
if _match_pattern(
208+
node,
209+
[
210+
[
211+
torch.ops.aten.conv1d.default,
212+
torch.ops.aten.conv2d.default,
213+
torch.ops.aten.linear.default,
214+
],
215+
[torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default],
216+
],
217+
any_or_hardtanh_min_zero,
218+
):
219+
if node.target in (
220+
torch.ops.aten.conv1d.default,
221+
torch.ops.aten.conv2d.default,
222+
torch.ops.aten.linear.default,
223+
):
224+
quant_properties.quant_inputs = [
225+
_QuantProperty(0, input_act_qspec),
226+
_QuantProperty(1, weight_qspec, mark_annotated=True),
227+
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
228+
]
229+
else:
230+
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
231+
elif node.target in (
168232
torch.ops.aten.conv1d.default,
169233
torch.ops.aten.conv2d.default,
170234
torch.ops.aten.linear.default,

backends/arm/test/ops/test_conv_combos.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,11 @@ class ComboConvRelu6(torch.nn.Module):
137137
]
138138

139139
test_data = [
140-
(20 * torch.randn(1, 3, 256, 256),),
141-
(5 * torch.randn(1, 3, 256, 256),),
140+
(2 * torch.randn(1, 3, 256, 256),),
141+
(0.5 * torch.randn(1, 3, 256, 256),),
142142
(torch.randn(1, 3, 256, 256),),
143-
(-5 * torch.randn(1, 3, 256, 256),),
143+
(-0.5 * torch.randn(1, 3, 256, 256),),
144+
(-2 * torch.randn(1, 3, 256, 256),),
144145
]
145146

146147
def __init__(self):

0 commit comments

Comments
 (0)