Skip to content

Commit 2edce3d

Browse files
committed
Add decomposition of div op for ArmBackend
Implements pass that decomposes aten.div to aten.reciprocal and aten.mul. This is done in the Quantizer get quantization annotation on the decomposed operators. Add infra for passes in ArmQuantizer Signed-off-by: Erik Lundell <[email protected]> Change-Id: Idd1698dc5fc82ab42b68094b405fb3a08804a45e
1 parent 14cb90b commit 2edce3d

File tree

9 files changed

+150
-128
lines changed

9 files changed

+150
-128
lines changed

backends/arm/operators/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
op_cat,
1414
op_conv2d,
1515
op_dequant,
16-
op_div,
1716
op_exp,
1817
op_full,
1918
op_get_item,

backends/arm/operators/op_div.py

Lines changed: 0 additions & 43 deletions
This file was deleted.

backends/arm/passes/arm_pass_manager.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@
1515
from executorch.backends.arm.passes.convert_split_to_slice import (
1616
ConvertSplitToSlicePass,
1717
)
18+
from executorch.backends.arm.passes.decompose_div_pass import DecomposeDivPass
1819
from executorch.backends.arm.passes.meandim_to_averagepool_pass import (
1920
ConvertMeanDimToAveragePool,
2021
)
2122
from executorch.backends.arm.passes.remove_clone_pass import RemoveClonePass
23+
from executorch.backends.arm.passes.scalars_to_attribute_pass import (
24+
ScalarsToAttributePass,
25+
)
2226
from executorch.backends.arm.passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass
2327
from executorch.exir.backend.compile_spec_schema import CompileSpec
2428
from executorch.exir.pass_manager import PassManager
@@ -37,6 +41,7 @@ def transform_to_backend_pipeline(
3741
self.add_pass(RemoveClonePass())
3842
self.add_pass(ConvertExpandCopyToRepeatPass())
3943
self.add_pass(ConvertMeanDimToAveragePool())
44+
self.add_pass(DecomposeDivPass())
4045
self.add_pass(ConvertSplitToSlicePass())
4146
for spec in compile_spec:
4247
if spec.key == "permute_memory_format":
@@ -45,3 +50,8 @@ def transform_to_backend_pipeline(
4550
self.add_pass(AnnotateChannelsLastDimOrder())
4651

4752
return self._transform(graph_module)
53+
54+
def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
55+
self.add_pass(DecomposeDivPass())
56+
self.add_pass(ScalarsToAttributePass())
57+
return self._transform(graph_module)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass
10+
11+
12+
def get_div_decomposition(op) -> tuple:
13+
"""
14+
Returns the the (reciprocal_op, mul_op), where the ops depends on if
15+
the div op is in exir_ops torch.ops.aten.
16+
"""
17+
if op == exir_ops.edge.aten.div.Tensor:
18+
return (exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.mul.Tensor)
19+
if op == torch.ops.aten.div.Tensor:
20+
return (torch.ops.aten.reciprocal.default, torch.ops.aten.mul.Tensor)
21+
raise RuntimeError(f"Can't get div decomposition for op {op}")
22+
23+
24+
class DecomposeDivPass(ExportPass):
25+
"""
26+
This pass decomposes div into a mul and a reciprocal node.
27+
28+
Example:
29+
y = div(a,b)
30+
Becomes:
31+
x = reciprocal(b)
32+
y = mul(a,x)
33+
"""
34+
35+
def call_operator(self, op, args, kwargs, meta):
36+
if op not in (exir_ops.edge.aten.div.Tensor, torch.ops.aten.div.Tensor):
37+
return super().call_operator(op, args, kwargs, meta)
38+
39+
reciprocal_op, mul_op = get_div_decomposition(op)
40+
41+
numerator = args[0]
42+
denominator = args[1]
43+
reciprocal = super().call_operator(reciprocal_op, (denominator,), {}, meta)
44+
45+
return super().call_operator(mul_op, (numerator, reciprocal), {}, meta)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import cast, Union
8+
9+
import torch
10+
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
11+
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
14+
from torch.fx import GraphModule, Node
15+
16+
17+
class ScalarsToAttributePass(ExportPass):
18+
"""
19+
For ops in 'targeted_ops', convert inputs that are scalar values
20+
to attribute Nodes that output the same value.
21+
"""
22+
23+
targeted_ops = [
24+
torch.ops.aten.add.Tensor,
25+
torch.ops.aten.sub.Tensor,
26+
torch.ops.aten.mul.Tensor,
27+
]
28+
29+
def call(self, graph_module: GraphModule) -> GraphModule:
30+
for n in graph_module.graph.nodes:
31+
n = cast(Node, n)
32+
if n.op != "call_function" or n.target not in self.targeted_ops:
33+
continue
34+
35+
biggest_rank = 1
36+
for arg in n.args:
37+
if isinstance(arg, Node):
38+
_, shape, _ = extract_tensor_meta(arg.meta)
39+
biggest_rank = max(biggest_rank, len(shape))
40+
41+
new_args = []
42+
for arg in n.args:
43+
if isinstance(arg, Node):
44+
new_args.append(arg)
45+
continue
46+
47+
prefix = "_tensor_constant_"
48+
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
49+
tensor_constant_name = get_new_attr_name(graph_module)
50+
float_tensor = torch.tensor(
51+
float(cast(Union[int, float], arg))
52+
).reshape((1,) * biggest_rank)
53+
graph_module.register_buffer(tensor_constant_name, float_tensor)
54+
fake_mode = n.meta["val"].fake_mode
55+
56+
with graph_module.graph.inserting_before(n):
57+
get_attr_node = graph_module.graph.create_node(
58+
"get_attr", tensor_constant_name, (), {}
59+
)
60+
get_attr_node.meta["val"] = fake_mode.from_tensor(
61+
float_tensor, static_shapes=True
62+
)
63+
new_args.append(get_attr_node)
64+
n.args = tuple(new_args)
65+
66+
graph_module.recompile()
67+
return PassResult(graph_module, True)

backends/arm/quantizer/arm_quantizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
import torch
1919
import torch.nn.functional as F
20+
from executorch.backends.arm.passes.arm_pass_manager import ArmPassManager
2021

2122
from executorch.backends.arm.quantizer import arm_quantizer_utils
2223
from executorch.backends.arm.quantizer.arm_quantizer_utils import (
23-
convert_scalars_to_attrs,
2424
mark_nodes_as_annotated,
2525
propagate_annotation,
2626
)
@@ -315,7 +315,8 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
315315
"""An initial pass for transforming the graph to prepare it for annotation.
316316
Currently transforms scalar values to tensor attributes.
317317
"""
318-
return convert_scalars_to_attrs(model)
318+
319+
return ArmPassManager().transform_for_annotation_pipeline(graph_module=model)
319320

320321
def annotate(self, model: GraphModule) -> GraphModule:
321322
"""Performs the quantization annotation on the graph.

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@
1010
#
1111

1212
import operator
13-
from typing import Callable, cast, List, Union
13+
from typing import Callable, cast, List
1414

1515
import torch
1616
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
1717
from torch._subclasses import FakeTensor
18-
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
1918

2019
from torch.ao.quantization.quantizer import (
2120
QuantizationAnnotation,
@@ -194,42 +193,3 @@ def propagate_annotation(model: GraphModule) -> None:
194193
output_qspec=shared_qspec,
195194
_annotated=True,
196195
)
197-
198-
199-
def convert_scalars_to_attrs(model: GraphModule) -> GraphModule:
200-
"""For ops in 'targeted_ops', convert inputs that are scalar values
201-
to attribute Nodes that output the same value.
202-
#TODO Seems like this should be a pass.
203-
"""
204-
targeted_ops = [
205-
torch.ops.aten.add.Tensor,
206-
torch.ops.aten.sub.Tensor,
207-
torch.ops.aten.mul.Tensor,
208-
]
209-
for n in model.graph.nodes:
210-
n = cast(Node, n)
211-
if n.op != "call_function" or n.target not in targeted_ops:
212-
continue
213-
args = list(n.args)
214-
new_args = []
215-
for i in range(len(args)):
216-
if isinstance(args[i], Node):
217-
new_args.append(args[i])
218-
continue
219-
prefix = "_tensor_constant_"
220-
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
221-
tensor_constant_name = get_new_attr_name(model)
222-
float_tensor = torch.tensor(float(cast(Union[int, float], args[i])))
223-
model.register_buffer(tensor_constant_name, float_tensor)
224-
fake_mode = n.meta["val"].fake_mode
225-
with model.graph.inserting_before(n):
226-
get_attr_node = model.graph.create_node(
227-
"get_attr", tensor_constant_name, (), {}
228-
)
229-
get_attr_node.meta["val"] = fake_mode.from_tensor(
230-
float_tensor, static_shapes=True
231-
)
232-
new_args.append(get_attr_node)
233-
n.args = tuple(new_args)
234-
model.recompile()
235-
return model

backends/arm/quantizer/quantization_annotation/mul_annotator.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,15 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import itertools
8-
import operator
97
from typing import Callable, List, Optional
108

119
import torch
10+
import torch.fx
1211
from executorch.backends.arm.quantizer import arm_quantizer_utils
1312
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
1413
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
1514
from torch.ao.quantization.quantizer import QuantizationAnnotation
1615
from torch.fx import Node
17-
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1816

1917

2018
@register_annotator("mul")
@@ -23,14 +21,13 @@ def _annotate_mul(
2321
quantization_config: QuantizationConfig,
2422
filter_fn: Optional[Callable[[Node], bool]] = None,
2523
) -> Optional[List[List[Node]]]:
26-
mul_partitions = get_source_partitions(
27-
gm.graph, ["mul", "mul_", operator.mul, torch.mul, operator.imul], filter_fn
28-
)
29-
mul_partitions = list(itertools.chain.from_iterable(mul_partitions.values()))
24+
3025
annotated_partitions = []
31-
for mul_partition in mul_partitions:
32-
annotated_partitions.append(mul_partition.nodes)
33-
mul_node = mul_partition.output_nodes[0]
26+
for node in gm.graph.nodes:
27+
if node.target not in (torch.ops.aten.mul.Tensor,):
28+
continue
29+
mul_node = node
30+
annotated_partitions.append([mul_node])
3431
if arm_quantizer_utils.is_annotated(mul_node):
3532
continue
3633

0 commit comments

Comments
 (0)