Skip to content

Arm backend: Add atan decomposition pass and test #11998

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
from .convert_to_clamp import ConvertToClampPass # noqa
from .decompose_atan_pass import DecomposeAtanPass # noqa
from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa
from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ConvertSplitToSlicePass,
ConvertSqueezesToViewPass,
ConvertToClampPass,
DecomposeAtanPass,
DecomposeAvgPool2d,
DecomposeBatchNormNoStatsPass,
DecomposeCosineSimilarityPass,
Expand Down Expand Up @@ -151,6 +152,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(DecomposeRoundPass())
self.add_pass(DecomposeSqrtPass())
self.add_pass(DecomposeAtanPass())
self.add_pass(ConvertIntPowToMuls())
self.add_pass(CastBoolToInt8Pass())
self.add_pass(DecomposeSinhPass())
Expand Down
119 changes: 119 additions & 0 deletions backends/arm/_passes/decompose_atan_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from math import pi

from executorch.backends.arm._passes import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops


edge_atan = exir_ops.edge.aten.atan.default # MI case


def _get_atan_ops(op):
"""Return the primitive ops required.."""
if op is not edge_atan:
raise RuntimeError(f"Can't decompose atan for op {op}")

return (
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.mul.Scalar,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.add.Scalar,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.gt.Scalar,
exir_ops.edge.aten.reciprocal.default,
exir_ops.edge.aten.where.self,
exir_ops.edge.aten.neg.default,
)


class DecomposeAtanPass(ArmPass):
"""Decomposes the atan operator into a rational (Padé) approximation."""

def _rational_approximation(self, z, ops, meta):
"""Creates a (2,1) Padé approximation for atan(x) on [-1, 1]."""

op_mul, op_mul_scalar, op_add, op_add_scalar, _, _, _, op_recip, _, _ = ops

# Coefficients calculated using minimax on the interval [-1, 1].
a1 = 0.3529666667
a2 = -0.0287666667
b1 = 0.6863

z2 = super().call_operator(op_mul, (z, z), {}, meta, updated=True)
z4 = super().call_operator(op_mul, (z2, z2), {}, meta, updated=True)

num1 = super().call_operator(op_mul_scalar, (z2, a1), {}, meta, updated=True)
num2 = super().call_operator(op_mul_scalar, (z4, a2), {}, meta, updated=True)
num = super().call_operator(op_add_scalar, (num1, 1.0), {}, meta, updated=True)
num = super().call_operator(op_add, (num, num2), {}, meta, updated=True)

den1 = super().call_operator(op_mul_scalar, (z2, b1), {}, meta, updated=True)
den = super().call_operator(op_add_scalar, (den1, 1.0), {}, meta, updated=True)

inv_den = super().call_operator(op_recip, (den,), {}, meta, updated=True)

prod = super().call_operator(op_mul, (num, inv_den), {}, meta, updated=True)
return super().call_operator(op_mul, (z, prod), {}, meta, updated=True)

def call_operator(self, op, args, kwargs, meta):
if op is not edge_atan:
return super().call_operator(op, args, kwargs, meta, updated=False)

logging.info(
f"Approximating atan. This may introduce small numerical errors. For details, see {__file__}."
)

ops = _get_atan_ops(op)
(
_,
op_mul_scalar,
_,
op_add_scalar,
op_sub,
op_abs,
op_gt,
op_recip,
op_where,
op_neg,
) = ops

x = args[0]

# |x| > 1 is reduced to [0, 1] using atan(x) = pi/2 - atan(1/x) and atan(-x) = -atan(x).

abs_x = super().call_operator(op_abs, (x,), {}, meta, updated=True)
mask_hi = super().call_operator(op_gt, (abs_x, 1.0), {}, meta, updated=True)

inv_x = super().call_operator(op_recip, (abs_x,), {}, meta, updated=True)
z = super().call_operator(
op_where, (mask_hi, inv_x, abs_x), {}, meta, updated=True
)

atan_z = self._rational_approximation(z, ops, meta)

zero_tensor = super().call_operator(
op_mul_scalar, (x, 0.0), {}, meta, updated=True
)
half_pi_tensor = super().call_operator(
op_add_scalar, (zero_tensor, pi / 2), {}, meta, updated=True
)

diff = super().call_operator(
op_sub, (half_pi_tensor, atan_z), {}, meta, updated=True
)
atan_abs = super().call_operator(
op_where, (mask_hi, diff, atan_z), {}, meta, updated=True
)

mask_pos = super().call_operator(op_gt, (x, 0.0), {}, meta, updated=True)
neg_val = super().call_operator(op_neg, (atan_abs,), {}, meta, updated=True)

return super().call_operator(
op_where, (mask_pos, atan_abs, neg_val), {}, meta, updated=True
)
1 change: 1 addition & 0 deletions backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class TableOps:
exir_ops.edge.aten.cos.default: torch.cos,
exir_ops.edge.aten.sin.default: torch.sin,
exir_ops.edge.aten.tanh.default: torch.tanh,
exir_ops.edge.aten.atan.default: torch.atan,
exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,
exir_ops.edge.aten.sinh.default: torch.sinh,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def is_node_supported(
exir_ops.edge.aten.gelu.default,
exir_ops.edge.aten.alias_copy.default,
exir_ops.edge.aten.sinh.default,
exir_ops.edge.aten.atan.default,
]

return supported
Expand Down
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def _match_pattern(
torch.ops.aten.pow.Tensor_Scalar,
torch.ops.aten.gelu.default,
torch.ops.aten.sinh.default,
torch.ops.aten.atan.default,
]

_one_to_one_shared_input_qspec = [
Expand Down
84 changes: 84 additions & 0 deletions backends/arm/test/ops/test_atan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineBI,
EthosU85PipelineBI,
TosaPipelineBI,
TosaPipelineMI,
)

aten_op = "torch.ops.aten.atan.default"
exir_op = "executorch_exir_dialects_edge__ops_aten__atan_default"

input_t1 = Tuple[torch.Tensor]

test_data_suite = {
"zeros": torch.zeros(1, 10, 10, 10),
"zeros_alt_shape": torch.zeros(1, 10, 3, 5),
"ones": torch.ones(10, 10, 10),
"rand": torch.rand(10, 10) - 0.5,
"rand_alt_shape": torch.rand(1, 10, 3, 5) - 0.5,
"randn_pos": torch.randn(10) + 10,
"randn_neg": torch.randn(10) - 10,
"ramp": torch.arange(-16, 16, 0.2),
}


class Atan(torch.nn.Module):

def forward(self, x: torch.Tensor):
return torch.atan(x)


@common.parametrize("test_data", test_data_suite)
def test_atan_tosa_MI(test_data: Tuple):
pipeline = TosaPipelineMI[input_t1](
Atan(),
(test_data,),
aten_op=aten_op,
exir_op=exir_op,
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
def test_atan_tosa_BI(test_data: Tuple):
pipeline = TosaPipelineBI[input_t1](
Atan(),
(test_data,),
aten_op=aten_op,
exir_op=exir_op,
)
pipeline.run()


@common.XfailIfNoCorstone300
@common.parametrize("test_data", test_data_suite)
def test_atan_u55_BI(test_data: Tuple):
pipeline = EthosU55PipelineBI[input_t1](
Atan(),
(test_data,),
aten_ops=aten_op,
exir_ops=exir_op,
)
pipeline.run()


@common.XfailIfNoCorstone320
@common.parametrize("test_data", test_data_suite)
def test_atan_u85_BI(test_data: Tuple):
pipeline = EthosU85PipelineBI[input_t1](
Atan(),
(test_data,),
aten_ops=aten_op,
exir_ops=exir_op,
)
pipeline.run()
Loading