Skip to content

Commit 28a8954

Browse files
YufengShi-duduper
authored andcommitted
Arm backend: Add TOSA support for any.default, any.dim and any.dims
1. Implement a pass ConvertAnyDefaultDimDimsPass to decompose any.default, any.dim and any.dims into a sequence of any.dim with keepdim=True and a squeeze_copy.dims if needed 2. Implement a NodeVisitor to lower any.dim to REDUCE_ANY in TOSA 3. Fix the failures in #9128 Change-Id: Ifb6672f2c017cd7365e76319795290a36909657c Signed-off-by: Yufeng Shi <[email protected]>
1 parent 42d3952 commit 28a8954

File tree

8 files changed

+360
-4
lines changed

8 files changed

+360
-4
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
)
1616
from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass
1717
from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass
18+
from executorch.backends.arm._passes.convert_any_default_dim_dims_pass import (
19+
ConvertAnyDefaultDimDimsPass,
20+
)
1821
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
1922
ConvertExpandCopyToRepeatPass,
2023
)
@@ -110,6 +113,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
110113
self.add_pass(ConvertFullLikeToFullPass())
111114
self.add_pass(ConvertToClampPass())
112115
self.add_pass(ConvertMinMaxPass())
116+
self.add_pass(ConvertAnyDefaultDimDimsPass())
113117

114118
self.add_pass(ReplaceScalarWithTensorArgPass())
115119
self.add_pass(AnnotateDecomposedMatmulPass())
@@ -155,6 +159,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
155159
self.add_pass(ConvertFullLikeToFullPass())
156160
self.add_pass(ConvertToClampPass())
157161
self.add_pass(ConvertMinMaxPass())
162+
self.add_pass(ConvertAnyDefaultDimDimsPass())
158163

159164
self.add_pass(AnnotateDecomposedMatmulPass())
160165
self.add_pass(QuantizeOperatorArguments())
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.dialects._ops import ( # type: ignore[import-not-found]
8+
ops as exir_ops,
9+
)
10+
from executorch.exir.pass_base import ( # type: ignore[import-not-found]
11+
ExportPass,
12+
PassResult,
13+
)
14+
15+
16+
class ConvertAnyDefaultDimDimsPass(ExportPass):
17+
"""
18+
Converts any.default, any.dim and any.dims to a sequence of any.dim by unrolling multi-dimensional reduction.
19+
Please refer to KeepDimsFalseToSqueezePass for an explanation of this coversion.
20+
21+
Example 1
22+
Original:
23+
any() # x.shape: [dim1, dim2, ..., dimn]
24+
After pass:
25+
any.dim(dim1, keepdim = True)
26+
any.dim(dim2, keepdim = True)
27+
...
28+
any.dim(dimn, keepdim = True)
29+
squeeze(dim = [dim1, dim2, ...., dimn])
30+
31+
Example 2
32+
Original:
33+
any.dim(dim1, keepdim = False)
34+
After pass:
35+
any.dim(dim1, keepdim = True)
36+
squeeze(dim = [dim1])
37+
38+
Example 3
39+
Original:
40+
any.dims([dim1, dim2], keepdim = False)
41+
After pass:
42+
any.dim(dim1, keepdim = True)
43+
any.dim(dim2, keepdim = True)
44+
squeeze(dim = [dim1, dim2])
45+
"""
46+
47+
def call(self, graph_module: torch.fx.GraphModule):
48+
modified = False
49+
for node in graph_module.graph.nodes:
50+
if node.op != "call_function":
51+
continue
52+
if node.target not in [
53+
exir_ops.edge.aten.any.default,
54+
exir_ops.edge.aten.any.dim,
55+
exir_ops.edge.aten.any.dims,
56+
]:
57+
continue
58+
59+
if len(node.args) == 1:
60+
# any.default(input)
61+
input_node = (node.args)[0]
62+
dims = range(len(input_node.meta["val"].shape))
63+
keepdim = False
64+
elif len(node.args) == 2:
65+
# any.dim/dims(input, dims=dims)
66+
input_node, dims = node.args
67+
keepdim = False
68+
elif len(node.args) == 3:
69+
# any.dim/dims(input, dims=dims, keepdim=keepdim)
70+
input_node, dims, keepdim = node.args
71+
else:
72+
raise RuntimeError(
73+
f"Unexpected arg size {len(node.args)} in {node.name}"
74+
)
75+
try:
76+
iter(dims)
77+
except:
78+
dims = [dims] # type: ignore[assignment]
79+
else:
80+
dims = list(dims) # type: ignore[assignment]
81+
82+
# Unroll multi-dimensional reduction and keep-dims arg
83+
with graph_module.graph.inserting_before(node):
84+
for dim in dims:
85+
args = (input_node, dim, True)
86+
input_node = graph_module.graph.create_node(
87+
"call_function", exir_ops.edge.aten.any.dim, args, node.kwargs
88+
)
89+
90+
if not keepdim:
91+
args = (input_node, dims) # type: ignore[assignment]
92+
input_node = graph_module.graph.create_node(
93+
"call_function",
94+
exir_ops.edge.aten.squeeze_copy.dims,
95+
args,
96+
)
97+
98+
node.replace_all_uses_with(input_node)
99+
modified = True
100+
101+
if modified:
102+
graph_module.graph.eliminate_dead_code()
103+
graph_module.recompile()
104+
graph_module = super().call(graph_module).graph_module
105+
106+
return PassResult(graph_module, modified)

backends/arm/_passes/keep_dims_false_to_squeeze_pass.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,15 @@ class KeepDimsFalseToSqueezePass(ExportPass):
3535
"""
3636

3737
# CURRENTLY NOT HANDLED OPS
38-
# exir_ops.edge.aten.any.dim,
39-
# exir_ops.edge.aten.any.dims,
4038
# exir_ops.edge.aten.argmax,
4139
# exir_ops.edge.aten.argmin,
4240
# exir_ops.edge.aten.prod.dim_int,
4341

4442
# HANDLED OPS
4543
# exir_ops.edge.aten.sum.dim_IntList
44+
# exir_ops.edge.aten.any.default (decomposed in convert_any_default_dim_dims_pass)
45+
# exir_ops.edge.aten.any.dim (decomposed in convert_any_default_dim_dims_pass)
46+
# exir_ops.edge.aten.any.dims (decomposed in convert_any_default_dim_dims_pass)
4647
# exir_ops.edge.aten.max.dim (decomposed in convert_minmax_pass)
4748
# exir_ops.edge.aten.min.dim (decomposed in convert_minmax_pass)
4849
# exir_ops.edge.aten.amin (decomposed in convert_minmax_pass)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ def is_node_supported(
112112
supported = node.op == "call_function" and node.target in [
113113
exir_ops.edge.aten.abs.default,
114114
exir_ops.edge.aten.add.Tensor,
115+
exir_ops.edge.aten.any.default,
116+
exir_ops.edge.aten.any.dim,
117+
exir_ops.edge.aten.any.dims,
115118
exir_ops.edge.aten.logical_and.default,
116119
exir_ops.edge.aten.logical_or.default,
117120
exir_ops.edge.aten.logical_xor.default,
@@ -194,6 +197,9 @@ def is_node_supported(
194197
) -> bool:
195198
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
196199
unsupported_ops = [
200+
exir_ops.edge.aten.any.default,
201+
exir_ops.edge.aten.any.dim,
202+
exir_ops.edge.aten.any.dims,
197203
exir_ops.edge.aten.bitwise_and.Tensor,
198204
exir_ops.edge.aten.bitwise_or.Tensor,
199205
exir_ops.edge.aten.bitwise_xor.Tensor,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
op_add,
1212
op_amax,
1313
op_amin,
14+
op_any,
1415
op_avg_pool2d,
1516
op_bmm,
1617
op_cat,

backends/arm/operators/op_any.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import cast, List
8+
9+
import serializer.tosa_serializer as ts # type: ignore
10+
from executorch.backends.arm.operators.node_visitor import ( # type: ignore
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
15+
from executorch.backends.arm.tosa_mapping import TosaArg # type: ignore
16+
from serializer.tosa_serializer import TosaOp
17+
from torch.fx import Node
18+
19+
20+
@register_node_visitor
21+
class AnyVisitor(NodeVisitor):
22+
target = "aten.any.dim"
23+
24+
def define_node(
25+
self,
26+
node: Node,
27+
tosa_graph: ts.TosaSerializer,
28+
inputs: List[TosaArg],
29+
output: TosaArg,
30+
) -> None:
31+
32+
if not (inputs[0].dtype == output.dtype):
33+
raise ValueError(
34+
"All inputs and outputs need same dtype."
35+
f"Got {ts.DTypeNames[inputs[0].dtype]=}, {ts.DTypeNames[output.dtype]=}."
36+
)
37+
if not (inputs[0].dtype == ts.DType.BOOL):
38+
raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}")
39+
40+
input_shape = list(inputs[0].shape)
41+
dim = cast(int, inputs[1].number) % len(
42+
input_shape
43+
) # process the negative index
44+
keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False)
45+
if not keep_dim:
46+
raise ValueError("This case should be handled by ConvertAnyDimDimsPass")
47+
48+
attr = ts.TosaSerializerAttribute()
49+
attr.AxisAttribute(inputs[0].dim_order.index(dim))
50+
51+
tosa_graph.addOperator(
52+
TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr
53+
)

backends/arm/test/models/test_conformer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,10 @@ class TestConformer(unittest.TestCase):
3434
"executorch_exir_dialects_edge__ops_aten_max_default": 1,
3535
"executorch_exir_dialects_edge__ops_aten_eq_Scalar": 2,
3636
"executorch_exir_dialects_edge__ops_aten_where_self": 4,
37-
"executorch_exir_dialects_edge__ops_aten_any_dim": 2,
3837
"torch.ops.aten._assert_scalar.default": 10,
3938
"torch.ops.aten._local_scalar_dense.default": 1,
4039
"torch.ops.aten.scalar_tensor.default": 2,
41-
"torch.ops.higher_order.executorch_call_delegate": 4,
40+
"torch.ops.higher_order.executorch_call_delegate": 6,
4241
}
4342

4443
dim = 16

0 commit comments

Comments
 (0)