Skip to content

Add div decomposition in ArmQuantizer #5267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.reciprocal.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten._softmax.default,
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
op_cat,
op_conv2d,
op_dequant,
op_div,
op_exp,
op_full,
op_get_item,
Expand All @@ -26,6 +25,7 @@
op_mul,
op_permute,
op_quant,
op_reciprocal,
op_relu,
op_repeat,
op_rsqrt,
Expand Down
79 changes: 79 additions & 0 deletions backends/arm/operators/op_reciprocal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List

import numpy as np

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import (
dequantize_value,
get_quant_node_args,
QuantArgs,
quantize_value,
)
from serializer.tosa_serializer import TosaOp


@register_node_visitor
class DivVisitor(NodeVisitor):
target = "aten.reciprocal.default"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
# 1/X

if is_quant_node:
input = inputs[0]
input_qargs = get_quant_node_args(node.all_input_nodes[0])
output_qargs = get_quant_node_args(list(node.users)[0])

div_table = div_table_8bit(input_qargs, output_qargs)

table_attr = ts.TosaSerializerAttribute()
table_attr.TableAttribute(div_table)
tosa_graph.addOperator(
TosaOp.Op().TABLE, [input.name], [output.name], table_attr
)

else:
tosa_graph.addOperator(
TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name]
)


def div_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs):
"""
Returns a table mapping 256 entries to div([qmin,qmax])
"""

def div(x):
# Convert quantized input to floating point div input space.
v1 = dequantize_value(x, in_quantargs)
# Compute div.
v2 = 1.0 / v1
# Convert div output back to quantized space.
v3 = quantize_value(v2, out_quantargs)

return v3

return [
div(x)
for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8)
]
10 changes: 10 additions & 0 deletions backends/arm/passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
from executorch.backends.arm.passes.convert_split_to_slice import (
ConvertSplitToSlicePass,
)
from executorch.backends.arm.passes.decompose_div_pass import DecomposeDivPass
from executorch.backends.arm.passes.meandim_to_averagepool_pass import (
ConvertMeanDimToAveragePool,
)
from executorch.backends.arm.passes.remove_clone_pass import RemoveClonePass
from executorch.backends.arm.passes.scalars_to_attribute_pass import (
ScalarsToAttributePass,
)
from executorch.backends.arm.passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass
from executorch.exir import ExportedProgram
from executorch.exir.backend.compile_spec_schema import CompileSpec
Expand All @@ -40,6 +44,7 @@ def transform_to_backend_pipeline(
self.add_pass(RemoveClonePass())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(ConvertMeanDimToAveragePool())
self.add_pass(DecomposeDivPass())
self.add_pass(ConvertSplitToSlicePass())
for spec in compile_spec:
if spec.key == "permute_memory_format":
Expand All @@ -48,3 +53,8 @@ def transform_to_backend_pipeline(
self.add_pass(AnnotateChannelsLastDimOrder())

return self._transform(exported_program.graph_module)

def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
self.add_pass(DecomposeDivPass())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the motivation for this API is clear and this is allowed, however I do want to make sure modifying graph in the Quantizer is something we are OK with from general ET lowering point of view. I also see SDPA as a valid use case.

cc @kimishpatel, @mergennachin - any comments?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok - I'll resolve the merge conflict and mark this as ready for review tomorrow unless I get new comments.

self.add_pass(ScalarsToAttributePass())
return self._transform(graph_module)
45 changes: 45 additions & 0 deletions backends/arm/passes/decompose_div_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


def get_div_decomposition(op) -> tuple:
"""
Returns the the (reciprocal_op, mul_op), where the ops depends on if
the div op is in exir_ops torch.ops.aten.
"""
if op == exir_ops.edge.aten.div.Tensor:
return (exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.mul.Tensor)
if op == torch.ops.aten.div.Tensor:
return (torch.ops.aten.reciprocal.default, torch.ops.aten.mul.Tensor)
raise RuntimeError(f"Can't get div decomposition for op {op}")


class DecomposeDivPass(ExportPass):
"""
This pass decomposes div into a mul and a reciprocal node.

Example:
y = div(a,b)
Becomes:
x = reciprocal(b)
y = mul(a,x)
"""

def call_operator(self, op, args, kwargs, meta):
if op not in (exir_ops.edge.aten.div.Tensor, torch.ops.aten.div.Tensor):
return super().call_operator(op, args, kwargs, meta)

reciprocal_op, mul_op = get_div_decomposition(op)

numerator = args[0]
denominator = args[1]
reciprocal = super().call_operator(reciprocal_op, (denominator,), {}, meta)

return super().call_operator(mul_op, (numerator, reciprocal), {}, meta)
69 changes: 69 additions & 0 deletions backends/arm/passes/scalars_to_attribute_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast, Union

import torch
from executorch.backends.arm.tosa_mapping import extract_tensor_meta

from executorch.exir.pass_base import ExportPass, PassResult
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
from torch.fx import GraphModule, Node


class ScalarsToAttributePass(ExportPass):
"""
For ops in 'targeted_ops', convert inputs that are scalar values
to attribute Nodes that output the same value.
"""

targeted_ops = [
torch.ops.aten.add.Tensor,
torch.ops.aten.sub.Tensor,
torch.ops.aten.sub_.Tensor,
torch.ops.aten.mul.Tensor,
torch.ops.aten.div.Tensor,
]

def call(self, graph_module: GraphModule) -> PassResult:
for n in graph_module.graph.nodes:
n = cast(Node, n)
if n.op != "call_function" or n.target not in self.targeted_ops:
continue

biggest_rank = 1
for arg in n.args:
if isinstance(arg, Node):
_, shape, _ = extract_tensor_meta(arg.meta)
biggest_rank = max(biggest_rank, len(shape))

new_args = []
for arg in n.args:
if isinstance(arg, Node):
new_args.append(arg)
continue

prefix = "_tensor_constant_"
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
tensor_constant_name = get_new_attr_name(graph_module)
float_tensor = torch.tensor(
float(cast(Union[int, float], arg))
).reshape((1,) * biggest_rank)
graph_module.register_buffer(tensor_constant_name, float_tensor)
fake_mode = n.meta["val"].fake_mode

with graph_module.graph.inserting_before(n):
get_attr_node = graph_module.graph.create_node(
"get_attr", tensor_constant_name, (), {}
)
get_attr_node.meta["val"] = fake_mode.from_tensor(
float_tensor, static_shapes=True
)
new_args.append(get_attr_node)
n.args = tuple(new_args)

graph_module.recompile()
return PassResult(graph_module, True)
5 changes: 3 additions & 2 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

import torch
import torch.nn.functional as F
from executorch.backends.arm.passes.arm_pass_manager import ArmPassManager

from executorch.backends.arm.quantizer import arm_quantizer_utils
from executorch.backends.arm.quantizer.arm_quantizer_utils import (
convert_scalars_to_attrs,
mark_nodes_as_annotated,
propagate_annotation,
)
Expand Down Expand Up @@ -317,7 +317,8 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
"""An initial pass for transforming the graph to prepare it for annotation.
Currently transforms scalar values to tensor attributes.
"""
return convert_scalars_to_attrs(model)

return ArmPassManager().transform_for_annotation_pipeline(graph_module=model)

def annotate(self, model: GraphModule) -> GraphModule:
"""Performs the quantization annotation on the graph.
Expand Down
42 changes: 1 addition & 41 deletions backends/arm/quantizer/arm_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
#

import operator
from typing import Callable, cast, List, Union
from typing import Callable, cast, List

import torch
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from torch._subclasses import FakeTensor
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix

from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
Expand Down Expand Up @@ -199,42 +198,3 @@ def propagate_annotation(model: GraphModule) -> None:
output_qspec=shared_qspec,
_annotated=True,
)


def convert_scalars_to_attrs(model: GraphModule) -> GraphModule:
"""For ops in 'targeted_ops', convert inputs that are scalar values
to attribute Nodes that output the same value.
#TODO Seems like this should be a pass.
"""
targeted_ops = [
torch.ops.aten.add.Tensor,
torch.ops.aten.sub.Tensor,
torch.ops.aten.mul.Tensor,
]
for n in model.graph.nodes:
n = cast(Node, n)
if n.op != "call_function" or n.target not in targeted_ops:
continue
args = list(n.args)
new_args = []
for i in range(len(args)):
if isinstance(args[i], Node):
new_args.append(args[i])
continue
prefix = "_tensor_constant_"
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
tensor_constant_name = get_new_attr_name(model)
float_tensor = torch.tensor(float(cast(Union[int, float], args[i])))
model.register_buffer(tensor_constant_name, float_tensor)
fake_mode = n.meta["val"].fake_mode
with model.graph.inserting_before(n):
get_attr_node = model.graph.create_node(
"get_attr", tensor_constant_name, (), {}
)
get_attr_node.meta["val"] = fake_mode.from_tensor(
float_tensor, static_shapes=True
)
new_args.append(get_attr_node)
n.args = tuple(new_args)
model.recompile()
return model
19 changes: 7 additions & 12 deletions backends/arm/quantizer/quantization_annotation/mul_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import itertools
import operator
from typing import Callable, List, Optional

import torch
import torch.fx
from executorch.backends.arm.quantizer import arm_quantizer_utils
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from torch.ao.quantization.quantizer import QuantizationAnnotation
from torch.fx import Node
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


@register_annotator("mul")
Expand All @@ -25,14 +21,13 @@ def _annotate_mul(
quantization_config: QuantizationConfig,
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
mul_partitions = get_source_partitions(
gm.graph, ["mul", "mul_", operator.mul, torch.mul, operator.imul], filter_fn
)
mul_partitions = list(itertools.chain.from_iterable(mul_partitions.values()))

annotated_partitions = []
for mul_partition in mul_partitions:
annotated_partitions.append(mul_partition.nodes)
mul_node = mul_partition.output_nodes[0]
for node in gm.graph.nodes:
if node.target not in (torch.ops.aten.mul.Tensor,):
continue
mul_node = node
annotated_partitions.append([mul_node])
if arm_quantizer_utils.is_annotated(mul_node):
continue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ def _annotate_one_to_one(
Typical ops are ops implemented with a lookup table.
"""
annotated_partitions = []
one_to_one_ops = {
one_to_one_ops = (
torch.ops.aten.exp.default,
torch.ops.aten.log.default,
torch.ops.aten.reciprocal.default,
torch.ops.aten.rsqrt.default,
}
)
for node in gm.graph.nodes:
if node.op != "call_function" or node.target not in one_to_one_ops:
continue
Expand Down
Loading
Loading