Skip to content

Arm backend: Add support for amax/max/amin/min #8829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from executorch.backends.arm._passes.convert_full_like_to_full_pass import (
ConvertFullLikeToFullPass,
)
from executorch.backends.arm._passes.convert_minmax_pass import ConvertMinMaxPass
from executorch.backends.arm._passes.convert_split_to_slice import (
ConvertSplitToSlicePass,
)
Expand Down Expand Up @@ -106,6 +107,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertMeanDimToAveragePoolPass())
self.add_pass(ConvertFullLikeToFullPass())
self.add_pass(ConvertToClampPass())
self.add_pass(ConvertMinMaxPass())

self.add_pass(ReplaceScalarWithTensorArgPass())
self.add_pass(AnnotateDecomposedMatmulPass())
Expand Down Expand Up @@ -147,6 +149,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(DecomposeSoftmaxesPass())
self.add_pass(ConvertFullLikeToFullPass())
self.add_pass(ConvertToClampPass())
self.add_pass(ConvertMinMaxPass())

self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeOperatorArguments())
Expand Down Expand Up @@ -190,4 +193,5 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeMeanDimPass())
self.add_pass(DecomposeDivPass())
self.add_pass(DecomposeSoftmaxesPass())
self.add_pass(ConvertMinMaxPass())
return self._transform(graph_module)
136 changes: 136 additions & 0 deletions backends/arm/_passes/convert_minmax_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class ConvertMinMaxPass(ExportPass):
"""
Converts min/max to amin/amax and unrolls multi-dimensional reduction and keep-dims arg to be
TOSA compliant.

The difference between max/min and amax/amin is (from pytorch docs):
- amax/amin supports reducing on multiple dimensions,
- amax/amin does not return indices,
- amax/amin evenly distributes gradient between equal values, while max(dim)/min(dim)
propagates gradient only to a single index in the source tensor.
Since we do not care about gradients post training, convert min/max ops to amin/amax as long as
the indices are not used.

Original:
amax([dim1, dim2], keepdim = False)
After pass:
amax(dim1, keepdim = True)
amax(dim2, keepdim = True)
squeeze(dim = [dim1, dim2])
"""

def check_argmax(self, node):
"""
Raises a RuntimeError if the argmax value returned by the min/max op is used in the graph.
"""
if node.target in [torch.ops.aten.max.dim, torch.ops.aten.min.dim]:
no_argmax = len(node.users) == 1
no_argmax_users = (len(node.users) == 2) and (
len(list(node.users)[1].users) == 0
)
if not (no_argmax or no_argmax_users):
raise RuntimeError("Argmax is not supported by the arm_quantizer")

def get_variables(self, node):
"""Returns variables specific for each op handled by the pass."""
if node.target in [
exir_ops.edge.aten.amax.default,
exir_ops.edge.aten.amin.default,
]:
replace_node = node
op = node.target
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
elif node.target == exir_ops.edge.aten.max.dim:
replace_node = list(node.users)[0]
op = exir_ops.edge.aten.amax.default
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
elif node.target == exir_ops.edge.aten.min.dim:
replace_node = list(node.users)[0]
op = exir_ops.edge.aten.amin.default
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
elif node.target == torch.ops.aten.max.dim:
replace_node = list(node.users)[0]
op = torch.ops.aten.amax.default
squeeze_op = torch.ops.aten.squeeze.dims
elif node.target == torch.ops.aten.min.dim:
replace_node = list(node.users)[0]
op = torch.ops.aten.amin.default
squeeze_op = torch.ops.aten.squeeze.dims
else:
raise RuntimeError(
f"{node.name} is not an accepted target for ConvertMinMaxPass()"
)

return (replace_node, op, squeeze_op)

def call(self, graph_module: torch.fx.GraphModule):
modified = False
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
if node.target not in [
exir_ops.edge.aten.amax.default,
exir_ops.edge.aten.amin.default,
exir_ops.edge.aten.max.dim,
exir_ops.edge.aten.min.dim,
torch.ops.aten.max.dim,
torch.ops.aten.min.dim,
]:
continue

self.check_argmax(
node
) # TODO: MLETORCH-718 : Quantization of indices in arm_quantizer
replace_node, op, squeeze_op = self.get_variables(node)

# Unwrap args
if len(node.args) == 2:
input_node, dims = node.args
keepdims = False
elif len(node.args) == 3:
input_node, dims, keepdims = node.args
else:
raise RuntimeError(f"Unexpected arg size in {node.name}")

try:
iter(dims)
except:
dims = [dims]
else:
dims = list(dims)

# Unroll multi-dimensional reduction and keep-dims arg
with graph_module.graph.inserting_before(node):

for dim in dims:
args = (input_node, dim, True)
input_node = graph_module.graph.create_node(
"call_function", op, args, node.kwargs
)

if not keepdims:
input_node = graph_module.graph.create_node(
"call_function",
squeeze_op,
(input_node, dims),
)

replace_node.replace_all_uses_with(input_node)
modified = True

if modified:
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)
9 changes: 4 additions & 5 deletions backends/arm/_passes/keep_dims_false_to_squeeze_pass.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -36,18 +35,18 @@ class KeepDimsFalseToSqueezePass(ExportPass):
"""

# CURRENTLY NOT HANDLED OPS
# exir_ops.edge.aten.amax,
# exir_ops.edge.aten.amin,
# exir_ops.edge.aten.any.dim,
# exir_ops.edge.aten.any.dims,
# exir_ops.edge.aten.argmax,
# exir_ops.edge.aten.argmin,
# exir_ops.edge.aten.max.dim,
# exir_ops.edge.aten.min.dim,
# exir_ops.edge.aten.prod.dim_int,

# HANDLED OPS
# exir_ops.edge.aten.sum.dim_IntList
# exir_ops.edge.aten.max.dim (decomposed in convert_minmax_pass)
# exir_ops.edge.aten.min.dim (decomposed in convert_minmax_pass)
# exir_ops.edge.aten.amin (decomposed in convert_minmax_pass)
# exir_ops.edge.aten.amax (decomposed in convert_minmax_pass)
# exir_ops.edge.aten.var.correction (decomposed in decompose_var_pass)
# exir_ops.edge.aten.var.dim (decomposed in decompose_var_pass)
# exir_ops.edge.aten.mean.dim (decomposed in decompose_meandim_pass)
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from . import ( # noqa
convolution_support,
minmax_support,
pool_2d_support,
reduce_sum_support,
right_shift_support,
Expand Down
37 changes: 37 additions & 0 deletions backends/arm/operator_support/minmax_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch.fx as fx
from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops


@register_tosa_support_check
class MinMaxSupported(SupportedTOSAOperatorCheck):
targets = [
exir_ops.edge.aten.max.dim,
exir_ops.edge.aten.min.dim,
]

# TODO : "MLETORCH-718 : Quantization of indices in arm_quantizer"
tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
if node.target in [exir_ops.edge.aten.max.dim, exir_ops.edge.aten.min.dim]:
no_argmax = len(node.users) == 1
no_argmax_users = (len(node.users) == 2) and (
len(list(node.users)[1].users) == 0
)

if not (no_argmax or no_argmax_users):
return False

return True
4 changes: 4 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ def is_node_supported(
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.aten.constant_pad_nd.default,
exir_ops.edge.aten.amax.default,
exir_ops.edge.aten.amin.default,
]

return supported
Expand All @@ -191,6 +193,8 @@ def is_node_supported(
exir_ops.edge.aten.bitwise_and.Tensor,
exir_ops.edge.aten.bitwise_or.Tensor,
exir_ops.edge.aten.bitwise_xor.Tensor,
exir_ops.edge.aten.amax.default,
exir_ops.edge.aten.amin.default,
]

if node.target in unsupported_ops:
Expand Down
6 changes: 4 additions & 2 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
node_visitor,
op_abs,
op_add,
op_amax,
op_amin,
op_avg_pool2d,
op_bmm,
op_cat,
Expand All @@ -24,9 +26,9 @@
op_le,
op_log,
op_lt,
op_max,
op_max_pool2d,
op_min,
op_maximum,
op_minimum,
op_mul,
op_permute,
op_reciprocal,
Expand Down
45 changes: 45 additions & 0 deletions backends/arm/operators/op_amax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List

import serializer.tosa_serializer as ts
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


@register_node_visitor
class MaxVisitor(NodeVisitor):
target = "aten.amax.default"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

input = inputs[0]
dim = inputs[1].number
keep_dims = inputs[2].number
if not keep_dims:
raise RuntimeError(
"TOSA only supports keepdims == True; Did you run the convert_minmax pass?"
)

attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(input.dim_order.index(dim))

tosa_graph.addOperator(
TosaOp.Op().REDUCE_MAX, [input.name], [output.name], attr
)
45 changes: 45 additions & 0 deletions backends/arm/operators/op_amin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List

import serializer.tosa_serializer as ts
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


@register_node_visitor
class MinVisitor(NodeVisitor):
target = "aten.amin.default"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

input = inputs[0]
dim = inputs[1].number
keep_dims = inputs[2].number
if not keep_dims:
raise RuntimeError(
"TOSA only supports keepdims == True; Did you run the convert_minmax pass?"
)

attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(input.dim_order.index(dim))

tosa_graph.addOperator(
TosaOp.Op().REDUCE_MIN, [input.name], [output.name], attr
)
2 changes: 2 additions & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ def _match_pattern(
torch.ops.aten.contiguous.default,
torch.ops.aten.upsample_nearest2d.vec,
torch.ops.aten.pad.default,
torch.ops.aten.amax.default,
torch.ops.aten.amin.default,
]

# Operators that can inherit the quantization specs from its parent node
Expand Down
Loading
Loading