Skip to content

Commit aea1b4a

Browse files
committed
Add reduce_sum op to ArmBackend
Adds node visitor, tests and annotator Change-Id: I002f5203e855b9489fc5e67095ec30b2b3ce0077
1 parent a6b213b commit aea1b4a

File tree

9 files changed

+355
-0
lines changed

9 files changed

+355
-0
lines changed

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
6363
exir_ops.edge.aten._softmax.default,
6464
exir_ops.edge.aten.slice_copy.Tensor,
6565
exir_ops.edge.aten.sub.Tensor,
66+
exir_ops.edge.aten.sum.dim_IntList,
6667
exir_ops.edge.aten.view_copy.default,
6768
exir_ops.edge.aten.clone.default,
6869
exir_ops.edge.aten.mean.dim,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
op_softmax,
3535
op_squeeze,
3636
op_sub,
37+
op_sum,
3738
op_unsqueeze,
3839
op_view,
3940
)

backends/arm/operators/op_sum.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import cast, List
7+
8+
import executorch.backends.arm.tosa_quant_utils as tqutils
9+
import executorch.backends.arm.tosa_utils as tutils
10+
11+
import serializer.tosa_serializer as ts
12+
from executorch.backends.arm.operators.node_visitor import (
13+
NodeVisitor,
14+
register_node_visitor,
15+
)
16+
from executorch.backends.arm.tosa_mapping import TosaArg
17+
from serializer.tosa_serializer import TosaOp
18+
from torch.fx import Node
19+
20+
21+
@register_node_visitor
22+
class AddVisitor(NodeVisitor):
23+
target = "aten.sum.dim_IntList"
24+
25+
def __init__(self, *args):
26+
super().__init__(*args)
27+
28+
def define_node(
29+
self,
30+
node: Node,
31+
tosa_graph: ts.TosaSerializer,
32+
inputs: List[TosaArg],
33+
output: TosaArg,
34+
is_quant_node: bool,
35+
) -> None:
36+
input_node = inputs[0]
37+
input_shape = list(input_node.shape)
38+
dim_list = cast(list[int], inputs[1].special)
39+
dim_list = [dim % len(input_node.shape) for dim in dim_list]
40+
keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False)
41+
assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass"
42+
43+
if is_quant_node:
44+
45+
# Rescale input to 32 bit
46+
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
47+
[node.all_input_nodes[0]], tosa_graph
48+
)
49+
50+
prev_node = rescaled_inputs[0]
51+
reduced_shape = input_shape
52+
53+
# Reduce all dims in dim_list one-by-one.
54+
for dim in dim_list:
55+
# When reduced, the size of the dim becomes 1.
56+
reduced_shape[dim] = 1
57+
58+
attr = ts.TosaSerializerAttribute()
59+
attr.AxisAttribute(input_node.dim_order.index(dim))
60+
61+
next_node = tosa_graph.addIntermediate(
62+
tutils.tosa_shape(reduced_shape, input_node.dim_order),
63+
dtype=ts.DType.INT32,
64+
)
65+
66+
tosa_graph.addOperator(
67+
TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr
68+
)
69+
70+
prev_node = next_node
71+
tqutils.rescale_node_back_to_int8(node, prev_node, scale, tosa_graph)
72+
else:
73+
input_name = input_node.name
74+
reduced_shape = input_shape
75+
76+
# Reduce all dims in dim_list one-by-one.
77+
for dim in dim_list:
78+
# When reduced, the size of the dim becomes 1
79+
reduced_shape[dim] = 1
80+
81+
attr = ts.TosaSerializerAttribute()
82+
attr.AxisAttribute(input_node.dim_order.index(dim))
83+
84+
if dim == dim_list[-1]:
85+
output_name = output.name
86+
else:
87+
output_name = tosa_graph.addIntermediate(
88+
tutils.tosa_shape(reduced_shape, input_node.dim_order),
89+
dtype=ts.DType.FP32,
90+
).name
91+
92+
tosa_graph.addOperator(
93+
TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr
94+
)
95+
96+
input_name = output_name

backends/arm/passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
ConvertSplitToSlicePass,
1919
)
2020
from executorch.backends.arm.passes.decompose_div_pass import DecomposeDivPass
21+
from executorch.backends.arm.passes.insert_squeeze_after_sum_pass import (
22+
InsertSqueezeAfterSumPass,
23+
)
2124
from executorch.backends.arm.passes.meandim_to_averagepool_pass import (
2225
ConvertMeanDimToAveragePool,
2326
)
@@ -45,6 +48,7 @@ def transform_to_backend_pipeline(
4548
self.add_pass(ConvertExpandCopyToRepeatPass())
4649
self.add_pass(ConvertMeanDimToAveragePool())
4750
self.add_pass(DecomposeDivPass())
51+
self.add_pass(InsertSqueezeAfterSumPass())
4852
self.add_pass(ConvertSplitToSlicePass())
4953
for spec in compile_spec:
5054
if spec.key == "permute_memory_format":
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import cast
8+
9+
import torch
10+
import torch.fx
11+
from executorch.backends.arm.passes.arm_pass_utils import create_node, insert_q_dq_pair
12+
13+
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, is_quant_node
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.pass_base import ExportPass, PassResult
16+
17+
18+
class InsertSqueezeAfterSumPass(ExportPass):
19+
"""
20+
In Pytorch, the default behaviour of Tensor.sum is to squeeze
21+
the dimension that is summed (keep_dim = False).
22+
However, in TOSA, REDUCE_SUM always preserves the
23+
rank of the input (keep_dim = True).
24+
To get a 1-1 mapping in the sum lowering, normalize the
25+
keep_dim = False case to keep_dim = True and add squeeze ops.
26+
27+
Original:
28+
sum(dims, keep_dim = False)
29+
After pass:
30+
sum(dims, keep_dim = True)
31+
(q)
32+
(dq)
33+
squeeze(dim = dims)
34+
"""
35+
36+
def call(self, graph_module: torch.fx.GraphModule):
37+
for node in graph_module.graph.nodes:
38+
if node.op != "call_function":
39+
continue
40+
if node.target != exir_ops.edge.aten.sum.dim_IntList:
41+
continue
42+
sum_node = cast(torch.fx.Node, node)
43+
keep_dim = cast(bool, sum_node.args[2] if len(sum_node.args) > 2 else False)
44+
if keep_dim:
45+
continue
46+
47+
dim_list = cast(list[int], sum_node.args[1])
48+
quantized = is_quant_node(sum_node)
49+
if quantized:
50+
qparams = get_quant_node_args(sum_node.all_input_nodes[0])
51+
qparams = qparams + (torch.int8,)
52+
else:
53+
qparams = None
54+
55+
# Add keep_dim = True arg to sum node.
56+
sum_node.args = sum_node.args[0:2] + (True,)
57+
58+
with graph_module.graph.inserting_after(sum_node):
59+
squeeze_node = create_node(
60+
graph_module.graph, exir_ops.edge.aten.squeeze_copy.dims, ()
61+
)
62+
sum_node.replace_all_uses_with(squeeze_node)
63+
squeeze_node.args = (sum_node, dim_list)
64+
if quantized:
65+
sum_node = insert_q_dq_pair(graph_module.graph, sum_node, qparams)
66+
graph_module.graph.eliminate_dead_code()
67+
graph_module.recompile()
68+
graph_module = super().call(graph_module).graph_module
69+
return PassResult(graph_module, True)

backends/arm/quantizer/arm_quantizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ class ArmQuantizer(Quantizer):
272272
"cat",
273273
"one_to_one",
274274
"generic",
275+
"sum",
275276
]
276277

277278
def __init__(self) -> None:

backends/arm/quantizer/quantization_annotation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,5 @@ def decorator(annotator: AnnotatorType):
6161
one_to_one_annotator,
6262
sigmoid_annotator,
6363
sub_annotator,
64+
sum_annotator,
6465
)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Callable, List, Optional
7+
8+
import torch
9+
from executorch.backends.arm.quantizer import arm_quantizer_utils
10+
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
11+
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
12+
from torch.ao.quantization.quantizer import (
13+
QuantizationAnnotation,
14+
SharedQuantizationSpec,
15+
)
16+
from torch.fx import Node
17+
18+
19+
@register_annotator("sum")
20+
def _annotate_sum(
21+
gm: torch.fx.GraphModule,
22+
quantization_config: QuantizationConfig,
23+
filter_fn: Optional[Callable[[Node], bool]] = None,
24+
) -> Optional[List[List[Node]]]:
25+
annotated_partitions = []
26+
for node in gm.graph.nodes:
27+
if node.target is not torch.ops.aten.sum.dim_IntList:
28+
continue
29+
if filter_fn and not filter_fn(node):
30+
continue
31+
32+
sum_node = node
33+
if arm_quantizer_utils.is_annotated(sum_node):
34+
continue
35+
36+
input_act = sum_node.args[0]
37+
38+
if not isinstance(input_act, Node):
39+
continue
40+
if not arm_quantizer_utils.is_input_ok_for_quantization(input_act, gm):
41+
continue
42+
43+
input_act_qspec = quantization_config.get_input_act_qspec()
44+
input_qspec_map = {input_act: input_act_qspec}
45+
shared_with_input0_qspec = SharedQuantizationSpec((input_act, sum_node))
46+
47+
sum_node.meta["quantization_annotation"] = QuantizationAnnotation(
48+
input_qspec_map=input_qspec_map,
49+
output_qspec=shared_with_input0_qspec,
50+
_annotated=True,
51+
)
52+
annotated_partitions.append([sum_node])
53+
return annotated_partitions

backends/arm/test/ops/test_sum.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from typing import Tuple
10+
11+
import torch
12+
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
14+
from executorch.exir import EdgeCompileConfig
15+
from executorch.exir.backend.compile_spec_schema import CompileSpec
16+
from parameterized import parameterized
17+
18+
exampledata_t = Tuple[torch.Tensor, int | list[int], bool]
19+
"""(data, dim(s), keepdim)"""
20+
21+
22+
class TestSum(unittest.TestCase):
23+
"""Tests sum which sums all elements along some specified dimensions.
24+
keepdim specifies whether the dimension that is summed should
25+
be squeezed or not.
26+
"""
27+
28+
class Sum(torch.nn.Module):
29+
test_parameters: list[Tuple[exampledata_t]] = [
30+
((torch.rand(10), 0, True),),
31+
((torch.rand(10, 10), 1, False),),
32+
((torch.rand(10, 10, 10), [-3, 1], True),),
33+
((torch.rand(2, 1, 5, 8), 1, False),),
34+
((torch.rand(1, 2, 3, 4), 3, True),),
35+
((torch.rand(1, 2, 8, 8), [2, 3, 0], True),),
36+
]
37+
38+
def forward(self, x: torch.Tensor, dim: int, keepdim: bool):
39+
return x.sum(dim=dim, keepdim=keepdim)
40+
41+
_edge_compile_config: EdgeCompileConfig = EdgeCompileConfig(
42+
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
43+
)
44+
45+
def _test_sum_tosa_MI_pipeline(
46+
self, module: torch.nn.Module, test_data: tuple[exampledata_t]
47+
):
48+
(
49+
ArmTester(
50+
module,
51+
example_inputs=test_data,
52+
compile_spec=common.get_tosa_compile_spec(),
53+
)
54+
.export()
55+
.check_count({"torch.ops.aten.sum.dim_IntList": 1})
56+
.check_not(["torch.ops.quantized_decomposed"])
57+
.to_edge(config=self._edge_compile_config)
58+
.partition()
59+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
60+
.to_executorch()
61+
.run_method_and_compare_outputs(inputs=test_data)
62+
)
63+
64+
def _test_sum_tosa_BI_pipeline(
65+
self, module: torch.nn.Module, test_data: tuple[exampledata_t]
66+
):
67+
(
68+
ArmTester(
69+
module,
70+
example_inputs=test_data,
71+
compile_spec=common.get_tosa_compile_spec(),
72+
)
73+
.quantize()
74+
.export()
75+
.check_count({"torch.ops.aten.sum.dim_IntList": 1})
76+
.check(["torch.ops.quantized_decomposed"])
77+
.to_edge(config=self._edge_compile_config)
78+
.partition()
79+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
80+
.to_executorch()
81+
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
82+
)
83+
84+
def _test_sum_ethosu_BI_pipeline(
85+
self,
86+
module: torch.nn.Module,
87+
test_data: tuple[exampledata_t],
88+
compile_spec: CompileSpec,
89+
):
90+
(
91+
ArmTester(
92+
module,
93+
example_inputs=test_data,
94+
compile_spec=compile_spec,
95+
)
96+
.quantize()
97+
.export()
98+
.check_count({"torch.ops.aten.sum.dim_IntList": 1})
99+
.check(["torch.ops.quantized_decomposed"])
100+
.to_edge()
101+
.partition()
102+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
103+
.to_executorch()
104+
.serialize()
105+
)
106+
107+
@parameterized.expand(Sum.test_parameters)
108+
def test_sum_tosa_MI(self, test_data: tuple[exampledata_t]):
109+
self._test_sum_tosa_MI_pipeline(self.Sum(), test_data)
110+
111+
@parameterized.expand(Sum.test_parameters)
112+
def test_sum_tosa_BI(self, test_data: tuple[exampledata_t]):
113+
self._test_sum_tosa_BI_pipeline(self.Sum(), test_data)
114+
115+
@parameterized.expand(Sum.test_parameters)
116+
def test_sum_u55_BI(self, test_data: tuple[exampledata_t]):
117+
self._test_sum_ethosu_BI_pipeline(
118+
self.Sum(),
119+
test_data,
120+
common.get_u55_compile_spec(permute_memory_to_nhwc=False),
121+
)
122+
123+
@parameterized.expand(Sum.test_parameters)
124+
def test_sum_u85_BI(self, test_data: tuple[exampledata_t]):
125+
self._test_sum_ethosu_BI_pipeline(
126+
self.Sum(),
127+
test_data,
128+
common.get_u85_compile_spec(permute_memory_to_nhwc=True),
129+
)

0 commit comments

Comments
 (0)