Skip to content

Commit 12fbaeb

Browse files
committed
Update base for Update on "[ExecuTorch] Arm Ethos: Do not depend on torch.testing._internal "
This can cuase issues with `disable_global_flags` and internal state of the library, this is something which is set when importing this. Differential Revision: [D70402061](https://our.internmc.facebook.com/intern/diff/D70402061/) [ghstack-poisoned]
2 parents 4df0ade + 7ce47fc commit 12fbaeb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1364
-688
lines changed

.buckconfig

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
shim_et = shim_et
1212

1313
[repository_aliases]
14+
bazel_skylib = shim
1415
config = prelude
1516
ovr_config = prelude
1617
toolchains = shim_et
1718
fbcode = shim_et
18-
fbcode_macros = shim_et
19+
fbcode_macros = shim
1920
fbsource = shim_et
2021
buck = shim
22+
gh_facebook_buck2_shims_meta = shim
2123

2224
[cxx]
2325
cxxflags = -g -std=c++17

.ci/scripts/unittest-buck2.sh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
set -eux
88

99
# TODO: expand this to //...
10-
buck2 query //runtime/...
10+
# TODO: can't query cadence & vulkan backends
11+
buck2 query "//backends/apple/... + //backends/example/... + \
12+
//backends/mediatek/... + //backends/test/... + //backends/transforms/... + \
13+
//backends/xnnpack/... + //configurations/... + //kernels/portable/cpu/... + \
14+
//runtime/... + //schema/... + //test/... + //util/..."
1115

1216
# TODO: expand the covered scope of Buck targets.
1317
buck2 build //runtime/core/portable_type/...

.github/workflows/android-perf.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ on:
2020
description: Models to be benchmarked
2121
required: false
2222
type: string
23-
default: stories110M
23+
default: llama
2424
devices:
2525
description: Target devices to run benchmark
2626
required: false
@@ -36,7 +36,7 @@ on:
3636
description: Models to be benchmarked
3737
required: false
3838
type: string
39-
default: stories110M
39+
default: llama
4040
devices:
4141
description: Target devices to run benchmark
4242
required: false

.github/workflows/apple-perf.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ on:
2020
description: Models to be benchmarked
2121
required: false
2222
type: string
23-
default: stories110M
23+
default: llama
2424
devices:
2525
description: Target devices to run benchmark
2626
required: false
@@ -36,7 +36,7 @@ on:
3636
description: Models to be benchmarked
3737
required: false
3838
type: string
39-
default: stories110M
39+
default: llama
4040
devices:
4141
description: Target devices to run benchmark
4242
required: false

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
__pycache__/
99

1010
# Build and tool-generated files
11+
arm_test/
1112
buck-out/
1213
buck2-bin/
1314
cmake-android-out/
@@ -33,6 +34,7 @@ pip-out/
3334

3435
# Xcode
3536
xcuserdata/
37+
.build/
3638
.swiftpm/
3739
*.xcworkspace/
3840
*.xcframework/

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,6 @@ endif()
724724

725725
if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR)
726726
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/flat_tensor)
727-
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/flat_tensor/serialize)
728727
endif()
729728

730729
if(EXECUTORCH_BUILD_EXTENSION_LLM)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from executorch.backends.arm._passes.convert_full_like_to_full_pass import (
2222
ConvertFullLikeToFullPass,
2323
)
24+
from executorch.backends.arm._passes.convert_minmax_pass import ConvertMinMaxPass
2425
from executorch.backends.arm._passes.convert_split_to_slice import (
2526
ConvertSplitToSlicePass,
2627
)
@@ -106,6 +107,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
106107
self.add_pass(ConvertMeanDimToAveragePoolPass())
107108
self.add_pass(ConvertFullLikeToFullPass())
108109
self.add_pass(ConvertToClampPass())
110+
self.add_pass(ConvertMinMaxPass())
109111

110112
self.add_pass(ReplaceScalarWithTensorArgPass())
111113
self.add_pass(AnnotateDecomposedMatmulPass())
@@ -147,6 +149,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
147149
self.add_pass(DecomposeSoftmaxesPass())
148150
self.add_pass(ConvertFullLikeToFullPass())
149151
self.add_pass(ConvertToClampPass())
152+
self.add_pass(ConvertMinMaxPass())
150153

151154
self.add_pass(AnnotateDecomposedMatmulPass())
152155
self.add_pass(QuantizeOperatorArguments())
@@ -190,4 +193,5 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
190193
self.add_pass(DecomposeMeanDimPass())
191194
self.add_pass(DecomposeDivPass())
192195
self.add_pass(DecomposeSoftmaxesPass())
196+
self.add_pass(ConvertMinMaxPass())
193197
return self._transform(graph_module)
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.dialects._ops import ops as exir_ops
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
10+
11+
class ConvertMinMaxPass(ExportPass):
12+
"""
13+
Converts min/max to amin/amax and unrolls multi-dimensional reduction and keep-dims arg to be
14+
TOSA compliant.
15+
16+
The difference between max/min and amax/amin is (from pytorch docs):
17+
- amax/amin supports reducing on multiple dimensions,
18+
- amax/amin does not return indices,
19+
- amax/amin evenly distributes gradient between equal values, while max(dim)/min(dim)
20+
propagates gradient only to a single index in the source tensor.
21+
Since we do not care about gradients post training, convert min/max ops to amin/amax as long as
22+
the indices are not used.
23+
24+
Original:
25+
amax([dim1, dim2], keepdim = False)
26+
After pass:
27+
amax(dim1, keepdim = True)
28+
amax(dim2, keepdim = True)
29+
squeeze(dim = [dim1, dim2])
30+
"""
31+
32+
def check_argmax(self, node):
33+
"""
34+
Raises a RuntimeError if the argmax value returned by the min/max op is used in the graph.
35+
"""
36+
if node.target in [torch.ops.aten.max.dim, torch.ops.aten.min.dim]:
37+
no_argmax = len(node.users) == 1
38+
no_argmax_users = (len(node.users) == 2) and (
39+
len(list(node.users)[1].users) == 0
40+
)
41+
if not (no_argmax or no_argmax_users):
42+
raise RuntimeError("Argmax is not supported by the arm_quantizer")
43+
44+
def get_variables(self, node):
45+
"""Returns variables specific for each op handled by the pass."""
46+
if node.target in [
47+
exir_ops.edge.aten.amax.default,
48+
exir_ops.edge.aten.amin.default,
49+
]:
50+
replace_node = node
51+
op = node.target
52+
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
53+
elif node.target == exir_ops.edge.aten.max.dim:
54+
replace_node = list(node.users)[0]
55+
op = exir_ops.edge.aten.amax.default
56+
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
57+
elif node.target == exir_ops.edge.aten.min.dim:
58+
replace_node = list(node.users)[0]
59+
op = exir_ops.edge.aten.amin.default
60+
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
61+
elif node.target == torch.ops.aten.max.dim:
62+
replace_node = list(node.users)[0]
63+
op = torch.ops.aten.amax.default
64+
squeeze_op = torch.ops.aten.squeeze.dims
65+
elif node.target == torch.ops.aten.min.dim:
66+
replace_node = list(node.users)[0]
67+
op = torch.ops.aten.amin.default
68+
squeeze_op = torch.ops.aten.squeeze.dims
69+
else:
70+
raise RuntimeError(
71+
f"{node.name} is not an accepted target for ConvertMinMaxPass()"
72+
)
73+
74+
return (replace_node, op, squeeze_op)
75+
76+
def call(self, graph_module: torch.fx.GraphModule):
77+
modified = False
78+
for node in graph_module.graph.nodes:
79+
if node.op != "call_function":
80+
continue
81+
if node.target not in [
82+
exir_ops.edge.aten.amax.default,
83+
exir_ops.edge.aten.amin.default,
84+
exir_ops.edge.aten.max.dim,
85+
exir_ops.edge.aten.min.dim,
86+
torch.ops.aten.max.dim,
87+
torch.ops.aten.min.dim,
88+
]:
89+
continue
90+
91+
self.check_argmax(
92+
node
93+
) # TODO: MLETORCH-718 : Quantization of indices in arm_quantizer
94+
replace_node, op, squeeze_op = self.get_variables(node)
95+
96+
# Unwrap args
97+
if len(node.args) == 2:
98+
input_node, dims = node.args
99+
keepdims = False
100+
elif len(node.args) == 3:
101+
input_node, dims, keepdims = node.args
102+
else:
103+
raise RuntimeError(f"Unexpected arg size in {node.name}")
104+
105+
try:
106+
iter(dims)
107+
except:
108+
dims = [dims]
109+
else:
110+
dims = list(dims)
111+
112+
# Unroll multi-dimensional reduction and keep-dims arg
113+
with graph_module.graph.inserting_before(node):
114+
115+
for dim in dims:
116+
args = (input_node, dim, True)
117+
input_node = graph_module.graph.create_node(
118+
"call_function", op, args, node.kwargs
119+
)
120+
121+
if not keepdims:
122+
input_node = graph_module.graph.create_node(
123+
"call_function",
124+
squeeze_op,
125+
(input_node, dims),
126+
)
127+
128+
replace_node.replace_all_uses_with(input_node)
129+
modified = True
130+
131+
if modified:
132+
graph_module.graph.eliminate_dead_code()
133+
graph_module.recompile()
134+
graph_module = super().call(graph_module).graph_module
135+
136+
return PassResult(graph_module, True)

backends/arm/_passes/keep_dims_false_to_squeeze_pass.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -36,18 +35,18 @@ class KeepDimsFalseToSqueezePass(ExportPass):
3635
"""
3736

3837
# CURRENTLY NOT HANDLED OPS
39-
# exir_ops.edge.aten.amax,
40-
# exir_ops.edge.aten.amin,
4138
# exir_ops.edge.aten.any.dim,
4239
# exir_ops.edge.aten.any.dims,
4340
# exir_ops.edge.aten.argmax,
4441
# exir_ops.edge.aten.argmin,
45-
# exir_ops.edge.aten.max.dim,
46-
# exir_ops.edge.aten.min.dim,
4742
# exir_ops.edge.aten.prod.dim_int,
4843

4944
# HANDLED OPS
5045
# exir_ops.edge.aten.sum.dim_IntList
46+
# exir_ops.edge.aten.max.dim (decomposed in convert_minmax_pass)
47+
# exir_ops.edge.aten.min.dim (decomposed in convert_minmax_pass)
48+
# exir_ops.edge.aten.amin (decomposed in convert_minmax_pass)
49+
# exir_ops.edge.aten.amax (decomposed in convert_minmax_pass)
5150
# exir_ops.edge.aten.var.correction (decomposed in decompose_var_pass)
5251
# exir_ops.edge.aten.var.dim (decomposed in decompose_var_pass)
5352
# exir_ops.edge.aten.mean.dim (decomposed in decompose_meandim_pass)

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from . import ( # noqa
99
convolution_support,
10+
minmax_support,
1011
pool_2d_support,
1112
reduce_sum_support,
1213
right_shift_support,
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch.fx as fx
7+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
8+
register_tosa_support_check,
9+
SupportedTOSAOperatorCheck,
10+
)
11+
from executorch.backends.arm.tosa_specification import TosaSpecification
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
14+
15+
@register_tosa_support_check
16+
class MinMaxSupported(SupportedTOSAOperatorCheck):
17+
targets = [
18+
exir_ops.edge.aten.max.dim,
19+
exir_ops.edge.aten.min.dim,
20+
]
21+
22+
# TODO : "MLETORCH-718 : Quantization of indices in arm_quantizer"
23+
tosa_specs = [
24+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
25+
]
26+
27+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
28+
if node.target in [exir_ops.edge.aten.max.dim, exir_ops.edge.aten.min.dim]:
29+
no_argmax = len(node.users) == 1
30+
no_argmax_users = (len(node.users) == 2) and (
31+
len(list(node.users)[1].users) == 0
32+
)
33+
34+
if not (no_argmax or no_argmax_users):
35+
return False
36+
37+
return True

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ def is_node_supported(
169169
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
170170
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
171171
exir_ops.edge.aten.constant_pad_nd.default,
172+
exir_ops.edge.aten.amax.default,
173+
exir_ops.edge.aten.amin.default,
172174
]
173175

174176
return supported
@@ -191,6 +193,13 @@ def is_node_supported(
191193
exir_ops.edge.aten.bitwise_and.Tensor,
192194
exir_ops.edge.aten.bitwise_or.Tensor,
193195
exir_ops.edge.aten.bitwise_xor.Tensor,
196+
exir_ops.edge.aten.amax.default,
197+
exir_ops.edge.aten.amin.default,
198+
exir_ops.edge.aten.eq.Tensor,
199+
exir_ops.edge.aten.ge.Tensor,
200+
exir_ops.edge.aten.gt.Tensor,
201+
exir_ops.edge.aten.le.Tensor,
202+
exir_ops.edge.aten.lt.Tensor,
194203
]
195204

196205
if node.target in unsupported_ops:

backends/arm/operators/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
node_visitor,
1010
op_abs,
1111
op_add,
12+
op_amax,
13+
op_amin,
1214
op_avg_pool2d,
1315
op_bmm,
1416
op_cat,
@@ -24,9 +26,9 @@
2426
op_le,
2527
op_log,
2628
op_lt,
27-
op_max,
2829
op_max_pool2d,
29-
op_min,
30+
op_maximum,
31+
op_minimum,
3032
op_mul,
3133
op_permute,
3234
op_reciprocal,

0 commit comments

Comments
 (0)