Skip to content

Commit f0a7d10

Browse files
authored
Arm backend: Add acosh decomposition pass and test (#12105)
Decomposes acosh into TOSA-operations. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 Signed-off-by: Emma Kujala <[email protected]>
1 parent 1f6c465 commit f0a7d10

File tree

7 files changed

+172
-0
lines changed

7 files changed

+172
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa
2323
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
2424
from .convert_to_clamp import ConvertToClampPass # noqa
25+
from .decompose_acosh_pass import DecomposeAcoshPass # noqa
2526
from .decompose_atan_pass import DecomposeAtanPass # noqa
2627
from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa
2728
from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
ConvertSplitToSlicePass,
2626
ConvertSqueezesToViewPass,
2727
ConvertToClampPass,
28+
DecomposeAcoshPass,
2829
DecomposeAtanPass,
2930
DecomposeAvgPool2d,
3031
DecomposeBatchNormNoStatsPass,
@@ -151,6 +152,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
151152

152153
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
153154
self.add_pass(DecomposeRoundPass())
155+
self.add_pass(DecomposeAcoshPass())
154156
self.add_pass(DecomposeSqrtPass())
155157
self.add_pass(DecomposeAtanPass())
156158
self.add_pass(ConvertIntPowToMuls())
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from executorch.backends.arm._passes import ArmPass
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
11+
# For MI case
12+
edge_acosh_op = exir_ops.edge.aten.acosh.default
13+
14+
15+
class DecomposeAcoshPass(ArmPass):
16+
"""
17+
Decomposes acosh to supported TOSA-operations.
18+
This decomposition is based on the mathematical identity:
19+
acosh(x) = log(x + sqrt((x-1)(x+1))
20+
"""
21+
22+
def call_operator(self, op, args, kwargs, meta, updated=False):
23+
24+
if op is not edge_acosh_op:
25+
return super().call_operator(op, args, kwargs, meta, updated)
26+
27+
log_op, sqrt_op, mul_op, sub_op, add_op, add_op_scalar = (
28+
exir_ops.edge.aten.log.default,
29+
exir_ops.edge.aten.sqrt.default,
30+
exir_ops.edge.aten.mul.Tensor,
31+
exir_ops.edge.aten.sub.Scalar,
32+
exir_ops.edge.aten.add.Tensor,
33+
exir_ops.edge.aten.add.Scalar,
34+
)
35+
36+
x = args[0]
37+
38+
# (x-1)(x+1)
39+
sub = super().call_operator(sub_op, (x, 1.0), {}, meta, True)
40+
add = super().call_operator(add_op_scalar, (x, 1.0), {}, meta, True)
41+
mul = super().call_operator(mul_op, (sub, add), {}, meta, True)
42+
43+
# sqrt((x-1)(x+1))
44+
sqrt = super().call_operator(sqrt_op, (mul,), {}, meta, True)
45+
46+
# x + sqrt((x-1)(x+1))
47+
add = super().call_operator(add_op, (x, sqrt), {}, meta, True)
48+
49+
# out = ln(x + sqrt((x-1)(x+1))
50+
out = super().call_operator(log_op, (add,), {}, meta, True)
51+
52+
return out

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class TableOps:
5555
exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
5656
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,
5757
exir_ops.edge.aten.sinh.default: torch.sinh,
58+
exir_ops.edge.aten.acosh.default: torch.acosh,
5859
}
5960

6061
# Targets that must be treated explicitly

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def is_node_supported(
245245
exir_ops.edge.aten.alias_copy.default,
246246
exir_ops.edge.aten.sinh.default,
247247
exir_ops.edge.aten.atan.default,
248+
exir_ops.edge.aten.acosh.default,
248249
]
249250

250251
return supported

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def _match_pattern(
215215
torch.ops.aten.gelu.default,
216216
torch.ops.aten.sinh.default,
217217
torch.ops.aten.atan.default,
218+
torch.ops.aten.acosh.default,
218219
]
219220

220221
_one_to_one_shared_input_qspec = [

backends/arm/test/ops/test_acosh.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import Tuple
6+
7+
import pytest
8+
9+
import torch
10+
11+
from executorch.backends.arm.test import common
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
EthosU55PipelineBI,
14+
EthosU85PipelineBI,
15+
TosaPipelineBI,
16+
TosaPipelineMI,
17+
)
18+
19+
input_t = Tuple[torch.Tensor] # Input x
20+
aten_op = "torch.ops.aten.acosh.default"
21+
22+
23+
test_data_suite = {
24+
# Valid input cases
25+
"ones": lambda: torch.ones(1, 7, 10, 12),
26+
"just_above_one": lambda: torch.tensor([1.0001, 1.01, 1.1, 2.0]),
27+
"rand_valid": lambda: torch.rand(10, 10) * 10 + 1, # [1, 11)
28+
"ramp_valid": lambda: torch.linspace(1.0, 20.0, steps=160),
29+
"large": lambda: torch.tensor([10.0, 100.0, 1000.0, 1e6]),
30+
"mixed_valid": lambda: torch.tensor([1.0, 2.0, 10.0, 100.0]),
31+
}
32+
33+
test_data_suite_xfails = {
34+
# Invalid input cases (should return nan or error)
35+
"zeros": lambda: torch.zeros(1, 5, 3, 2),
36+
"neg_ones": lambda: -torch.ones(10, 10, 10),
37+
"rand_invalid": lambda: torch.rand(10, 10), # [0, 1)
38+
"ramp_invalid": lambda: torch.linspace(-10.0, 0.99, steps=160),
39+
"near_zero": lambda: torch.tensor([-1e-6, 0.0, 1e-6]),
40+
"large_negative": lambda: torch.tensor([-100.0, -10.0, 0.0]),
41+
}
42+
43+
44+
class Acosh(torch.nn.Module):
45+
46+
def forward(self, x: torch.Tensor):
47+
return torch.acosh(x)
48+
49+
50+
@common.parametrize("test_data", test_data_suite)
51+
def test_acosh_tosa_MI(test_data: Tuple):
52+
pipeline = TosaPipelineMI[input_t](
53+
Acosh(),
54+
(test_data(),),
55+
aten_op,
56+
exir_op=[],
57+
)
58+
pipeline.run()
59+
60+
61+
@common.parametrize("test_data", test_data_suite)
62+
def test_acosh_tosa_BI(test_data: Tuple):
63+
pipeline = TosaPipelineBI[input_t](
64+
Acosh(),
65+
(test_data(),),
66+
aten_op=[],
67+
)
68+
pipeline.run()
69+
70+
71+
@common.parametrize("test_data", test_data_suite)
72+
@common.XfailIfNoCorstone300
73+
def test_acosh_u55_BI(test_data: Tuple):
74+
pipeline = EthosU55PipelineBI[input_t](
75+
Acosh(),
76+
(test_data(),),
77+
aten_ops=[],
78+
)
79+
pipeline.run()
80+
81+
82+
@common.parametrize("test_data", test_data_suite_xfails)
83+
@pytest.mark.xfail(reason="Invalid inputs are currently not handled")
84+
def test_acosh_u55_BI_xfail(test_data: Tuple):
85+
pipeline = EthosU55PipelineBI[input_t](
86+
Acosh(),
87+
(test_data(),),
88+
aten_ops=[],
89+
run_on_fvp=False,
90+
)
91+
pipeline.run()
92+
93+
94+
@common.parametrize("test_data", test_data_suite)
95+
@common.XfailIfNoCorstone320
96+
def test_acosh_u85_BI(test_data: Tuple):
97+
pipeline = EthosU85PipelineBI[input_t](
98+
Acosh(),
99+
(test_data(),),
100+
aten_ops=[],
101+
)
102+
pipeline.run()
103+
104+
105+
@common.parametrize("test_data", test_data_suite_xfails)
106+
@pytest.mark.xfail(reason="Invalid inputs are currently not handled")
107+
def test_acosh_u85_BI_xfail(test_data: Tuple):
108+
pipeline = EthosU55PipelineBI[input_t](
109+
Acosh(),
110+
(test_data(),),
111+
aten_ops=[],
112+
run_on_fvp=False,
113+
)
114+
pipeline.run()

0 commit comments

Comments
 (0)