Skip to content

Commit ba9254a

Browse files
authored
Arm backend: Add atan decomposition pass and test (#11998)
Decomposes atan using Padé approximation/lookup table for MI/BI case. Signed-off-by: Teo Bergkvist <[email protected]>
1 parent ecb85ce commit ba9254a

File tree

7 files changed

+209
-0
lines changed

7 files changed

+209
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa
2323
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
2424
from .convert_to_clamp import ConvertToClampPass # noqa
25+
from .decompose_atan_pass import DecomposeAtanPass # noqa
2526
from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa
2627
from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa
2728
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
ConvertSplitToSlicePass,
2626
ConvertSqueezesToViewPass,
2727
ConvertToClampPass,
28+
DecomposeAtanPass,
2829
DecomposeAvgPool2d,
2930
DecomposeBatchNormNoStatsPass,
3031
DecomposeCosineSimilarityPass,
@@ -151,6 +152,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
151152
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
152153
self.add_pass(DecomposeRoundPass())
153154
self.add_pass(DecomposeSqrtPass())
155+
self.add_pass(DecomposeAtanPass())
154156
self.add_pass(ConvertIntPowToMuls())
155157
self.add_pass(CastBoolToInt8Pass())
156158
self.add_pass(DecomposeSinhPass())
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
from math import pi
8+
9+
from executorch.backends.arm._passes import ArmPass
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
12+
13+
edge_atan = exir_ops.edge.aten.atan.default # MI case
14+
15+
16+
def _get_atan_ops(op):
17+
"""Return the primitive ops required.."""
18+
if op is not edge_atan:
19+
raise RuntimeError(f"Can't decompose atan for op {op}")
20+
21+
return (
22+
exir_ops.edge.aten.mul.Tensor,
23+
exir_ops.edge.aten.mul.Scalar,
24+
exir_ops.edge.aten.add.Tensor,
25+
exir_ops.edge.aten.add.Scalar,
26+
exir_ops.edge.aten.sub.Tensor,
27+
exir_ops.edge.aten.abs.default,
28+
exir_ops.edge.aten.gt.Scalar,
29+
exir_ops.edge.aten.reciprocal.default,
30+
exir_ops.edge.aten.where.self,
31+
exir_ops.edge.aten.neg.default,
32+
)
33+
34+
35+
class DecomposeAtanPass(ArmPass):
36+
"""Decomposes the atan operator into a rational (Padé) approximation."""
37+
38+
def _rational_approximation(self, z, ops, meta):
39+
"""Creates a (2,1) Padé approximation for atan(x) on [-1, 1]."""
40+
41+
op_mul, op_mul_scalar, op_add, op_add_scalar, _, _, _, op_recip, _, _ = ops
42+
43+
# Coefficients calculated using minimax on the interval [-1, 1].
44+
a1 = 0.3529666667
45+
a2 = -0.0287666667
46+
b1 = 0.6863
47+
48+
z2 = super().call_operator(op_mul, (z, z), {}, meta, updated=True)
49+
z4 = super().call_operator(op_mul, (z2, z2), {}, meta, updated=True)
50+
51+
num1 = super().call_operator(op_mul_scalar, (z2, a1), {}, meta, updated=True)
52+
num2 = super().call_operator(op_mul_scalar, (z4, a2), {}, meta, updated=True)
53+
num = super().call_operator(op_add_scalar, (num1, 1.0), {}, meta, updated=True)
54+
num = super().call_operator(op_add, (num, num2), {}, meta, updated=True)
55+
56+
den1 = super().call_operator(op_mul_scalar, (z2, b1), {}, meta, updated=True)
57+
den = super().call_operator(op_add_scalar, (den1, 1.0), {}, meta, updated=True)
58+
59+
inv_den = super().call_operator(op_recip, (den,), {}, meta, updated=True)
60+
61+
prod = super().call_operator(op_mul, (num, inv_den), {}, meta, updated=True)
62+
return super().call_operator(op_mul, (z, prod), {}, meta, updated=True)
63+
64+
def call_operator(self, op, args, kwargs, meta):
65+
if op is not edge_atan:
66+
return super().call_operator(op, args, kwargs, meta, updated=False)
67+
68+
logging.info(
69+
f"Approximating atan. This may introduce small numerical errors. For details, see {__file__}."
70+
)
71+
72+
ops = _get_atan_ops(op)
73+
(
74+
_,
75+
op_mul_scalar,
76+
_,
77+
op_add_scalar,
78+
op_sub,
79+
op_abs,
80+
op_gt,
81+
op_recip,
82+
op_where,
83+
op_neg,
84+
) = ops
85+
86+
x = args[0]
87+
88+
# |x| > 1 is reduced to [0, 1] using atan(x) = pi/2 - atan(1/x) and atan(-x) = -atan(x).
89+
90+
abs_x = super().call_operator(op_abs, (x,), {}, meta, updated=True)
91+
mask_hi = super().call_operator(op_gt, (abs_x, 1.0), {}, meta, updated=True)
92+
93+
inv_x = super().call_operator(op_recip, (abs_x,), {}, meta, updated=True)
94+
z = super().call_operator(
95+
op_where, (mask_hi, inv_x, abs_x), {}, meta, updated=True
96+
)
97+
98+
atan_z = self._rational_approximation(z, ops, meta)
99+
100+
zero_tensor = super().call_operator(
101+
op_mul_scalar, (x, 0.0), {}, meta, updated=True
102+
)
103+
half_pi_tensor = super().call_operator(
104+
op_add_scalar, (zero_tensor, pi / 2), {}, meta, updated=True
105+
)
106+
107+
diff = super().call_operator(
108+
op_sub, (half_pi_tensor, atan_z), {}, meta, updated=True
109+
)
110+
atan_abs = super().call_operator(
111+
op_where, (mask_hi, diff, atan_z), {}, meta, updated=True
112+
)
113+
114+
mask_pos = super().call_operator(op_gt, (x, 0.0), {}, meta, updated=True)
115+
neg_val = super().call_operator(op_neg, (atan_abs,), {}, meta, updated=True)
116+
117+
return super().call_operator(
118+
op_where, (mask_pos, atan_abs, neg_val), {}, meta, updated=True
119+
)

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class TableOps:
5151
exir_ops.edge.aten.cos.default: torch.cos,
5252
exir_ops.edge.aten.sin.default: torch.sin,
5353
exir_ops.edge.aten.tanh.default: torch.tanh,
54+
exir_ops.edge.aten.atan.default: torch.atan,
5455
exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
5556
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,
5657
exir_ops.edge.aten.sinh.default: torch.sinh,

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def is_node_supported(
244244
exir_ops.edge.aten.gelu.default,
245245
exir_ops.edge.aten.alias_copy.default,
246246
exir_ops.edge.aten.sinh.default,
247+
exir_ops.edge.aten.atan.default,
247248
]
248249

249250
return supported

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def _match_pattern(
214214
torch.ops.aten.pow.Tensor_Scalar,
215215
torch.ops.aten.gelu.default,
216216
torch.ops.aten.sinh.default,
217+
torch.ops.aten.atan.default,
217218
]
218219

219220
_one_to_one_shared_input_qspec = [

backends/arm/test/ops/test_atan.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineBI,
13+
EthosU85PipelineBI,
14+
TosaPipelineBI,
15+
TosaPipelineMI,
16+
)
17+
18+
aten_op = "torch.ops.aten.atan.default"
19+
exir_op = "executorch_exir_dialects_edge__ops_aten__atan_default"
20+
21+
input_t1 = Tuple[torch.Tensor]
22+
23+
test_data_suite = {
24+
"zeros": torch.zeros(1, 10, 10, 10),
25+
"zeros_alt_shape": torch.zeros(1, 10, 3, 5),
26+
"ones": torch.ones(10, 10, 10),
27+
"rand": torch.rand(10, 10) - 0.5,
28+
"rand_alt_shape": torch.rand(1, 10, 3, 5) - 0.5,
29+
"randn_pos": torch.randn(10) + 10,
30+
"randn_neg": torch.randn(10) - 10,
31+
"ramp": torch.arange(-16, 16, 0.2),
32+
}
33+
34+
35+
class Atan(torch.nn.Module):
36+
37+
def forward(self, x: torch.Tensor):
38+
return torch.atan(x)
39+
40+
41+
@common.parametrize("test_data", test_data_suite)
42+
def test_atan_tosa_MI(test_data: Tuple):
43+
pipeline = TosaPipelineMI[input_t1](
44+
Atan(),
45+
(test_data,),
46+
aten_op=aten_op,
47+
exir_op=exir_op,
48+
)
49+
pipeline.run()
50+
51+
52+
@common.parametrize("test_data", test_data_suite)
53+
def test_atan_tosa_BI(test_data: Tuple):
54+
pipeline = TosaPipelineBI[input_t1](
55+
Atan(),
56+
(test_data,),
57+
aten_op=aten_op,
58+
exir_op=exir_op,
59+
)
60+
pipeline.run()
61+
62+
63+
@common.XfailIfNoCorstone300
64+
@common.parametrize("test_data", test_data_suite)
65+
def test_atan_u55_BI(test_data: Tuple):
66+
pipeline = EthosU55PipelineBI[input_t1](
67+
Atan(),
68+
(test_data,),
69+
aten_ops=aten_op,
70+
exir_ops=exir_op,
71+
)
72+
pipeline.run()
73+
74+
75+
@common.XfailIfNoCorstone320
76+
@common.parametrize("test_data", test_data_suite)
77+
def test_atan_u85_BI(test_data: Tuple):
78+
pipeline = EthosU85PipelineBI[input_t1](
79+
Atan(),
80+
(test_data,),
81+
aten_ops=aten_op,
82+
exir_ops=exir_op,
83+
)
84+
pipeline.run()

0 commit comments

Comments
 (0)