Skip to content

Commit 2d2928b

Browse files
Merge branch 'main' into issue/9131
2 parents 2359c71 + 2bee76e commit 2d2928b

File tree

18 files changed

+390
-368
lines changed

18 files changed

+390
-368
lines changed

.ci/scripts/unittest-macos.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ if [[ "$BUILD_TOOL" == "cmake" ]]; then
3535
.ci/scripts/unittest-macos-cmake.sh
3636
elif [[ "$BUILD_TOOL" == "buck2" ]]; then
3737
.ci/scripts/unittest-buck2.sh
38-
.ci/scripts/unittest-macos-buck2.sh
38+
# .ci/scripts/unittest-macos-buck2.sh
3939
else
4040
echo "Unknown build tool $BUILD_TOOL"
4141
exit 1

backends/arm/_passes/arm_pass_manager.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@
4444
from executorch.backends.arm._passes.decompose_select import ( # type: ignore[import-not-found]
4545
DecomposeSelectPass,
4646
)
47-
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
48-
DecomposeSoftmaxesPass,
47+
from executorch.backends.arm._passes.decompose_softmax_pass import DecomposeSoftmaxPass
48+
from executorch.backends.arm._passes.decompose_softmax_unstable_pass import (
49+
DecomposeSoftmaxUnstablePass,
4950
)
5051
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
5152
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
@@ -81,7 +82,7 @@
8182
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
8283
UnsqueezeScalarPlaceholdersPass,
8384
)
84-
from executorch.backends.arm.tosa_specification import TosaSpecification
85+
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
8586
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
8687

8788
from executorch.backends.transforms.replace_scalar_with_tensor import (
@@ -155,7 +156,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
155156
self.add_pass(DecomposeMeanDimPass())
156157
self.add_pass(ConvertMeanDimToAveragePoolPass())
157158
self.add_pass(DecomposeDivPass())
158-
self.add_pass(DecomposeSoftmaxesPass())
159+
self.add_pass(DecomposeSoftmaxPass())
159160
self.add_pass(ConvertFullLikeToFullPass())
160161
self.add_pass(ConvertToClampPass())
161162
self.add_pass(ConvertMinMaxPass())
@@ -204,6 +205,12 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
204205
self.add_pass(DecomposeVarPass())
205206
self.add_pass(DecomposeMeanDimPass())
206207
self.add_pass(DecomposeDivPass())
207-
self.add_pass(DecomposeSoftmaxesPass())
208+
209+
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
210+
# Numerically stable softmax uses amax which is not supported on Ethos-U55
211+
self.add_pass(DecomposeSoftmaxUnstablePass())
212+
else:
213+
self.add_pass(DecomposeSoftmaxPass())
214+
208215
self.add_pass(ConvertMinMaxPass())
209216
return self._transform(graph_module)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.dialects._ops import ops as exir_ops
8+
from executorch.exir.pass_base import ExportPass
9+
10+
# For BI case
11+
torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int)
12+
# For MI case
13+
edge_softmax = (
14+
exir_ops.edge.aten._softmax.default,
15+
exir_ops.edge.aten._log_softmax.default,
16+
)
17+
log_softmax = (torch.ops.aten.log_softmax.int, exir_ops.edge.aten._log_softmax.default)
18+
19+
20+
def _get_logsoftmax_ops(op) -> tuple:
21+
"""
22+
Returns the (log_op, sub_op, amax_op, expo_op, sum_op, reciprocal_op), where the ops depends on if
23+
the softmax op is an aten or edge op.
24+
"""
25+
if op in edge_softmax:
26+
return (
27+
exir_ops.edge.aten.log.default,
28+
exir_ops.edge.aten.sub.Tensor,
29+
exir_ops.edge.aten.amax.default,
30+
exir_ops.edge.aten.exp.default,
31+
exir_ops.edge.aten.sum.dim_IntList,
32+
exir_ops.edge.aten.reciprocal.default,
33+
exir_ops.edge.aten.mul.Tensor,
34+
)
35+
if op in torch_softmax:
36+
return (
37+
torch.ops.aten.log.default,
38+
torch.ops.aten.sub.Tensor,
39+
torch.ops.aten.amax.default,
40+
torch.ops.aten.exp.default,
41+
torch.ops.aten.sum.dim_IntList,
42+
torch.ops.aten.reciprocal.default,
43+
torch.ops.aten.mul.Tensor,
44+
)
45+
raise RuntimeError(f"Can't get logsoftmax decomposition ops for op {op}")
46+
47+
48+
class DecomposeSoftmaxPass(ExportPass):
49+
"""
50+
This pass decomposes log_softmax or softmax into more primitive ops.
51+
Example:
52+
%op1 = amax(x)
53+
%op2 = sub(x, %op1)
54+
%op3 = exp(%op2)
55+
%op4 = sum(%op3, dim)
56+
%op5 = reciprocal(%op4)
57+
%op6 = mul(%op3, %op5)
58+
(in logsoftmax case: %op7 = log(%op6))
59+
"""
60+
61+
def call_operator(self, op, args, kwargs, meta):
62+
if op not in torch_softmax + edge_softmax:
63+
return super().call_operator(op, args, kwargs, meta)
64+
log_op, sub_op, max_op, exp_op, sum_op, reciprocal_op, mul_op = (
65+
_get_logsoftmax_ops(op)
66+
)
67+
_input = args[0]
68+
dim = [args[1]]
69+
op1 = super().call_operator(max_op, (_input, dim, True), {}, meta)
70+
op2 = super().call_operator(sub_op, (_input, op1), {}, meta)
71+
op3 = super().call_operator(exp_op, (op2,), {}, meta)
72+
op4 = super().call_operator(sum_op, (op3, dim, True), {}, meta)
73+
op5 = super().call_operator(reciprocal_op, (op4,), {}, meta)
74+
op6 = super().call_operator(mul_op, (op3, op5), {}, meta)
75+
if op in log_softmax:
76+
op6 = super().call_operator(log_op, (op6,), {}, meta)
77+
return op6

backends/arm/_passes/decompose_softmaxes_pass.py renamed to backends/arm/_passes/decompose_softmax_unstable_pass.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
2-
# All rights reserved.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -46,7 +45,7 @@ def get_logsoftmax_ops(op) -> tuple:
4645
raise RuntimeError(f"Can't get softmax decomposition ops for op {op}")
4746

4847

49-
class DecomposeSoftmaxesPass(ExportPass):
48+
class DecomposeSoftmaxUnstablePass(ExportPass):
5049
"""
5150
This pass decomposes log softmax or softmax into more primitive ops.
5251

backends/arm/operators/op_amax.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import List
66

77
import serializer.tosa_serializer as ts
8+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
89
from executorch.backends.arm.operators.node_visitor import (
910
NodeVisitor,
1011
register_node_visitor,
@@ -31,6 +32,12 @@ def define_node(
3132

3233
input = inputs[0]
3334
dim = inputs[1].number
35+
36+
if dim < 0:
37+
tensor = get_first_fake_tensor(node)
38+
rank = len(tensor.size())
39+
dim = rank + dim
40+
3441
keep_dims = inputs[2].number
3542
if not keep_dims:
3643
raise RuntimeError(

backends/arm/operators/op_amin.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import List
66

77
import serializer.tosa_serializer as ts
8+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
89
from executorch.backends.arm.operators.node_visitor import (
910
NodeVisitor,
1011
register_node_visitor,
@@ -31,6 +32,12 @@ def define_node(
3132

3233
input = inputs[0]
3334
dim = inputs[1].number
35+
36+
if dim < 0:
37+
tensor = get_first_fake_tensor(node)
38+
rank = len(tensor.size())
39+
dim = rank + dim
40+
3441
keep_dims = inputs[2].number
3542
if not keep_dims:
3643
raise RuntimeError(

backends/arm/test/common.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -159,13 +158,38 @@ def get_u85_compile_spec_unbuilt(
159158
not corstone300_installed() or not arm_executor_runner_exists("corstone-300"),
160159
reason="Did not find Corstone-300 FVP or executor_runner on path",
161160
)
162-
"""Skips a test if Corsone300 FVP is not installed, or if the executor runner is not built"""
161+
"""
162+
TO BE DEPRECATED - Use XfailIfNoCorstone300 instead
163+
Skips a test if Corsone300 FVP is not installed, or if the executor runner is not built
164+
"""
163165

164166
SkipIfNoCorstone320 = pytest.mark.skipif(
165167
not corstone320_installed() or not arm_executor_runner_exists("corstone-320"),
166168
reason="Did not find Corstone-320 FVP or executor_runner on path",
167169
)
168-
"""Skips a test if Corsone320 FVP is not installed, or if the executor runner is not built."""
170+
"""
171+
TO BE DEPRECATED - Use XfailIfNoCorstone320 instead
172+
Skips a test if Corsone320 FVP is not installed, or if the executor runner is not built
173+
"""
174+
175+
176+
XfailIfNoCorstone300 = pytest.mark.xfail(
177+
condition=not (
178+
corstone300_installed() and arm_executor_runner_exists("corstone-300")
179+
),
180+
raises=FileNotFoundError,
181+
reason="Did not find Corstone-300 FVP or executor_runner on path",
182+
)
183+
"""Xfails a test if Corsone300 FVP is not installed, or if the executor runner is not built"""
184+
185+
XfailIfNoCorstone320 = pytest.mark.xfail(
186+
condition=not (
187+
corstone320_installed() and arm_executor_runner_exists("corstone-320")
188+
),
189+
raises=FileNotFoundError,
190+
reason="Did not find Corstone-320 FVP or executor_runner on path",
191+
)
192+
"""Xfails a test if Corsone320 FVP is not installed, or if the executor runner is not built"""
169193

170194

171195
def parametrize(

backends/arm/test/ops/test_add.py

Lines changed: 10 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
# All rights reserved.
3+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
@@ -106,79 +106,47 @@ def test_add_i32_tosa_BI(test_data: input_t1):
106106

107107

108108
@common.parametrize("test_data", Add.test_data)
109+
@common.XfailIfNoCorstone300
109110
def test_add_u55_BI(test_data: input_t1):
110-
pipeline = EthosU55PipelineBI[input_t1](
111-
Add(), test_data, aten_op, exir_op, run_on_fvp=False
112-
)
113-
pipeline.run()
114-
115-
116-
@common.parametrize("test_data", Add.test_data)
117-
def test_add_u85_BI(test_data: input_t1):
118-
pipeline = EthosU85PipelineBI[input_t1](
119-
Add(), test_data, aten_op, exir_op, run_on_fvp=False
120-
)
121-
pipeline.run()
122-
123-
124-
@common.parametrize("test_data", Add.test_data)
125-
@common.SkipIfNoCorstone300
126-
def test_add_u55_BI_on_fvp(test_data: input_t1):
127111
pipeline = EthosU55PipelineBI[input_t1](
128112
Add(), test_data, aten_op, exir_op, run_on_fvp=True
129113
)
130114
pipeline.run()
131115

132116

133117
@common.parametrize("test_data", Add.test_data)
134-
@common.SkipIfNoCorstone320
135-
def test_add_u85_BI_on_fvp(test_data: input_t1):
118+
@common.XfailIfNoCorstone320
119+
def test_add_u85_BI(test_data: input_t1):
136120
pipeline = EthosU85PipelineBI[input_t1](
137121
Add(), test_data, aten_op, exir_op, run_on_fvp=True
138122
)
139123
pipeline.run()
140124

141125

142126
@common.parametrize("test_data", Add2.test_data)
143-
def test_add2_tosa_MI(test_data: input_t2):
127+
def test_add_2_tosa_MI(test_data: input_t2):
144128
pipeline = TosaPipelineMI[input_t2](Add2(), test_data, aten_op, exir_op)
145129
pipeline.run()
146130

147131

148132
@common.parametrize("test_data", Add2.test_data)
149-
def test_add2_tosa_BI(test_data: input_t2):
133+
def test_add_2_tosa_BI(test_data: input_t2):
150134
pipeline = TosaPipelineBI[input_t2](Add2(), test_data, aten_op, exir_op)
151135
pipeline.run()
152136

153137

154138
@common.parametrize("test_data", Add2.test_data)
155-
def test_add2_u55_BI(test_data: input_t2):
156-
pipeline = EthosU55PipelineBI[input_t2](
157-
Add2(), test_data, aten_op, exir_op, run_on_fvp=False
158-
)
159-
pipeline.run()
160-
161-
162-
@common.parametrize("test_data", Add2.test_data)
163-
@common.SkipIfNoCorstone300
164-
def test_add2_u55_BI_on_fvp(test_data: input_t2):
139+
@common.XfailIfNoCorstone300
140+
def test_add_2_u55_BI(test_data: input_t2):
165141
pipeline = EthosU55PipelineBI[input_t2](
166142
Add2(), test_data, aten_op, exir_op, run_on_fvp=True
167143
)
168144
pipeline.run()
169145

170146

171147
@common.parametrize("test_data", Add2.test_data)
172-
def test_add2_u85_BI(test_data: input_t2):
173-
pipeline = EthosU85PipelineBI[input_t2](
174-
Add2(), test_data, aten_op, exir_op, run_on_fvp=False
175-
)
176-
pipeline.run()
177-
178-
179-
@common.parametrize("test_data", Add2.test_data)
180-
@common.SkipIfNoCorstone320
181-
def test_add2_u85_BI_on_fvp(test_data: input_t2):
148+
@common.XfailIfNoCorstone320
149+
def test_add_2_u85_BI(test_data: input_t2):
182150
pipeline = EthosU85PipelineBI[input_t2](
183151
Add2(), test_data, aten_op, exir_op, run_on_fvp=True
184152
)

0 commit comments

Comments
 (0)