Skip to content

Commit af8f22b

Browse files
committed
Add decomposition of div op for ArmBackend
Implements pass that decomposes aten.div to aten.reciprocal and aten.mul. This is done in the Quantizer get quantization annotation on the decomposed operators. Add infra for passes in ArmQuantizer Signed-off-by: Erik Lundell <[email protected]> Change-Id: Idd1698dc5fc82ab42b68094b405fb3a08804a45e
1 parent 03c78e1 commit af8f22b

File tree

8 files changed

+152
-87
lines changed

8 files changed

+152
-87
lines changed

backends/arm/operators/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
op_cat,
1616
op_conv2d,
1717
op_dequant,
18-
op_div,
1918
op_exp,
2019
op_full,
2120
op_get_item,

backends/arm/passes/arm_pass_manager.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717
from executorch.backends.arm.passes.convert_split_to_slice import (
1818
ConvertSplitToSlicePass,
1919
)
20+
from executorch.backends.arm.passes.decompose_div_pass import DecomposeDivPass
2021
from executorch.backends.arm.passes.meandim_to_averagepool_pass import (
2122
ConvertMeanDimToAveragePool,
2223
)
2324
from executorch.backends.arm.passes.remove_clone_pass import RemoveClonePass
25+
from executorch.backends.arm.passes.scalars_to_attribute_pass import (
26+
ScalarsToAttributePass,
27+
)
2428
from executorch.backends.arm.passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass
2529
from executorch.exir.backend.compile_spec_schema import CompileSpec
2630
from executorch.exir.pass_manager import PassManager
@@ -39,6 +43,7 @@ def transform_to_backend_pipeline(
3943
self.add_pass(RemoveClonePass())
4044
self.add_pass(ConvertExpandCopyToRepeatPass())
4145
self.add_pass(ConvertMeanDimToAveragePool())
46+
self.add_pass(DecomposeDivPass())
4247
self.add_pass(ConvertSplitToSlicePass())
4348
for spec in compile_spec:
4449
if spec.key == "permute_memory_format":
@@ -47,3 +52,8 @@ def transform_to_backend_pipeline(
4752
self.add_pass(AnnotateChannelsLastDimOrder())
4853

4954
return self._transform(graph_module)
55+
56+
def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
57+
self.add_pass(DecomposeDivPass())
58+
self.add_pass(ScalarsToAttributePass())
59+
return self._transform(graph_module)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass
10+
11+
12+
def get_div_decomposition(op) -> tuple:
13+
"""
14+
Returns the the (reciprocal_op, mul_op), where the ops depends on if
15+
the div op is in exir_ops torch.ops.aten.
16+
"""
17+
if op == exir_ops.edge.aten.div.Tensor:
18+
return (exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.mul.Tensor)
19+
if op == torch.ops.aten.div.Tensor:
20+
return (torch.ops.aten.reciprocal.default, torch.ops.aten.mul.Tensor)
21+
raise RuntimeError(f"Can't get div decomposition for op {op}")
22+
23+
24+
class DecomposeDivPass(ExportPass):
25+
"""
26+
This pass decomposes div into a mul and a reciprocal node.
27+
28+
Example:
29+
y = div(a,b)
30+
Becomes:
31+
x = reciprocal(b)
32+
y = mul(a,x)
33+
"""
34+
35+
def call_operator(self, op, args, kwargs, meta):
36+
if op not in (exir_ops.edge.aten.div.Tensor, torch.ops.aten.div.Tensor):
37+
return super().call_operator(op, args, kwargs, meta)
38+
39+
reciprocal_op, mul_op = get_div_decomposition(op)
40+
41+
numerator = args[0]
42+
denominator = args[1]
43+
reciprocal = super().call_operator(reciprocal_op, (denominator,), {}, meta)
44+
45+
return super().call_operator(mul_op, (numerator, reciprocal), {}, meta)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import cast, Union
8+
9+
import torch
10+
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
11+
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
14+
from torch.fx import GraphModule, Node
15+
16+
17+
class ScalarsToAttributePass(ExportPass):
18+
"""
19+
For ops in 'targeted_ops', convert inputs that are scalar values
20+
to attribute Nodes that output the same value.
21+
"""
22+
23+
targeted_ops = [
24+
torch.ops.aten.add.Tensor,
25+
torch.ops.aten.sub.Tensor,
26+
torch.ops.aten.sub_.Tensor,
27+
torch.ops.aten.mul.Tensor,
28+
torch.ops.aten.div.Tensor,
29+
]
30+
31+
def call(self, graph_module: GraphModule) -> GraphModule:
32+
for n in graph_module.graph.nodes:
33+
n = cast(Node, n)
34+
if n.op != "call_function" or n.target not in self.targeted_ops:
35+
continue
36+
37+
biggest_rank = 1
38+
for arg in n.args:
39+
if isinstance(arg, Node):
40+
_, shape, _ = extract_tensor_meta(arg.meta)
41+
biggest_rank = max(biggest_rank, len(shape))
42+
43+
new_args = []
44+
for arg in n.args:
45+
if isinstance(arg, Node):
46+
new_args.append(arg)
47+
continue
48+
49+
prefix = "_tensor_constant_"
50+
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
51+
tensor_constant_name = get_new_attr_name(graph_module)
52+
float_tensor = torch.tensor(
53+
float(cast(Union[int, float], arg))
54+
).reshape((1,) * biggest_rank)
55+
graph_module.register_buffer(tensor_constant_name, float_tensor)
56+
fake_mode = n.meta["val"].fake_mode
57+
58+
with graph_module.graph.inserting_before(n):
59+
get_attr_node = graph_module.graph.create_node(
60+
"get_attr", tensor_constant_name, (), {}
61+
)
62+
get_attr_node.meta["val"] = fake_mode.from_tensor(
63+
float_tensor, static_shapes=True
64+
)
65+
new_args.append(get_attr_node)
66+
n.args = tuple(new_args)
67+
68+
graph_module.recompile()
69+
return PassResult(graph_module, True)

backends/arm/quantizer/arm_quantizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919

2020
import torch
2121
import torch.nn.functional as F
22+
from executorch.backends.arm.passes.arm_pass_manager import ArmPassManager
2223

2324
from executorch.backends.arm.quantizer import arm_quantizer_utils
2425
from executorch.backends.arm.quantizer.arm_quantizer_utils import (
25-
convert_scalars_to_attrs,
2626
mark_nodes_as_annotated,
2727
propagate_annotation,
2828
)
@@ -317,7 +317,8 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
317317
"""An initial pass for transforming the graph to prepare it for annotation.
318318
Currently transforms scalar values to tensor attributes.
319319
"""
320-
return convert_scalars_to_attrs(model)
320+
321+
return ArmPassManager().transform_for_annotation_pipeline(graph_module=model)
321322

322323
def annotate(self, model: GraphModule) -> GraphModule:
323324
"""Performs the quantization annotation on the graph.

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
#
1313

1414
import operator
15-
from typing import Callable, cast, List, Union
15+
from typing import Callable, cast, List
1616

1717
import torch
1818
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
1919
from torch._subclasses import FakeTensor
20-
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
2120

2221
from torch.ao.quantization.quantizer import (
2322
QuantizationAnnotation,
@@ -196,42 +195,3 @@ def propagate_annotation(model: GraphModule) -> None:
196195
output_qspec=shared_qspec,
197196
_annotated=True,
198197
)
199-
200-
201-
def convert_scalars_to_attrs(model: GraphModule) -> GraphModule:
202-
"""For ops in 'targeted_ops', convert inputs that are scalar values
203-
to attribute Nodes that output the same value.
204-
#TODO Seems like this should be a pass.
205-
"""
206-
targeted_ops = [
207-
torch.ops.aten.add.Tensor,
208-
torch.ops.aten.sub.Tensor,
209-
torch.ops.aten.mul.Tensor,
210-
]
211-
for n in model.graph.nodes:
212-
n = cast(Node, n)
213-
if n.op != "call_function" or n.target not in targeted_ops:
214-
continue
215-
args = list(n.args)
216-
new_args = []
217-
for i in range(len(args)):
218-
if isinstance(args[i], Node):
219-
new_args.append(args[i])
220-
continue
221-
prefix = "_tensor_constant_"
222-
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
223-
tensor_constant_name = get_new_attr_name(model)
224-
float_tensor = torch.tensor(float(cast(Union[int, float], args[i])))
225-
model.register_buffer(tensor_constant_name, float_tensor)
226-
fake_mode = n.meta["val"].fake_mode
227-
with model.graph.inserting_before(n):
228-
get_attr_node = model.graph.create_node(
229-
"get_attr", tensor_constant_name, (), {}
230-
)
231-
get_attr_node.meta["val"] = fake_mode.from_tensor(
232-
float_tensor, static_shapes=True
233-
)
234-
new_args.append(get_attr_node)
235-
n.args = tuple(new_args)
236-
model.recompile()
237-
return model

backends/arm/quantizer/quantization_annotation/mul_annotator.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,15 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# pyre-unsafe
8-
9-
import itertools
10-
import operator
117
from typing import Callable, List, Optional
128

139
import torch
10+
import torch.fx
1411
from executorch.backends.arm.quantizer import arm_quantizer_utils
1512
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
1613
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
1714
from torch.ao.quantization.quantizer import QuantizationAnnotation
1815
from torch.fx import Node
19-
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
2016

2117

2218
@register_annotator("mul")
@@ -25,14 +21,13 @@ def _annotate_mul(
2521
quantization_config: QuantizationConfig,
2622
filter_fn: Optional[Callable[[Node], bool]] = None,
2723
) -> Optional[List[List[Node]]]:
28-
mul_partitions = get_source_partitions(
29-
gm.graph, ["mul", "mul_", operator.mul, torch.mul, operator.imul], filter_fn
30-
)
31-
mul_partitions = list(itertools.chain.from_iterable(mul_partitions.values()))
24+
3225
annotated_partitions = []
33-
for mul_partition in mul_partitions:
34-
annotated_partitions.append(mul_partition.nodes)
35-
mul_node = mul_partition.output_nodes[0]
26+
for node in gm.graph.nodes:
27+
if node.target not in (torch.ops.aten.mul.Tensor,):
28+
continue
29+
mul_node = node
30+
annotated_partitions.append([mul_node])
3631
if arm_quantizer_utils.is_annotated(mul_node):
3732
continue
3833

backends/arm/test/ops/test_div.py

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
),
2929
(
3030
"op_div_rank1_rand",
31-
torch.rand(5),
32-
torch.rand(5),
31+
torch.rand(5) * 5,
32+
torch.rand(5) * 5,
3333
None,
3434
),
3535
(
@@ -70,8 +70,8 @@
7070
),
7171
(
7272
"op_div_rank4_large_randn",
73-
200 * torch.randn(5, 10, 25, 20),
74-
torch.rand(5, 10, 25, 20),
73+
200 * torch.randn(5, 10, 25, 20) + 1,
74+
torch.rand(5, 10, 25, 20) + 1,
7575
None,
7676
),
7777
]
@@ -81,26 +81,18 @@ class TestDiv(unittest.TestCase):
8181
"""Tests division"""
8282

8383
class Div(torch.nn.Module):
84-
def __init__(
85-
self,
86-
input_: Union[torch.Tensor, torch.types.Number],
87-
other_: Union[torch.Tensor, torch.types.Number],
88-
rounding_mode: Optional[str] = None,
89-
):
90-
super().__init__()
91-
self.rounding_mode = rounding_mode
9284

9385
def forward(
9486
self,
9587
input_: Union[torch.Tensor, torch.types.Number],
9688
other_: Union[torch.Tensor, torch.types.Number],
9789
rounding_mode: Optional[str] = None,
9890
):
99-
if self.rounding_mode is None:
91+
if rounding_mode is None:
10092
return torch.div(input=input_, other=other_)
10193
else:
10294
return torch.div(
103-
input=input_, other=other_, rounding_mode=self.rounding_mode
95+
input=input_, other=other_, rounding_mode=rounding_mode
10496
)
10597

10698
def _test_div_tosa_MI_pipeline(
@@ -133,13 +125,15 @@ def _test_div_tosa_BI_pipeline(
133125
)
134126
.quantize()
135127
.export()
136-
.check_count({"torch.ops.aten.div.Tensor": 1})
128+
.check_count(
129+
{"torch.ops.aten.reciprocal.default": 1, "torch.ops.aten.mul.Tensor": 1}
130+
)
137131
.check(["torch.ops.quantized_decomposed"])
138132
.to_edge()
139133
.partition()
140134
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
141135
.to_executorch()
142-
.run_method_and_compare_outputs(inputs=test_data)
136+
.run_method_and_compare_outputs(inputs=test_data, atol=1, rtol=0.1)
143137
)
144138

145139
def _test_div_u55_BI_pipeline(
@@ -153,7 +147,9 @@ def _test_div_u55_BI_pipeline(
153147
)
154148
.quantize()
155149
.export()
156-
.check_count({"torch.ops.aten.div.Tensor": 1})
150+
.check_count(
151+
{"torch.ops.aten.reciprocal.default": 1, "torch.ops.aten.mul.Tensor": 1}
152+
)
157153
.check(["torch.ops.quantized_decomposed"])
158154
.to_edge()
159155
.partition()
@@ -170,14 +166,9 @@ def test_div_tosa_MI(
170166
rounding_mode: Optional[str] = None,
171167
):
172168
test_data = (input_, other_)
173-
self._test_div_tosa_MI_pipeline(
174-
self.Div(input_, other_, rounding_mode=rounding_mode), test_data
175-
)
169+
self._test_div_tosa_MI_pipeline(self.Div(), test_data)
176170

177-
# Expected to fail since ArmQuantizer cannot quantize a Div layer
178-
# TODO(MLETORCH-129)
179171
@parameterized.expand(test_data_suite)
180-
@unittest.expectedFailure
181172
def test_div_tosa_BI(
182173
self,
183174
test_name: str,
@@ -187,12 +178,9 @@ def test_div_tosa_BI(
187178
):
188179

189180
test_data = (input_, other_)
190-
self._test_div_tosa_BI_pipeline(
191-
self.Div(input=input_, other=other_, rounding_mode=rounding_mode), test_data
192-
)
181+
self._test_div_tosa_BI_pipeline(self.Div(), test_data)
193182

194-
# Expected to fail since ArmQuantizer cannot quantize a Div layer
195-
# TODO(MLETORCH-129)
183+
# Fails due to Vela error.
196184
@parameterized.expand(test_data_suite)
197185
@unittest.expectedFailure
198186
def test_div_u55_BI(
@@ -203,6 +191,4 @@ def test_div_u55_BI(
203191
rounding_mode: Optional[str] = None,
204192
):
205193
test_data = (input_, other_)
206-
self._test_div_u55_BI_pipeline(
207-
self.Div(input=input_, other=other_, rounding_mode=rounding_mode), test_data
208-
)
194+
self._test_div_u55_BI_pipeline(self.Div(), test_data)

0 commit comments

Comments
 (0)