Skip to content

Commit 7b5d925

Browse files
committed
Add reduce_sum op to ArmBackend
Adds node visitor, tests and annotator Change-Id: I002f5203e855b9489fc5e67095ec30b2b3ce0077
1 parent 83c95df commit 7b5d925

File tree

9 files changed

+359
-0
lines changed

9 files changed

+359
-0
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
ConvertSplitToSlicePass,
2020
)
2121
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
22+
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
23+
InsertSqueezeAfterSumPass,
24+
)
2225
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
2326
ConvertMeanDimToAveragePool,
2427
)
@@ -47,6 +50,7 @@ def transform_to_backend_pipeline(
4750
self.add_pass(ConvertExpandCopyToRepeatPass())
4851
self.add_pass(ConvertMeanDimToAveragePool())
4952
self.add_pass(DecomposeDivPass())
53+
self.add_pass(InsertSqueezeAfterSumPass())
5054
self.add_pass(ConvertSplitToSlicePass())
5155
for spec in compile_spec:
5256
if spec.key == "permute_memory_format":
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import cast
8+
9+
import torch
10+
import torch.fx
11+
from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_q_dq_pair
12+
13+
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, is_quant_node
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.pass_base import ExportPass, PassResult
16+
17+
18+
class InsertSqueezeAfterSumPass(ExportPass):
19+
"""
20+
In Pytorch, the default behaviour of Tensor.sum is to squeeze
21+
the dimension that is summed (keep_dim = False).
22+
However, in TOSA, REDUCE_SUM always preserves the
23+
rank of the input (keep_dim = True).
24+
To get a 1-1 mapping in the sum lowering, normalize the
25+
keep_dim = False case to keep_dim = True and add squeeze ops.
26+
27+
Original:
28+
sum(dims, keep_dim = False)
29+
After pass:
30+
sum(dims, keep_dim = True)
31+
(q)
32+
(dq)
33+
squeeze(dim = dims)
34+
"""
35+
36+
def call(self, graph_module: torch.fx.GraphModule):
37+
for node in graph_module.graph.nodes:
38+
if node.op != "call_function":
39+
continue
40+
if node.target != exir_ops.edge.aten.sum.dim_IntList:
41+
continue
42+
sum_node = cast(torch.fx.Node, node)
43+
keep_dim = cast(bool, sum_node.args[2] if len(sum_node.args) > 2 else False)
44+
if keep_dim:
45+
continue
46+
47+
dim_list = cast(list[int], sum_node.args[1])
48+
quantized = is_quant_node(sum_node)
49+
if quantized:
50+
qparams = get_quant_node_args(sum_node.all_input_nodes[0])
51+
qparams = qparams + (torch.int8,)
52+
else:
53+
qparams = None
54+
55+
# Add keep_dim = True arg to sum node.
56+
sum_node.args = sum_node.args[0:2] + (True,)
57+
58+
with graph_module.graph.inserting_after(sum_node):
59+
squeeze_node = create_node(
60+
graph_module.graph, exir_ops.edge.aten.squeeze_copy.dims, ()
61+
)
62+
sum_node.replace_all_uses_with(squeeze_node)
63+
squeeze_node.args = (sum_node, dim_list)
64+
if quantized:
65+
sum_node = insert_q_dq_pair(graph_module.graph, sum_node, qparams)
66+
graph_module.graph.eliminate_dead_code()
67+
graph_module.recompile()
68+
graph_module = super().call(graph_module).graph_module
69+
return PassResult(graph_module, True)

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
6363
exir_ops.edge.aten._softmax.default,
6464
exir_ops.edge.aten.slice_copy.Tensor,
6565
exir_ops.edge.aten.sub.Tensor,
66+
exir_ops.edge.aten.sum.dim_IntList,
6667
exir_ops.edge.aten.view_copy.default,
6768
exir_ops.edge.aten.clone.default,
6869
exir_ops.edge.aten.mean.dim,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
op_softmax,
3535
op_squeeze,
3636
op_sub,
37+
op_sum,
3738
op_unsqueeze,
3839
op_view,
3940
)

backends/arm/operators/op_sum.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import cast, List
7+
8+
import executorch.backends.arm.tosa_quant_utils as tqutils
9+
import executorch.backends.arm.tosa_utils as tutils
10+
11+
import serializer.tosa_serializer as ts
12+
from executorch.backends.arm.operators.node_visitor import (
13+
NodeVisitor,
14+
register_node_visitor,
15+
)
16+
from executorch.backends.arm.tosa_mapping import TosaArg
17+
from serializer.tosa_serializer import TosaOp
18+
from torch.fx import Node
19+
20+
21+
@register_node_visitor
22+
class AddVisitor(NodeVisitor):
23+
target = "aten.sum.dim_IntList"
24+
25+
def __init__(self, *args):
26+
super().__init__(*args)
27+
28+
def define_node(
29+
self,
30+
node: Node,
31+
tosa_graph: ts.TosaSerializer,
32+
inputs: List[TosaArg],
33+
output: TosaArg,
34+
is_quant_node: bool,
35+
) -> None:
36+
input_node = inputs[0]
37+
input_shape = list(input_node.shape)
38+
dim_list = cast(list[int], inputs[1].special)
39+
dim_list = [dim % len(input_node.shape) for dim in dim_list]
40+
keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False)
41+
assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass"
42+
43+
if is_quant_node:
44+
45+
# Rescale input to 32 bit
46+
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
47+
[node.all_input_nodes[0]], tosa_graph
48+
)
49+
50+
prev_node = rescaled_inputs[0]
51+
reduced_shape = input_shape
52+
53+
# Reduce all dims in dim_list one-by-one.
54+
for dim in dim_list:
55+
# When reduced, the size of the dim becomes 1.
56+
reduced_shape[dim] = 1
57+
58+
attr = ts.TosaSerializerAttribute()
59+
attr.AxisAttribute(input_node.dim_order.index(dim))
60+
61+
next_node = tosa_graph.addIntermediate(
62+
tutils.tosa_shape(reduced_shape, input_node.dim_order),
63+
dtype=ts.DType.INT32,
64+
)
65+
66+
tosa_graph.addOperator(
67+
TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr
68+
)
69+
70+
prev_node = next_node
71+
tqutils.rescale_node_back_to_int8(node, prev_node, scale, tosa_graph)
72+
else:
73+
input_name = input_node.name
74+
reduced_shape = input_shape
75+
76+
# Reduce all dims in dim_list one-by-one.
77+
for dim in dim_list:
78+
# When reduced, the size of the dim becomes 1
79+
reduced_shape[dim] = 1
80+
81+
attr = ts.TosaSerializerAttribute()
82+
attr.AxisAttribute(input_node.dim_order.index(dim))
83+
84+
if dim == dim_list[-1]:
85+
output_name = output.name
86+
else:
87+
output_name = tosa_graph.addIntermediate(
88+
tutils.tosa_shape(reduced_shape, input_node.dim_order),
89+
dtype=ts.DType.FP32,
90+
).name
91+
92+
tosa_graph.addOperator(
93+
TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr
94+
)
95+
96+
input_name = output_name

backends/arm/quantizer/arm_quantizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ class ArmQuantizer(Quantizer):
272272
"cat",
273273
"one_to_one",
274274
"generic",
275+
"sum",
275276
]
276277

277278
def __init__(self) -> None:

backends/arm/quantizer/quantization_annotation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,5 @@ def decorator(annotator: AnnotatorType):
6161
one_to_one_annotator,
6262
sigmoid_annotator,
6363
sub_annotator,
64+
sum_annotator,
6465
)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Callable, cast, List, Optional
7+
8+
import torch
9+
from executorch.backends.arm.quantizer import arm_quantizer_utils
10+
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
11+
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
12+
13+
from torch.ao.quantization.quantizer import (
14+
QuantizationAnnotation,
15+
QuantizationSpecBase,
16+
SharedQuantizationSpec,
17+
)
18+
from torch.fx import Node
19+
20+
21+
@register_annotator("sum")
22+
def _annotate_sum(
23+
gm: torch.fx.GraphModule,
24+
quantization_config: QuantizationConfig,
25+
filter_fn: Optional[Callable[[Node], bool]] = None,
26+
) -> Optional[List[List[Node]]]:
27+
annotated_partitions = []
28+
for node in gm.graph.nodes:
29+
if node.target is not torch.ops.aten.sum.dim_IntList:
30+
continue
31+
if filter_fn and not filter_fn(node):
32+
continue
33+
34+
sum_node = node
35+
if arm_quantizer_utils.is_annotated(sum_node):
36+
continue
37+
38+
input_act = sum_node.args[0]
39+
40+
if not isinstance(input_act, Node):
41+
continue
42+
if not arm_quantizer_utils.is_input_ok_for_quantization(input_act, gm):
43+
continue
44+
45+
input_act_qspec = cast(
46+
Optional[QuantizationSpecBase], quantization_config.get_input_act_qspec()
47+
)
48+
input_qspec_map = {input_act: input_act_qspec}
49+
shared_with_input0_qspec = SharedQuantizationSpec((input_act, sum_node))
50+
51+
sum_node.meta["quantization_annotation"] = QuantizationAnnotation(
52+
input_qspec_map=input_qspec_map,
53+
output_qspec=shared_with_input0_qspec,
54+
_annotated=True,
55+
)
56+
annotated_partitions.append([sum_node])
57+
return annotated_partitions

0 commit comments

Comments
 (0)