Skip to content

Commit 1c39228

Browse files
committed
Add pass for replacing dq-q patterns with rescale
When an int8 op meets an int32 op, the int8 is first dequantized to a float which is then quantized to the desired int8 dtype. This produces a (int8 dq -> q int32) pattern that we can replace with a TOSA.RESCALE since they are approximately mathematically equivalent, differing only in how the rounding is done. This requires a few changes: - Introduce custom rescale op - Create pass to replace the dq-q pattern with the rescale op - Implement node_visitor for rescale op The change makes it possible to mix int8 and int32 quantization, as showcased in the new test_add_i32_tosa_BI test. Signed-off-by: Erik Lundell <[email protected]> Change-Id: Ifd475ade488fbbeb8395ac883986b19f8edfae5a
1 parent 8d96d74 commit 1c39228

File tree

7 files changed

+395
-1
lines changed

7 files changed

+395
-1
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found]
5050
FuseQuantizedActivationPass,
5151
)
52+
from executorch.backends.arm._passes.insert_rescales_pass import InsertRescalePass
5253
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
5354
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
5455
KeepDimsFalseToSqueezePass,
@@ -72,6 +73,7 @@
7273
UnsqueezeScalarPlaceholdersPass,
7374
)
7475
from executorch.backends.arm.tosa_specification import TosaSpecification
76+
7577
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
7678
from executorch.exir import ExportedProgram
7779
from executorch.exir.pass_manager import PassManager
@@ -115,7 +117,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
115117
self.add_pass(ConvertSqueezesToViewPass())
116118

117119
self.add_pass(AnnotateChannelsLastDimOrder())
118-
120+
self.add_pass(InsertRescalePass())
119121
return self._transform(exported_program.graph_module)
120122

121123
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
@@ -153,6 +155,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
153155
self.add_pass(ConvertSqueezesToViewPass())
154156

155157
self.add_pass(AnnotateChannelsLastDimOrder())
158+
self.add_pass(InsertRescalePass())
156159

157160
return self._transform(exported_program.graph_module)
158161

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ def call(self, graph_module: GraphModule) -> PassResult:
131131
n = cast(Node, n)
132132
if n.op != "call_function":
133133
continue
134+
# Don't fold chains of quant-ops into each other.
135+
if n.target in (q_op, dq_op):
136+
continue
134137

135138
# Make sure we haven't already set qparams meta information on the node
136139
assert "input_qparams" not in n.meta.keys()
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
from copy import copy
8+
from typing import cast
9+
10+
import torch
11+
from executorch.backends.arm._passes.arm_pass_utils import create_node
12+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, QuantArgs
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
from torch import Tensor
15+
from torch.fx import GraphModule, Node
16+
from torch.library import custom_op, register_fake
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
@custom_op("tosa::_rescale", mutates_args=()) # type: ignore[misc]
22+
def rescale(
23+
x: Tensor, dtype: torch.dtype, scale: float, in_zp: int, out_zp: int
24+
) -> Tensor:
25+
logger.warning(
26+
"Ran default implementation of tosa::_rescale."
27+
"This op is meant to always be inserted inside a partition and a correct default implementation is not implemented."
28+
)
29+
# Clone is needed to not return reference when rescaling to same dtype.
30+
# This is a neccessary requirement for non-mutating custom ops.
31+
return x.to(dtype=dtype).clone()
32+
33+
34+
@register_fake("tosa::_rescale") # type: ignore[misc]
35+
def rescale_fake(
36+
x: Tensor, dtype: torch.dtype, scale: float, in_zp: int, out_zp: int
37+
) -> Tensor:
38+
"""Casts the input tensor to dtype `dtype` to produce the correct tensor meta for a _rescale op.
39+
Additionally validates TOSA constraints of a RESCALE op.
40+
"""
41+
if not (dtype == torch.int32 or dtype == torch.int8):
42+
raise NotImplementedError(
43+
"tosa::rescale currently only supports int32 and int8."
44+
)
45+
if dtype == torch.int32 and out_zp != 0:
46+
raise ValueError(
47+
"TOSA requires output_zp to be zero when the output dtype is int32."
48+
)
49+
if x.dtype == torch.int32 and in_zp != 0:
50+
raise ValueError(
51+
"TOSA requires input_zp to be zero when the input dtype is int32."
52+
)
53+
if x.dtype == torch.int8 and not -128 <= in_zp <= 127:
54+
raise ValueError(f"{in_zp=} outside valid range (-128,127) for int8.")
55+
if dtype == torch.int8 and not -128 <= out_zp <= 127:
56+
raise ValueError(f"{out_zp=} outside valid range (-128,127) for int8.")
57+
58+
return x.to(dtype=dtype).clone()
59+
60+
61+
class InsertRescalePass(ExportPass):
62+
"""Finds patterns of dq -> q, and replaces them
63+
with passthrough_to_tosa::rescales.
64+
65+
Does not garantuee that the dtypes and zero points are valid
66+
in TOSA, that is the job of the quantization annotator that
67+
produced the dq and q nodes. The TOSA constraints are validated
68+
in the fake implementation of passthrough_to_tosa:rescale.
69+
"""
70+
71+
def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule):
72+
dq_args = QuantArgs.from_operator(node.target, node.args)
73+
q_args = QuantArgs.from_operator(user.target, user.args)
74+
new_scale = dq_args.scale / q_args.scale
75+
76+
with graph_module.graph.inserting_before(node):
77+
rescale_node = create_node(
78+
graph_module.graph,
79+
torch.ops.tosa._rescale.default,
80+
(
81+
node.all_input_nodes[0],
82+
q_args.dtype,
83+
new_scale,
84+
dq_args.zp,
85+
q_args.zp,
86+
),
87+
)
88+
rescale_node.meta = copy(user.meta)
89+
user.replace_all_uses_with(rescale_node)
90+
graph_module.graph.erase_node(user)
91+
92+
def call(self, graph_module: GraphModule) -> PassResult:
93+
modified = False
94+
for node in graph_module.graph.nodes:
95+
node = cast(Node, node)
96+
97+
if node.target is not dq_op:
98+
continue
99+
# Copy users since we remove them while iterating, modyfing the node.users list.
100+
for user in copy(node.users):
101+
if user.target is q_op:
102+
self.fold_dq_q_to_rescale(node, user, graph_module)
103+
modified = True
104+
if len(node.users) == 0:
105+
graph_module.graph.erase_node(node)
106+
107+
graph_module = super().call(graph_module).graph_module
108+
graph_module.recompile()
109+
return PassResult(graph_module, modified)

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
op_reciprocal,
3232
op_relu,
3333
op_repeat,
34+
op_rescale,
3435
op_rshift,
3536
op_rsqrt,
3637
op_sigmoid,

backends/arm/operators/op_rescale.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import cast, List
9+
10+
import executorch.backends.arm.tosa_quant_utils as tosa_quant_utils
11+
import serializer.tosa_serializer as ts # type: ignore
12+
import torch
13+
14+
import tosa.Op as TosaOp # type: ignore
15+
from executorch.backends.arm.operators.node_visitor import (
16+
NodeVisitor,
17+
register_node_visitor,
18+
)
19+
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
20+
from torch.fx import Node
21+
22+
23+
@register_node_visitor
24+
class RescaleVisitor(NodeVisitor):
25+
target = "_rescale.default"
26+
27+
def define_node(
28+
self,
29+
node: Node,
30+
tosa_graph: ts.TosaSerializer,
31+
inputs: List[TosaArg],
32+
output: TosaArg,
33+
) -> None:
34+
35+
input_dtype = inputs[0].dtype
36+
output_dtype = cast(torch.dtype, node.args[1])
37+
scale = cast(float, node.args[2])
38+
input_zp = cast(int, node.args[3])
39+
output_zp = cast(int, node.args[4])
40+
41+
# Skip int16 cases for now.
42+
if input_dtype != map_dtype(torch.int8) and input_zp != 0:
43+
raise ValueError(
44+
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{ts.DTypeNames[input_dtype]}, {input_zp=}"
45+
)
46+
if output_dtype != torch.int8 and output_zp != 0:
47+
raise ValueError(
48+
f"If output dtype is not int8, output_zp must be 0. Got {output_dtype=}, {output_zp=}"
49+
)
50+
51+
scale_width = 32 if output_dtype == torch.int32 else 16
52+
multiplier, shift = tosa_quant_utils.compute_multiplier_and_shift(
53+
scale, scale_width
54+
)
55+
attr_rescale = ts.TosaSerializerAttribute()
56+
attr_rescale.RescaleAttribute(
57+
input_zp=input_zp,
58+
output_zp=output_zp,
59+
multiplier=[multiplier],
60+
shift=[shift],
61+
scale32=output_dtype == torch.int32,
62+
double_round=False,
63+
per_channel=False,
64+
input_unsigned=False,
65+
output_unsigned=False,
66+
)
67+
68+
tosa_graph.addOperator(
69+
TosaOp.Op().RESCALE, [inputs[0].name], [output.name], attr_rescale
70+
)

backends/arm/test/ops/test_add.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,19 @@
99
from typing import Tuple
1010

1111
import torch
12+
from executorch.backends.arm.arm_backend import get_tosa_version
13+
from executorch.backends.arm.quantizer import arm_quantizer
1214
from executorch.backends.arm.test import common
1315
from executorch.backends.arm.test.tester.test_pipeline import (
1416
EthosU55PipelineBI,
1517
EthosU85PipelineBI,
1618
TosaPipelineBI,
1719
TosaPipelineMI,
1820
)
21+
from executorch.backends.xnnpack.test.tester import Quantize
22+
from torch.ao.quantization.observer import HistogramObserver
23+
from torch.ao.quantization.quantizer import QuantizationSpec
24+
1925

2026
aten_op = "torch.ops.aten.add.Tensor"
2127
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"
@@ -67,6 +73,38 @@ def test_add_tosa_BI(test_data: input_t1):
6773
pipeline.run()
6874

6975

76+
@common.parametrize("test_data", Add.test_data)
77+
def test_add_i32_tosa_BI(test_data: input_t1):
78+
pipeline = TosaPipelineBI[input_t1](Add(), test_data, aten_op, exir_op)
79+
80+
# Create a quantizer with int8 quantization on the input and output but int32 on everything else.
81+
quantizer = arm_quantizer.ArmQuantizer(
82+
get_tosa_version(common.get_tosa_compile_spec("TOSA-0.80+BI"))
83+
)
84+
quantizer.set_io(arm_quantizer.get_symmetric_quantization_config())
85+
observer_options = {"eps": 2**-16}
86+
observer = HistogramObserver.with_args(**observer_options)
87+
input_act_qspec = QuantizationSpec(
88+
torch.int32,
89+
observer,
90+
qscheme=torch.per_tensor_symmetric,
91+
quant_max=2**31 - 1,
92+
quant_min=-(2**31),
93+
)
94+
# This quantization_config will be set as global config.
95+
quantization_config = arm_quantizer.QuantizationConfig(
96+
input_act_qspec, None, None, None
97+
)
98+
quantize_stage = Quantize(quantizer, quantization_config)
99+
pipeline.change_args("quantize", quantize_stage)
100+
101+
# Check that we get the additional (dq -> q
102+
pipeline.add_stage_after(
103+
"export", pipeline.tester.check_count, {"torch.ops.quantized_decomposed": 8}
104+
)
105+
pipeline.run()
106+
107+
70108
@common.parametrize("test_data", Add.test_data)
71109
def test_add_u55_BI(test_data: input_t1):
72110
pipeline = EthosU55PipelineBI[input_t1](

0 commit comments

Comments
 (0)