Skip to content

Commit d0b4ed6

Browse files
Arm backend: Decompose sum in pass (#10852)
Moves the unrolling of reducing multiple indices from the sum node visitor to a new DecomposeSumPass. KeepDimsFalseToSqueezePass is merged into the new pass to decompose the sum op fully in one pass. This change introduces new rescales for each reduced dim, requiring decomposition before quantization to get proper quantization parameters. Signed-off-by: Adrian Lundell <[email protected]>
1 parent 518324f commit d0b4ed6

File tree

5 files changed

+156
-205
lines changed

5 files changed

+156
-205
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
3333
from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa
3434
from .decompose_sqrt_pass import DecomposeSqrtPass # noqa
35+
from .decompose_sum_pass import DecomposeSumPass # noqa
3536
from .decompose_var_pass import DecomposeVarPass # noqa
3637
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
3738
FoldAndAnnotateQParamsPass,
@@ -44,7 +45,6 @@
4445
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
4546
from .insert_rescales_pass import InsertRescalePass # noqa
4647
from .insert_table_ops import InsertTableOpsPass # noqa
47-
from .keep_dims_false_to_squeeze_pass import KeepDimsFalseToSqueezePass # noqa
4848
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
4949
from .match_where_self_arg_dtype_pass import MatchWhereSelfDtypePass # noqa
5050
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
DecomposeSoftmaxPass,
3737
DecomposeSoftmaxUnstablePass,
3838
DecomposeSqrtPass,
39+
DecomposeSumPass,
3940
DecomposeVarPass,
4041
FoldAndAnnotateQParamsPass,
4142
FuseBatchnorm2DPass,
@@ -44,7 +45,6 @@
4445
FuseQuantizedActivationPass,
4546
InsertRescalePass,
4647
InsertTableOpsPass,
47-
KeepDimsFalseToSqueezePass,
4848
MatchArgRanksPass,
4949
MatchWhereSelfDtypePass,
5050
QuantizeOperatorArguments,
@@ -109,7 +109,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
109109
self.add_pass(ConvertExpandCopyToRepeatPass())
110110
self.add_pass(UnsqueezeBeforeRepeatPass())
111111
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
112-
self.add_pass(KeepDimsFalseToSqueezePass())
112+
self.add_pass(DecomposeSumPass())
113113
self.add_pass(Conv1dUnsqueezePass(exported_program))
114114
self.add_pass(DecomposeSelectPass())
115115
self.add_pass(ConvertSqueezesToViewPass())
@@ -161,7 +161,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
161161
self.add_pass(ConvertExpandCopyToRepeatPass())
162162
self.add_pass(UnsqueezeBeforeRepeatPass())
163163
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
164-
self.add_pass(KeepDimsFalseToSqueezePass())
164+
self.add_pass(DecomposeSumPass())
165165
self.add_pass(Conv1dUnsqueezePass(exported_program))
166166
self.add_pass(DecomposeSelectPass())
167167
self.add_pass(ConvertSqueezesToViewPass())
@@ -218,4 +218,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
218218

219219
self.add_pass(ConvertMinMaxPass())
220220
self.add_pass(ReplaceInfValues())
221+
self.add_pass(DecomposeSumPass())
222+
221223
return self._transform(graph_module)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.dialects._ops import ops as exir_ops
8+
from executorch.exir.pass_base import ExportPass
9+
10+
11+
def _get_sum_decomp(op):
12+
match op:
13+
case exir_ops.edge.aten.sum.dim_IntList:
14+
return (
15+
exir_ops.edge.aten.view_copy.default,
16+
exir_ops.edge.aten.sum.dim_IntList,
17+
)
18+
case torch.ops.aten.sum.dim_IntList:
19+
return (torch.ops.aten.view_copy.default, torch.ops.aten.sum.dim_IntList)
20+
case _:
21+
raise RuntimeError("Unvalid op in DecomposeSumPass")
22+
23+
24+
class DecomposeSumPass(ExportPass):
25+
"""
26+
In Pytorch, the default behaviour of for example Tensor.sum is to squeeze the
27+
dimension that is summed (keep_dim = False). However, in TOSA, REDUCE_SUM always
28+
preserves the rank of the input (keep_dim = True). To get a 1-1 mapping in the sum
29+
lowering, normalize the keep_dim = False case to keep_dim = True and lower the rank
30+
with a view op.
31+
32+
Since TOSA can only reduce one dimension at a time, multiple dims are additionally
33+
unrolled into multiple ops.
34+
35+
Original:
36+
sum((dim_1, dim_2), keep_dim = False) -> squeezed_shape
37+
After pass:
38+
sum(dim_1, keep_dim = True) -> unsqueezed_shape
39+
sum(dim_2, keep_dim = True) -> unsqueezed_shape
40+
view(shape = squeezed_shape) -> squeezed_shape
41+
"""
42+
43+
def call_operator(self, op, args, kwargs, meta):
44+
if op not in [
45+
exir_ops.edge.aten.sum.dim_IntList,
46+
torch.ops.aten.sum.dim_IntList,
47+
]:
48+
return super().call_operator(op, args, kwargs, meta)
49+
50+
match len(args):
51+
case 3:
52+
(
53+
input_node,
54+
dims,
55+
keepdims,
56+
) = args
57+
case 2:
58+
(
59+
input_node,
60+
dims,
61+
) = args
62+
keepdims = False
63+
case _:
64+
raise ValueError(f"Invalid number of arguments ({len(args)}) provided.")
65+
66+
view_op, sum_op = _get_sum_decomp(op)
67+
68+
for dim in dims:
69+
input_node = super().call_operator(
70+
sum_op, (input_node, dim, True), kwargs, meta
71+
)
72+
73+
if not keepdims:
74+
shape = list(meta["val"].size())
75+
input_node = super().call_operator(
76+
view_op, (input_node, shape), kwargs, meta
77+
)
78+
79+
return input_node

backends/arm/_passes/keep_dims_false_to_squeeze_pass.py

Lines changed: 0 additions & 92 deletions
This file was deleted.

0 commit comments

Comments
 (0)