Skip to content

Commit c5fdebd

Browse files
Erik-Lundellfacebook-github-bot
authored andcommitted
Add div decomposition in ArmQuantizer (#5267)
Summary: Suggestion or what a pass to decompose a div in ArmQuantizer might look like. Pull Request resolved: #5267 Reviewed By: mergennachin, manuelcandales Differential Revision: D62874421 Pulled By: digantdesai fbshipit-source-id: 423242d31a524970129ff78bfee1d1c5011a0d42
1 parent f8cec53 commit c5fdebd

File tree

12 files changed

+351
-90
lines changed

12 files changed

+351
-90
lines changed

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5757
exir_ops.edge.aten.sigmoid.default,
5858
exir_ops.edge.aten.mm.default,
5959
exir_ops.edge.aten.repeat.default,
60+
exir_ops.edge.aten.reciprocal.default,
6061
exir_ops.edge.aten.relu.default,
6162
exir_ops.edge.aten.rsqrt.default,
6263
exir_ops.edge.aten._softmax.default,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
op_cat,
1616
op_conv2d,
1717
op_dequant,
18-
op_div,
1918
op_exp,
2019
op_full,
2120
op_get_item,
@@ -26,6 +25,7 @@
2625
op_mul,
2726
op_permute,
2827
op_quant,
28+
op_reciprocal,
2929
op_relu,
3030
op_repeat,
3131
op_rsqrt,
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import List
6+
7+
import numpy as np
8+
9+
import serializer.tosa_serializer as ts
10+
import torch
11+
from executorch.backends.arm.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.tosa_mapping import TosaArg
16+
from executorch.backends.arm.tosa_quant_utils import (
17+
dequantize_value,
18+
get_quant_node_args,
19+
QuantArgs,
20+
quantize_value,
21+
)
22+
from serializer.tosa_serializer import TosaOp
23+
24+
25+
@register_node_visitor
26+
class DivVisitor(NodeVisitor):
27+
target = "aten.reciprocal.default"
28+
29+
def __init__(self, *args):
30+
super().__init__(*args)
31+
32+
def define_node(
33+
self,
34+
node: torch.fx.Node,
35+
tosa_graph: ts.TosaSerializer,
36+
inputs: List[TosaArg],
37+
output: TosaArg,
38+
is_quant_node: bool,
39+
) -> None:
40+
# 1/X
41+
42+
if is_quant_node:
43+
input = inputs[0]
44+
input_qargs = get_quant_node_args(node.all_input_nodes[0])
45+
output_qargs = get_quant_node_args(list(node.users)[0])
46+
47+
div_table = div_table_8bit(input_qargs, output_qargs)
48+
49+
table_attr = ts.TosaSerializerAttribute()
50+
table_attr.TableAttribute(div_table)
51+
tosa_graph.addOperator(
52+
TosaOp.Op().TABLE, [input.name], [output.name], table_attr
53+
)
54+
55+
else:
56+
tosa_graph.addOperator(
57+
TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name]
58+
)
59+
60+
61+
def div_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs):
62+
"""
63+
Returns a table mapping 256 entries to div([qmin,qmax])
64+
"""
65+
66+
def div(x):
67+
# Convert quantized input to floating point div input space.
68+
v1 = dequantize_value(x, in_quantargs)
69+
# Compute div.
70+
v2 = 1.0 / v1
71+
# Convert div output back to quantized space.
72+
v3 = quantize_value(v2, out_quantargs)
73+
74+
return v3
75+
76+
return [
77+
div(x)
78+
for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8)
79+
]

backends/arm/passes/arm_pass_manager.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717
from executorch.backends.arm.passes.convert_split_to_slice import (
1818
ConvertSplitToSlicePass,
1919
)
20+
from executorch.backends.arm.passes.decompose_div_pass import DecomposeDivPass
2021
from executorch.backends.arm.passes.meandim_to_averagepool_pass import (
2122
ConvertMeanDimToAveragePool,
2223
)
2324
from executorch.backends.arm.passes.remove_clone_pass import RemoveClonePass
25+
from executorch.backends.arm.passes.scalars_to_attribute_pass import (
26+
ScalarsToAttributePass,
27+
)
2428
from executorch.backends.arm.passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass
2529
from executorch.exir import ExportedProgram
2630
from executorch.exir.backend.compile_spec_schema import CompileSpec
@@ -40,6 +44,7 @@ def transform_to_backend_pipeline(
4044
self.add_pass(RemoveClonePass())
4145
self.add_pass(ConvertExpandCopyToRepeatPass())
4246
self.add_pass(ConvertMeanDimToAveragePool())
47+
self.add_pass(DecomposeDivPass())
4348
self.add_pass(ConvertSplitToSlicePass())
4449
for spec in compile_spec:
4550
if spec.key == "permute_memory_format":
@@ -48,3 +53,8 @@ def transform_to_backend_pipeline(
4853
self.add_pass(AnnotateChannelsLastDimOrder())
4954

5055
return self._transform(exported_program.graph_module)
56+
57+
def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
58+
self.add_pass(DecomposeDivPass())
59+
self.add_pass(ScalarsToAttributePass())
60+
return self._transform(graph_module)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass
10+
11+
12+
def get_div_decomposition(op) -> tuple:
13+
"""
14+
Returns the the (reciprocal_op, mul_op), where the ops depends on if
15+
the div op is in exir_ops torch.ops.aten.
16+
"""
17+
if op == exir_ops.edge.aten.div.Tensor:
18+
return (exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.mul.Tensor)
19+
if op == torch.ops.aten.div.Tensor:
20+
return (torch.ops.aten.reciprocal.default, torch.ops.aten.mul.Tensor)
21+
raise RuntimeError(f"Can't get div decomposition for op {op}")
22+
23+
24+
class DecomposeDivPass(ExportPass):
25+
"""
26+
This pass decomposes div into a mul and a reciprocal node.
27+
28+
Example:
29+
y = div(a,b)
30+
Becomes:
31+
x = reciprocal(b)
32+
y = mul(a,x)
33+
"""
34+
35+
def call_operator(self, op, args, kwargs, meta):
36+
if op not in (exir_ops.edge.aten.div.Tensor, torch.ops.aten.div.Tensor):
37+
return super().call_operator(op, args, kwargs, meta)
38+
39+
reciprocal_op, mul_op = get_div_decomposition(op)
40+
41+
numerator = args[0]
42+
denominator = args[1]
43+
reciprocal = super().call_operator(reciprocal_op, (denominator,), {}, meta)
44+
45+
return super().call_operator(mul_op, (numerator, reciprocal), {}, meta)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import cast, Union
8+
9+
import torch
10+
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
11+
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
14+
from torch.fx import GraphModule, Node
15+
16+
17+
class ScalarsToAttributePass(ExportPass):
18+
"""
19+
For ops in 'targeted_ops', convert inputs that are scalar values
20+
to attribute Nodes that output the same value.
21+
"""
22+
23+
targeted_ops = [
24+
torch.ops.aten.add.Tensor,
25+
torch.ops.aten.sub.Tensor,
26+
torch.ops.aten.sub_.Tensor,
27+
torch.ops.aten.mul.Tensor,
28+
torch.ops.aten.div.Tensor,
29+
]
30+
31+
def call(self, graph_module: GraphModule) -> PassResult:
32+
for n in graph_module.graph.nodes:
33+
n = cast(Node, n)
34+
if n.op != "call_function" or n.target not in self.targeted_ops:
35+
continue
36+
37+
biggest_rank = 1
38+
for arg in n.args:
39+
if isinstance(arg, Node):
40+
_, shape, _ = extract_tensor_meta(arg.meta)
41+
biggest_rank = max(biggest_rank, len(shape))
42+
43+
new_args = []
44+
for arg in n.args:
45+
if isinstance(arg, Node):
46+
new_args.append(arg)
47+
continue
48+
49+
prefix = "_tensor_constant_"
50+
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
51+
tensor_constant_name = get_new_attr_name(graph_module)
52+
float_tensor = torch.tensor(
53+
float(cast(Union[int, float], arg))
54+
).reshape((1,) * biggest_rank)
55+
graph_module.register_buffer(tensor_constant_name, float_tensor)
56+
fake_mode = n.meta["val"].fake_mode
57+
58+
with graph_module.graph.inserting_before(n):
59+
get_attr_node = graph_module.graph.create_node(
60+
"get_attr", tensor_constant_name, (), {}
61+
)
62+
get_attr_node.meta["val"] = fake_mode.from_tensor(
63+
float_tensor, static_shapes=True
64+
)
65+
new_args.append(get_attr_node)
66+
n.args = tuple(new_args)
67+
68+
graph_module.recompile()
69+
return PassResult(graph_module, True)

backends/arm/quantizer/arm_quantizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919

2020
import torch
2121
import torch.nn.functional as F
22+
from executorch.backends.arm.passes.arm_pass_manager import ArmPassManager
2223

2324
from executorch.backends.arm.quantizer import arm_quantizer_utils
2425
from executorch.backends.arm.quantizer.arm_quantizer_utils import (
25-
convert_scalars_to_attrs,
2626
mark_nodes_as_annotated,
2727
propagate_annotation,
2828
)
@@ -318,7 +318,8 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
318318
"""An initial pass for transforming the graph to prepare it for annotation.
319319
Currently transforms scalar values to tensor attributes.
320320
"""
321-
return convert_scalars_to_attrs(model)
321+
322+
return ArmPassManager().transform_for_annotation_pipeline(graph_module=model)
322323

323324
def annotate(self, model: GraphModule) -> GraphModule:
324325
"""Performs the quantization annotation on the graph.

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
#
1313

1414
import operator
15-
from typing import Callable, cast, List, Union
15+
from typing import Callable, cast, List
1616

1717
import torch
1818
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
1919
from torch._subclasses import FakeTensor
20-
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
2120

2221
from torch.ao.quantization.quantizer import (
2322
QuantizationAnnotation,
@@ -199,42 +198,3 @@ def propagate_annotation(model: GraphModule) -> None:
199198
output_qspec=shared_qspec,
200199
_annotated=True,
201200
)
202-
203-
204-
def convert_scalars_to_attrs(model: GraphModule) -> GraphModule:
205-
"""For ops in 'targeted_ops', convert inputs that are scalar values
206-
to attribute Nodes that output the same value.
207-
#TODO Seems like this should be a pass.
208-
"""
209-
targeted_ops = [
210-
torch.ops.aten.add.Tensor,
211-
torch.ops.aten.sub.Tensor,
212-
torch.ops.aten.mul.Tensor,
213-
]
214-
for n in model.graph.nodes:
215-
n = cast(Node, n)
216-
if n.op != "call_function" or n.target not in targeted_ops:
217-
continue
218-
args = list(n.args)
219-
new_args = []
220-
for i in range(len(args)):
221-
if isinstance(args[i], Node):
222-
new_args.append(args[i])
223-
continue
224-
prefix = "_tensor_constant_"
225-
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
226-
tensor_constant_name = get_new_attr_name(model)
227-
float_tensor = torch.tensor(float(cast(Union[int, float], args[i])))
228-
model.register_buffer(tensor_constant_name, float_tensor)
229-
fake_mode = n.meta["val"].fake_mode
230-
with model.graph.inserting_before(n):
231-
get_attr_node = model.graph.create_node(
232-
"get_attr", tensor_constant_name, (), {}
233-
)
234-
get_attr_node.meta["val"] = fake_mode.from_tensor(
235-
float_tensor, static_shapes=True
236-
)
237-
new_args.append(get_attr_node)
238-
n.args = tuple(new_args)
239-
model.recompile()
240-
return model

backends/arm/quantizer/quantization_annotation/mul_annotator.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,15 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# pyre-unsafe
8-
9-
import itertools
10-
import operator
117
from typing import Callable, List, Optional
128

139
import torch
10+
import torch.fx
1411
from executorch.backends.arm.quantizer import arm_quantizer_utils
1512
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
1613
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
1714
from torch.ao.quantization.quantizer import QuantizationAnnotation
1815
from torch.fx import Node
19-
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
2016

2117

2218
@register_annotator("mul")
@@ -25,14 +21,13 @@ def _annotate_mul(
2521
quantization_config: QuantizationConfig,
2622
filter_fn: Optional[Callable[[Node], bool]] = None,
2723
) -> Optional[List[List[Node]]]:
28-
mul_partitions = get_source_partitions(
29-
gm.graph, ["mul", "mul_", operator.mul, torch.mul, operator.imul], filter_fn
30-
)
31-
mul_partitions = list(itertools.chain.from_iterable(mul_partitions.values()))
24+
3225
annotated_partitions = []
33-
for mul_partition in mul_partitions:
34-
annotated_partitions.append(mul_partition.nodes)
35-
mul_node = mul_partition.output_nodes[0]
26+
for node in gm.graph.nodes:
27+
if node.target not in (torch.ops.aten.mul.Tensor,):
28+
continue
29+
mul_node = node
30+
annotated_partitions.append([mul_node])
3631
if arm_quantizer_utils.is_annotated(mul_node):
3732
continue
3833

backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,12 @@ def _annotate_one_to_one(
3535
Typical ops are ops implemented with a lookup table.
3636
"""
3737
annotated_partitions = []
38-
one_to_one_ops = {
38+
one_to_one_ops = (
3939
torch.ops.aten.exp.default,
4040
torch.ops.aten.log.default,
41+
torch.ops.aten.reciprocal.default,
4142
torch.ops.aten.rsqrt.default,
42-
}
43+
)
4344
for node in gm.graph.nodes:
4445
if node.op != "call_function" or node.target not in one_to_one_ops:
4546
continue

0 commit comments

Comments
 (0)