Skip to content

Commit 843023a

Browse files
committed
Introduce a quantization folding pass with annotations
Fold DQ/Q nodes into the target operators specified to the pass. Signed-off-by: Per Åstrand <[email protected]> Change-Id: I8a09dc0b887dd5f3915ca157f578ecf51772a1a2
1 parent a7632f3 commit 843023a

File tree

1 file changed

+105
-0
lines changed

1 file changed

+105
-0
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
9+
from typing import Callable, cast, Iterable
10+
11+
from executorch.backends.arm.tosa_quant_utils import QuantArgs
12+
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
15+
from executorch.exir.pass_base import ExportPass, PassResult
16+
from torch.fx import GraphModule, Node
17+
18+
19+
class FoldAndAnnotateQParamsPass(ExportPass):
20+
"""
21+
A pass that walks the graph and removes any DQ and Q nodes before and after the target
22+
node in the supplied list of operators.
23+
The quantization parameters from the DQ/Q nodes are stored as meta values to be
24+
accessible for later lowering and serialization passes.
25+
The assumption is that the quantization annotatation adds DQ nodes for all tensor
26+
inputs to the target one Q node to the output.
27+
28+
Example ('executorch_exir_dialects_edge__ops_' prefix removed from operators for readability):
29+
30+
x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8)
31+
32+
x_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(x_q, 0.05487706884741783, -128, -128, 127, torch.int8)
33+
aten_add_tensor: "f32[5]" = ops_aten_add_Tensor(x_dq, x_dq)
34+
aten_add_tensor_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(aten_add_tensor, 0.05487706884741783, -128, -128, 127, torch.int8)
35+
36+
output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8)
37+
38+
Becomes:
39+
x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8)
40+
41+
aten_add_tensor: "i8[5]" = aten_add_Tensor(x_q, x_q)
42+
43+
output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8)
44+
45+
The quantization parameters for x_dq and aten_add_tensor_q are store in meta for the aten_add_tensor node.
46+
47+
"""
48+
49+
def __init__(self, targeted_ops: Iterable[Callable]):
50+
super().__init__()
51+
self.targeted_ops = targeted_ops
52+
53+
def call(self, graph_module: GraphModule) -> PassResult:
54+
q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
55+
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
56+
57+
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
58+
for n in graph_module.graph.nodes:
59+
n = cast(Node, n)
60+
if n.op != "call_function" or n.target not in self.targeted_ops:
61+
continue
62+
63+
# Make sure we haven't already set qparams meta information on the node
64+
assert "input_qparams" not in n.meta.keys()
65+
assert "output_qparams" not in n.meta.keys()
66+
67+
# for the inputs and outputs search the graph for quantization info and
68+
# store the information in a dict with order of the _tensor_ inputs as key,
69+
# ignoring any other arguments to the target node.
70+
n.meta["input_qparams"] = {}
71+
n.meta["output_qparams"] = {}
72+
for i, arg in enumerate(n.args):
73+
if not isinstance(arg, Node):
74+
continue
75+
if arg.target != dq_op:
76+
continue
77+
78+
# arg.target for argument i is a dequant node, extract the information
79+
n.meta["input_qparams"][i] = QuantArgs.from_operator(
80+
arg.target, arg.args
81+
)
82+
83+
# arg.args[0] is the tensor input, replace the input usage
84+
n.replace_input_with(arg, arg.args[0])
85+
graph_module.graph.erase_node(arg)
86+
87+
# Copy the users, since we are modifying it.
88+
users_copy = copy.copy(n.users)
89+
for i, user in enumerate(users_copy):
90+
if user.target != q_op:
91+
continue
92+
93+
# quantization node found here, store the quantization parameters in meta value
94+
n.meta["output_qparams"][i] = QuantArgs.from_operator(
95+
user.target, user.args
96+
)
97+
98+
user.replace_all_uses_with(n)
99+
graph_module.graph.erase_node(user)
100+
101+
# retrace the graph to update the fake tensor types
102+
graph_module = super().call(graph_module).graph_module
103+
104+
graph_module.recompile()
105+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)