Skip to content

Commit 25e6802

Browse files
committed
Arm backend: Add support for amax/max/amin/min
Max/min can be decomposed as for example max(x) = (amax(x), argmax(x)) For MI, an operator support check is added to support max-ops for which the argmax is not used. For BI, the int64 dtype returned by argmax is currently not supported by the arm_quantizer, and the program will crash. This is the same behaviour as before the patch, but with an improved error message. - Adds op_amax/op_amin node visitors. - Renames op_max/min->op_maximum/minimum to clearly separate the two ops. - Adds convert_minmax_pass to make min/max/amin/amax TOSA-compatible. - Adds unittests Util updates: - Updates analyze_output_utils to support rank 0. - Adds quantization to OpNotSupportedPipeline Change-Id: I0a7ff126696a9b46568787c40cdf128f0c00f631 Signed-off-by: Adrian Lundell <[email protected]>
1 parent 4df0ade commit 25e6802

16 files changed

+619
-7
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from executorch.backends.arm._passes.convert_full_like_to_full_pass import (
2222
ConvertFullLikeToFullPass,
2323
)
24+
from executorch.backends.arm._passes.convert_minmax_pass import ConvertMinMaxPass
2425
from executorch.backends.arm._passes.convert_split_to_slice import (
2526
ConvertSplitToSlicePass,
2627
)
@@ -106,6 +107,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
106107
self.add_pass(ConvertMeanDimToAveragePoolPass())
107108
self.add_pass(ConvertFullLikeToFullPass())
108109
self.add_pass(ConvertToClampPass())
110+
self.add_pass(ConvertMinMaxPass())
109111

110112
self.add_pass(ReplaceScalarWithTensorArgPass())
111113
self.add_pass(AnnotateDecomposedMatmulPass())
@@ -147,6 +149,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
147149
self.add_pass(DecomposeSoftmaxesPass())
148150
self.add_pass(ConvertFullLikeToFullPass())
149151
self.add_pass(ConvertToClampPass())
152+
self.add_pass(ConvertMinMaxPass())
150153

151154
self.add_pass(AnnotateDecomposedMatmulPass())
152155
self.add_pass(QuantizeOperatorArguments())
@@ -190,4 +193,5 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
190193
self.add_pass(DecomposeMeanDimPass())
191194
self.add_pass(DecomposeDivPass())
192195
self.add_pass(DecomposeSoftmaxesPass())
196+
self.add_pass(ConvertMinMaxPass())
193197
return self._transform(graph_module)
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.dialects._ops import ops as exir_ops
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
10+
11+
class ConvertMinMaxPass(ExportPass):
12+
"""
13+
Converts min/max to amin/amax and unrolls multi-dimensional reduction and keep-dims arg to be
14+
TOSA compliant.
15+
16+
The difference between max/min and amax/amin is (from pytorch docs):
17+
- amax/amin supports reducing on multiple dimensions,
18+
- amax/amin does not return indices,
19+
- amax/amin evenly distributes gradient between equal values, while max(dim)/min(dim)
20+
propagates gradient only to a single index in the source tensor.
21+
Since we do not care about gradients post training, convert min/max ops to amin/amax as long as
22+
the indices are not used.
23+
24+
Original:
25+
amax([dim1, dim2], keepdim = False)
26+
After pass:
27+
amax(dim1, keepdim = True)
28+
amax(dim2, keepdim = True)
29+
squeeze(dim = [dim1, dim2])
30+
"""
31+
32+
def check_argmax(self, node):
33+
"""
34+
Raises a RuntimeError if the argmax value returned by the min/max op is used in the graph.
35+
"""
36+
if node.target in [torch.ops.aten.max.dim, torch.ops.aten.min.dim]:
37+
no_argmax = len(node.users) == 1
38+
no_argmax_users = (len(node.users) == 2) and (
39+
len(list(node.users)[1].users) == 0
40+
)
41+
if not (no_argmax or no_argmax_users):
42+
raise RuntimeError("Argmax is not supported by the arm_quantizer")
43+
44+
def get_variables(self, node):
45+
"""Returns variables specific for each op handled by the pass."""
46+
if node.target in [
47+
exir_ops.edge.aten.amax.default,
48+
exir_ops.edge.aten.amin.default,
49+
]:
50+
replace_node = node
51+
op = node.target
52+
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
53+
elif node.target == exir_ops.edge.aten.max.dim:
54+
replace_node = list(node.users)[0]
55+
op = exir_ops.edge.aten.amax.default
56+
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
57+
elif node.target == exir_ops.edge.aten.min.dim:
58+
replace_node = list(node.users)[0]
59+
op = exir_ops.edge.aten.amin.default
60+
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
61+
elif node.target == torch.ops.aten.max.dim:
62+
replace_node = list(node.users)[0]
63+
op = torch.ops.aten.amax.default
64+
squeeze_op = torch.ops.aten.squeeze.dims
65+
elif node.target == torch.ops.aten.min.dim:
66+
replace_node = list(node.users)[0]
67+
op = torch.ops.aten.amin.default
68+
squeeze_op = torch.ops.aten.squeeze.dims
69+
else:
70+
raise RuntimeError(
71+
f"{node.name} is not an accepted target for ConvertMinMaxPass()"
72+
)
73+
74+
return (replace_node, op, squeeze_op)
75+
76+
def call(self, graph_module: torch.fx.GraphModule):
77+
modified = False
78+
for node in graph_module.graph.nodes:
79+
if node.op != "call_function":
80+
continue
81+
if node.target not in [
82+
exir_ops.edge.aten.amax.default,
83+
exir_ops.edge.aten.amin.default,
84+
exir_ops.edge.aten.max.dim,
85+
exir_ops.edge.aten.min.dim,
86+
torch.ops.aten.max.dim,
87+
torch.ops.aten.min.dim,
88+
]:
89+
continue
90+
91+
self.check_argmax(
92+
node
93+
) # TODO: MLETORCH-718 : Quantization of indices in arm_quantizer
94+
replace_node, op, squeeze_op = self.get_variables(node)
95+
96+
# Unwrap args
97+
if len(node.args) == 2:
98+
input_node, dims = node.args
99+
keepdims = False
100+
elif len(node.args) == 3:
101+
input_node, dims, keepdims = node.args
102+
else:
103+
raise RuntimeError(f"Unexpected arg size in {node.name}")
104+
105+
try:
106+
iter(dims)
107+
except:
108+
dims = [dims]
109+
else:
110+
dims = list(dims)
111+
112+
# Unroll multi-dimensional reduction and keep-dims arg
113+
with graph_module.graph.inserting_before(node):
114+
115+
for dim in dims:
116+
args = (input_node, dim, True)
117+
input_node = graph_module.graph.create_node(
118+
"call_function", op, args, node.kwargs
119+
)
120+
121+
if not keepdims:
122+
input_node = graph_module.graph.create_node(
123+
"call_function",
124+
squeeze_op,
125+
(input_node, dims),
126+
)
127+
128+
replace_node.replace_all_uses_with(input_node)
129+
modified = True
130+
131+
if modified:
132+
graph_module.graph.eliminate_dead_code()
133+
graph_module.recompile()
134+
graph_module = super().call(graph_module).graph_module
135+
136+
return PassResult(graph_module, True)

backends/arm/_passes/keep_dims_false_to_squeeze_pass.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -36,18 +35,18 @@ class KeepDimsFalseToSqueezePass(ExportPass):
3635
"""
3736

3837
# CURRENTLY NOT HANDLED OPS
39-
# exir_ops.edge.aten.amax,
40-
# exir_ops.edge.aten.amin,
4138
# exir_ops.edge.aten.any.dim,
4239
# exir_ops.edge.aten.any.dims,
4340
# exir_ops.edge.aten.argmax,
4441
# exir_ops.edge.aten.argmin,
45-
# exir_ops.edge.aten.max.dim,
46-
# exir_ops.edge.aten.min.dim,
4742
# exir_ops.edge.aten.prod.dim_int,
4843

4944
# HANDLED OPS
5045
# exir_ops.edge.aten.sum.dim_IntList
46+
# exir_ops.edge.aten.max.dim (decomposed in convert_minmax_pass)
47+
# exir_ops.edge.aten.min.dim (decomposed in convert_minmax_pass)
48+
# exir_ops.edge.aten.amin (decomposed in convert_minmax_pass)
49+
# exir_ops.edge.aten.amax (decomposed in convert_minmax_pass)
5150
# exir_ops.edge.aten.var.correction (decomposed in decompose_var_pass)
5251
# exir_ops.edge.aten.var.dim (decomposed in decompose_var_pass)
5352
# exir_ops.edge.aten.mean.dim (decomposed in decompose_meandim_pass)

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from . import ( # noqa
99
convolution_support,
10+
minmax_support,
1011
pool_2d_support,
1112
reduce_sum_support,
1213
right_shift_support,
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch.fx as fx
7+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
8+
register_tosa_support_check,
9+
SupportedTOSAOperatorCheck,
10+
)
11+
from executorch.backends.arm.tosa_specification import TosaSpecification
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
14+
15+
@register_tosa_support_check
16+
class MinMaxSupported(SupportedTOSAOperatorCheck):
17+
targets = [
18+
exir_ops.edge.aten.max.dim,
19+
exir_ops.edge.aten.min.dim,
20+
]
21+
22+
# TODO : "MLETORCH-718 : Quantization of indices in arm_quantizer"
23+
tosa_specs = [
24+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
25+
]
26+
27+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
28+
if node.target in [exir_ops.edge.aten.max.dim, exir_ops.edge.aten.min.dim]:
29+
no_argmax = len(node.users) == 1
30+
no_argmax_users = (len(node.users) == 2) and (
31+
len(list(node.users)[1].users) == 0
32+
)
33+
34+
if not (no_argmax or no_argmax_users):
35+
return False
36+
37+
return True

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ def is_node_supported(
169169
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
170170
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
171171
exir_ops.edge.aten.constant_pad_nd.default,
172+
exir_ops.edge.aten.amax.default,
173+
exir_ops.edge.aten.amin.default,
172174
]
173175

174176
return supported
@@ -191,6 +193,8 @@ def is_node_supported(
191193
exir_ops.edge.aten.bitwise_and.Tensor,
192194
exir_ops.edge.aten.bitwise_or.Tensor,
193195
exir_ops.edge.aten.bitwise_xor.Tensor,
196+
exir_ops.edge.aten.amax.default,
197+
exir_ops.edge.aten.amin.default,
194198
]
195199

196200
if node.target in unsupported_ops:

backends/arm/operators/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
node_visitor,
1010
op_abs,
1111
op_add,
12+
op_amax,
13+
op_amin,
1214
op_avg_pool2d,
1315
op_bmm,
1416
op_cat,
@@ -24,9 +26,9 @@
2426
op_le,
2527
op_log,
2628
op_lt,
27-
op_max,
2829
op_max_pool2d,
29-
op_min,
30+
op_maximum,
31+
op_minimum,
3032
op_mul,
3133
op_permute,
3234
op_reciprocal,

backends/arm/operators/op_amax.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import List
6+
7+
import serializer.tosa_serializer as ts
8+
from executorch.backends.arm.operators.node_visitor import (
9+
NodeVisitor,
10+
register_node_visitor,
11+
)
12+
from executorch.backends.arm.tosa_mapping import TosaArg
13+
from serializer.tosa_serializer import TosaOp
14+
from torch.fx import Node
15+
16+
17+
@register_node_visitor
18+
class MaxVisitor(NodeVisitor):
19+
target = "aten.amax.default"
20+
21+
def __init__(self, *args):
22+
super().__init__(*args)
23+
24+
def define_node(
25+
self,
26+
node: Node,
27+
tosa_graph: ts.TosaSerializer,
28+
inputs: List[TosaArg],
29+
output: TosaArg,
30+
) -> None:
31+
32+
input = inputs[0]
33+
dim = inputs[1].number
34+
keep_dims = inputs[2].number
35+
if not keep_dims:
36+
raise RuntimeError(
37+
"TOSA only supports keepdims == True; Did you run the convert_minmax pass?"
38+
)
39+
40+
attr = ts.TosaSerializerAttribute()
41+
attr.AxisAttribute(input.dim_order.index(dim))
42+
43+
tosa_graph.addOperator(
44+
TosaOp.Op().REDUCE_MAX, [input.name], [output.name], attr
45+
)

backends/arm/operators/op_amin.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import List
6+
7+
import serializer.tosa_serializer as ts
8+
from executorch.backends.arm.operators.node_visitor import (
9+
NodeVisitor,
10+
register_node_visitor,
11+
)
12+
from executorch.backends.arm.tosa_mapping import TosaArg
13+
from serializer.tosa_serializer import TosaOp
14+
from torch.fx import Node
15+
16+
17+
@register_node_visitor
18+
class MinVisitor(NodeVisitor):
19+
target = "aten.amin.default"
20+
21+
def __init__(self, *args):
22+
super().__init__(*args)
23+
24+
def define_node(
25+
self,
26+
node: Node,
27+
tosa_graph: ts.TosaSerializer,
28+
inputs: List[TosaArg],
29+
output: TosaArg,
30+
) -> None:
31+
32+
input = inputs[0]
33+
dim = inputs[1].number
34+
keep_dims = inputs[2].number
35+
if not keep_dims:
36+
raise RuntimeError(
37+
"TOSA only supports keepdims == True; Did you run the convert_minmax pass?"
38+
)
39+
40+
attr = ts.TosaSerializerAttribute()
41+
attr.AxisAttribute(input.dim_order.index(dim))
42+
43+
tosa_graph.addOperator(
44+
TosaOp.Op().REDUCE_MIN, [input.name], [output.name], attr
45+
)

backends/arm/quantizer/quantization_annotator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ def _match_pattern(
175175
torch.ops.aten.contiguous.default,
176176
torch.ops.aten.upsample_nearest2d.vec,
177177
torch.ops.aten.pad.default,
178+
torch.ops.aten.amax.default,
179+
torch.ops.aten.amin.default,
178180
]
179181

180182
# Operators that can inherit the quantization specs from its parent node

0 commit comments

Comments
 (0)