Skip to content

Commit fd9eb28

Browse files
committed
Add helper functions for Q/DQ folding pass
Adds a helper function to retrieve QuantArgs from node.meta and cleanup the handling a bit by introducing the __eq__ operator for QuantArgs. Signed-off-by: Per Åstrand <[email protected]> Change-Id: I519a9a286a36a278f40ffb6c679192a54d9f940d
1 parent 2d39f78 commit fd9eb28

File tree

5 files changed

+77
-45
lines changed

5 files changed

+77
-45
lines changed

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,32 @@
1616
from torch.fx import GraphModule, Node
1717

1818

19+
def get_input_qparams(node: Node) -> dict[int, QuantArgs]:
20+
"""
21+
Get the input quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'.
22+
Raises a ValueError if the node doesn't have any parameters set.
23+
"""
24+
if "input_qparams" not in node.meta.keys():
25+
raise ValueError(f"No input quantization parameter found in node {node}")
26+
input_qparams = cast(dict[int, QuantArgs], node.meta["input_qparams"])
27+
if len(input_qparams) == 0:
28+
raise ValueError(f"No input quantization parameter found in node {node}")
29+
return input_qparams
30+
31+
32+
def get_output_qparams(node: Node) -> dict[int, QuantArgs]:
33+
"""
34+
Get the output quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'.
35+
Raises a ValueError if the node doesn't have any parameters set.
36+
"""
37+
if "output_qparams" not in node.meta.keys():
38+
raise ValueError(f"No output quantization parameter found in node {node}")
39+
input_qparams = cast(dict[int, QuantArgs], node.meta["output_qparams"])
40+
if len(input_qparams) == 0:
41+
raise ValueError(f"No output quantization parameter found in node {node}")
42+
return input_qparams
43+
44+
1945
class FoldAndAnnotateQParamsPass(ExportPass):
2046
"""
2147
A pass that walks the graph and removes any DQ and Q nodes before and after the target

backends/arm/operators/op_add.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,7 @@ def define_node(
7676
if output.dtype == ts.DType.INT8:
7777
# Scale output back to 8 bit
7878
# pyre-ignore
79-
tqutils.insert_rescale_node_back_to_int8(
80-
tosa_graph, add_output, scale_back, node
81-
)
79+
tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node)
8280

8381

8482
@register_node_visitor

backends/arm/operators/op_max.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55

66
# pyre-unsafe
77

8-
from typing import cast, List
8+
from typing import List
99

1010
import executorch.backends.arm.tosa_quant_utils as tqutils
11-
1211
import serializer.tosa_serializer as ts
12+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
13+
get_input_qparams,
14+
)
1315
from executorch.backends.arm.operators.node_visitor import (
1416
NodeVisitor,
1517
register_node_visitor,
@@ -38,30 +40,23 @@ def define_node(
3840
) -> None:
3941
assert inputs[0].dtype == inputs[1].dtype
4042

41-
input_qparams = cast(dict[int, tqutils.QuantArgs], node.meta["input_qparams"])
42-
min_output = output
43-
43+
max_output = output
4444
if inputs[0].dtype == ts.DType.INT8:
45-
# insert RESCALEs to int32
46-
x_scale = input_qparams[0].scale
47-
x_zp = input_qparams[0].zp
48-
49-
y_scale = input_qparams[1].scale
50-
y_zp = input_qparams[1].zp
51-
45+
input_qparams = get_input_qparams(node)
5246
assert (
53-
x_zp == y_zp
54-
), "Different zp for inputs, MAX should be quantized with shared quantization!"
47+
len(input_qparams) == 2
48+
), f"Both inputs needs to have quantization information for {node}"
49+
# insert RESCALEs to int32
5550
assert (
56-
x_scale == y_scale
57-
), "Different scale for input, MAX should be quantized with shared quantization!"
51+
input_qparams[0] == input_qparams[1]
52+
), "Both inputs must have same quantization for MAX"
5853

5954
operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
6055
tosa_graph, inputs, node
6156
)
6257

6358
output.shape = tosa_shape(output.shape, output.dim_order)
64-
min_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
59+
max_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
6560
else:
6661
operand_inputs = inputs
6762

@@ -71,11 +66,9 @@ def define_node(
7166
operand_inputs[0].name,
7267
operand_inputs[1].name,
7368
],
74-
[min_output.name],
69+
[max_output.name],
7570
)
7671

7772
if output.dtype == ts.DType.INT8:
7873
# insert RESCALE from int32 back to int8
79-
tqutils.insert_rescale_node_back_to_int8(
80-
tosa_graph, min_output, scale_back, node
81-
)
74+
tqutils.insert_rescale_op_to_int8(tosa_graph, max_output, scale_back, node)

backends/arm/operators/op_min.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55

66
# pyre-unsafe
77

8-
from typing import cast, List
8+
from typing import List
99

1010
import executorch.backends.arm.tosa_quant_utils as tqutils
1111

1212
import serializer.tosa_serializer as ts
13+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
14+
get_input_qparams,
15+
)
1316
from executorch.backends.arm.operators.node_visitor import (
1417
NodeVisitor,
1518
register_node_visitor,
@@ -38,23 +41,16 @@ def define_node(
3841
) -> None:
3942
assert inputs[0].dtype == inputs[1].dtype
4043

41-
input_qparams = cast(dict[int, tqutils.QuantArgs], node.meta["input_qparams"])
4244
min_output = output
43-
4445
if inputs[0].dtype == ts.DType.INT8:
45-
# insert RESCALEs to int32
46-
x_scale = input_qparams[0].scale
47-
x_zp = input_qparams[0].zp
48-
49-
y_scale = input_qparams[1].scale
50-
y_zp = input_qparams[1].zp
51-
46+
input_qparams = get_input_qparams(node)
5247
assert (
53-
x_zp == y_zp
54-
), "Different zp for inputs, MIN should be quantized with shared quantization!"
48+
len(input_qparams) == 2
49+
), f"Both inputs needs to have quantization information for {node}"
50+
# insert RESCALEs to int32
5551
assert (
56-
x_scale == y_scale
57-
), "Different scale for input, MIN should be quantized with shared quantization!"
52+
input_qparams[0] == input_qparams[1]
53+
), "Both inputs must have same quantization for MIN"
5854

5955
operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
6056
tosa_graph, inputs, node
@@ -76,6 +72,4 @@ def define_node(
7672

7773
if output.dtype == ts.DType.INT8:
7874
# insert RESCALE from int32 back to int8
79-
tqutils.insert_rescale_node_back_to_int8(
80-
tosa_graph, min_output, scale_back, node
81-
)
75+
tqutils.insert_rescale_op_to_int8(tosa_graph, min_output, scale_back, node)

backends/arm/tosa_quant_utils.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,19 @@ def insert_rescale_ops_to_int32(
5757
the graph upstream for DQ nodes.
5858
"""
5959

60+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
61+
get_input_qparams,
62+
)
63+
6064
tensors = inputs.copy()
6165

6266
# Reshape tensor according to TOSA dim order
6367
for tensor in tensors:
6468
dim_order = tensor.dim_order
6569
tensor.shape = [tensor.shape[i] for i in dim_order]
6670

67-
qargs = list(cast(dict[int, QuantArgs], node.meta["input_qparams"]).values())
71+
input_qparams = get_input_qparams(node)
72+
qargs = input_qparams.values()
6873

6974
# Scale the int8 quantized input to a common scale in the integer
7075
# domain
@@ -84,7 +89,7 @@ def insert_rescale_ops_to_int32(
8489
return rescaled_nodes, min_scale
8590

8691

87-
def insert_rescale_node_back_to_int8(
92+
def insert_rescale_op_to_int8(
8893
tosa_graph: ts.TosaSerializer,
8994
last_tensor: TosaArg,
9095
scale: float,
@@ -102,9 +107,14 @@ def insert_rescale_node_back_to_int8(
102107
in the node meta dict as opposed to 'rescale_node_back_to_int8' which search
103108
the graph downstream for Q nodes.
104109
"""
105-
assert len(node.meta["output_qparams"]) == 1
110+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
111+
get_output_qparams,
112+
)
113+
114+
output_qparams = get_output_qparams(node)
115+
assert len(output_qparams) == 1, "More than one output not supported"
106116

107-
qargs_out = cast(dict[int, QuantArgs], node.meta["output_qparams"])[0]
117+
qargs_out = output_qparams[0]
108118
output_rescale_scale = scale / qargs_out.scale
109119

110120
# Rescale Back to INT8
@@ -136,6 +146,17 @@ def quantize_value(self, x):
136146
def dequantize_value(self, qx: int) -> float:
137147
return (qx - self.zp) * self.scale
138148

149+
def __eq__(self, other):
150+
if isinstance(other, QuantArgs):
151+
return (
152+
self.scale == other.scale
153+
and self.zp == other.zp
154+
and self.qmin == other.qmin
155+
and self.qmax == other.qmax
156+
and self.dtype == other.dtype
157+
)
158+
return False
159+
139160
@classmethod
140161
def from_operator(cls, op, args):
141162
if op in dq_q_ops:

0 commit comments

Comments
 (0)