Skip to content

Add pass for replacing dq-q patterns with rescale #8415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found]
FuseQuantizedActivationPass,
)
from executorch.backends.arm._passes.insert_rescales_pass import InsertRescalePass
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
KeepDimsFalseToSqueezePass,
Expand All @@ -72,6 +73,7 @@
UnsqueezeScalarPlaceholdersPass,
)
from executorch.backends.arm.tosa_specification import TosaSpecification

from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
from executorch.exir import ExportedProgram
from executorch.exir.pass_manager import PassManager
Expand Down Expand Up @@ -115,7 +117,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertSqueezesToViewPass())

self.add_pass(AnnotateChannelsLastDimOrder())

self.add_pass(InsertRescalePass())
return self._transform(exported_program.graph_module)

def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
Expand Down Expand Up @@ -153,6 +155,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertSqueezesToViewPass())

self.add_pass(AnnotateChannelsLastDimOrder())
self.add_pass(InsertRescalePass())

return self._transform(exported_program.graph_module)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def call(self, graph_module: GraphModule) -> PassResult:
n = cast(Node, n)
if n.op != "call_function":
continue
# Don't fold chains of quant-ops into each other.
if n.target in (q_op, dq_op):
continue

# Make sure we haven't already set qparams meta information on the node
assert "input_qparams" not in n.meta.keys()
Expand Down
109 changes: 109 additions & 0 deletions backends/arm/_passes/insert_rescales_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from copy import copy
from typing import cast

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, QuantArgs
from executorch.exir.pass_base import ExportPass, PassResult
from torch import Tensor
from torch.fx import GraphModule, Node
from torch.library import custom_op, register_fake

logger = logging.getLogger(__name__)


@custom_op("tosa::_rescale", mutates_args=()) # type: ignore[misc]
def rescale(
x: Tensor, dtype: torch.dtype, scale: float, in_zp: int, out_zp: int
) -> Tensor:
logger.warning(
"Ran default implementation of tosa::_rescale."
"This op is meant to always be inserted inside a partition and a correct default implementation is not implemented."
)
# Clone is needed to not return reference when rescaling to same dtype.
# This is a neccessary requirement for non-mutating custom ops.
return x.to(dtype=dtype).clone()


@register_fake("tosa::_rescale") # type: ignore[misc]
def rescale_fake(
x: Tensor, dtype: torch.dtype, scale: float, in_zp: int, out_zp: int
) -> Tensor:
"""Casts the input tensor to dtype `dtype` to produce the correct tensor meta for a _rescale op.
Additionally validates TOSA constraints of a RESCALE op.
"""
if not (dtype == torch.int32 or dtype == torch.int8):
raise NotImplementedError(
"tosa::rescale currently only supports int32 and int8."
)
if dtype == torch.int32 and out_zp != 0:
raise ValueError(
"TOSA requires output_zp to be zero when the output dtype is int32."
)
if x.dtype == torch.int32 and in_zp != 0:
raise ValueError(
"TOSA requires input_zp to be zero when the input dtype is int32."
)
if x.dtype == torch.int8 and not -128 <= in_zp <= 127:
raise ValueError(f"{in_zp=} outside valid range (-128,127) for int8.")
if dtype == torch.int8 and not -128 <= out_zp <= 127:
raise ValueError(f"{out_zp=} outside valid range (-128,127) for int8.")

return x.to(dtype=dtype).clone()


class InsertRescalePass(ExportPass):
"""Finds patterns of dq -> q, and replaces them
with passthrough_to_tosa::rescales.

Does not garantuee that the dtypes and zero points are valid
in TOSA, that is the job of the quantization annotator that
produced the dq and q nodes. The TOSA constraints are validated
in the fake implementation of passthrough_to_tosa:rescale.
"""

def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule):
dq_args = QuantArgs.from_operator(node.target, node.args)
q_args = QuantArgs.from_operator(user.target, user.args)
new_scale = dq_args.scale / q_args.scale

with graph_module.graph.inserting_before(node):
rescale_node = create_node(
graph_module.graph,
torch.ops.tosa._rescale.default,
(
node.all_input_nodes[0],
q_args.dtype,
new_scale,
dq_args.zp,
q_args.zp,
),
)
rescale_node.meta = copy(user.meta)
user.replace_all_uses_with(rescale_node)
graph_module.graph.erase_node(user)

def call(self, graph_module: GraphModule) -> PassResult:
modified = False
for node in graph_module.graph.nodes:
node = cast(Node, node)

if node.target is not dq_op:
continue
# Copy users since we remove them while iterating, modyfing the node.users list.
for user in copy(node.users):
if user.target is q_op:
self.fold_dq_q_to_rescale(node, user, graph_module)
modified = True
if len(node.users) == 0:
graph_module.graph.erase_node(node)

graph_module = super().call(graph_module).graph_module
graph_module.recompile()
return PassResult(graph_module, modified)
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
op_reciprocal,
op_relu,
op_repeat,
op_rescale,
op_rshift,
op_rsqrt,
op_sigmoid,
Expand Down
70 changes: 70 additions & 0 deletions backends/arm/operators/op_rescale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import cast, List

import executorch.backends.arm.tosa_quant_utils as tosa_quant_utils
import serializer.tosa_serializer as ts # type: ignore
import torch

import tosa.Op as TosaOp # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
from torch.fx import Node


@register_node_visitor
class RescaleVisitor(NodeVisitor):
target = "_rescale.default"

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

input_dtype = inputs[0].dtype
output_dtype = cast(torch.dtype, node.args[1])
scale = cast(float, node.args[2])
input_zp = cast(int, node.args[3])
output_zp = cast(int, node.args[4])

# Skip int16 cases for now.
if input_dtype != map_dtype(torch.int8) and input_zp != 0:
raise ValueError(
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{ts.DTypeNames[input_dtype]}, {input_zp=}"
)
if output_dtype != torch.int8 and output_zp != 0:
raise ValueError(
f"If output dtype is not int8, output_zp must be 0. Got {output_dtype=}, {output_zp=}"
)

scale_width = 32 if output_dtype == torch.int32 else 16
multiplier, shift = tosa_quant_utils.compute_multiplier_and_shift(
scale, scale_width
)
attr_rescale = ts.TosaSerializerAttribute()
attr_rescale.RescaleAttribute(
input_zp=input_zp,
output_zp=output_zp,
multiplier=[multiplier],
shift=[shift],
scale32=output_dtype == torch.int32,
double_round=False,
per_channel=False,
input_unsigned=False,
output_unsigned=False,
)

tosa_graph.addOperator(
TosaOp.Op().RESCALE, [inputs[0].name], [output.name], attr_rescale
)
38 changes: 38 additions & 0 deletions backends/arm/test/ops/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@
from typing import Tuple

import torch
from executorch.backends.arm.arm_backend import get_tosa_version
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to have broken the unittest-arm job:
https://github.com/pytorch/executorch/actions/runs/13311971065/job/37176555490#step:15:11475

_______________________ ERROR collecting ops/test_add.py _______________________
ImportError while importing test module '/pytorch/executorch/backends/arm/test/ops/test_add.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
/opt/conda/envs/py_3.10/lib/python3.10/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
backends/arm/test/ops/test_add.py:12: in <module>
    from executorch.backends.arm.arm_backend import get_tosa_version
E   ImportError: cannot import name 'get_tosa_version' from 'executorch.backends.arm.arm_backend' (/opt/conda/envs/py_3.10/lib/python3.10/site-packages/executorch/backends/arm/arm_backend.py)

Copy link
Collaborator

@zingo zingo Feb 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes i also think it did, sorry for the messup. I confirm with a revert to make sure it fixes it. If so lets merge the revert and fix/retry this PR again later.

from executorch.backends.arm.quantizer import arm_quantizer
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineBI,
EthosU85PipelineBI,
TosaPipelineBI,
TosaPipelineMI,
)
from executorch.backends.xnnpack.test.tester import Quantize
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.quantizer import QuantizationSpec


aten_op = "torch.ops.aten.add.Tensor"
exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor"
Expand Down Expand Up @@ -67,6 +73,38 @@ def test_add_tosa_BI(test_data: input_t1):
pipeline.run()


@common.parametrize("test_data", Add.test_data)
def test_add_i32_tosa_BI(test_data: input_t1):
pipeline = TosaPipelineBI[input_t1](Add(), test_data, aten_op, exir_op)

# Create a quantizer with int8 quantization on the input and output but int32 on everything else.
quantizer = arm_quantizer.ArmQuantizer(
get_tosa_version(common.get_tosa_compile_spec("TOSA-0.80+BI"))
)
quantizer.set_io(arm_quantizer.get_symmetric_quantization_config())
observer_options = {"eps": 2**-16}
observer = HistogramObserver.with_args(**observer_options)
input_act_qspec = QuantizationSpec(
torch.int32,
observer,
qscheme=torch.per_tensor_symmetric,
quant_max=2**31 - 1,
quant_min=-(2**31),
)
# This quantization_config will be set as global config.
quantization_config = arm_quantizer.QuantizationConfig(
input_act_qspec, None, None, None
)
quantize_stage = Quantize(quantizer, quantization_config)
pipeline.change_args("quantize", quantize_stage)

# Check that we get the additional (dq -> q
pipeline.add_stage_after(
"export", pipeline.tester.check_count, {"torch.ops.quantized_decomposed": 8}
)
pipeline.run()


@common.parametrize("test_data", Add.test_data)
def test_add_u55_BI(test_data: input_t1):
pipeline = EthosU55PipelineBI[input_t1](
Expand Down
Loading
Loading