Skip to content

Arm backend: Add support for aten.round #11813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .decompose_maxpool2d_with_dilation import DecomposeMaxPool2DPass # noqa
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
from .decompose_round_pass import DecomposeRoundPass # noqa
from .decompose_select import DecomposeSelectPass # noqa
from .decompose_silu_pass import DecomposeSiluPass # noqa
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
DecomposeMaxPool2DPass,
DecomposeMeanDimPass,
DecomposeNotEqualPass,
DecomposeRoundPass,
DecomposeSelectPass,
DecomposeSiluPass,
DecomposeSoftmaxPass,
Expand Down Expand Up @@ -139,6 +140,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
return self._transform(exported_program.graph_module)

def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(DecomposeRoundPass())
self.add_pass(DecomposeSqrtPass())
self.add_pass(ConvertIntPowToMuls())
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
Expand Down Expand Up @@ -219,6 +221,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(InsertCastForOpsWithInt64InputPass())
self.add_pass(DecomposeEmbeddingPass())
self.add_pass(DecomposeScaledDotProductAttention())
self.add_pass(DecomposeRoundPass())
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
self.add_pass(ScalarsToAttributePass())
self.add_pass(DecomposeGroupNormPass())
Expand Down
84 changes: 84 additions & 0 deletions backends/arm/_passes/decompose_round_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from torch._ops import OpOverload


Op = OpOverload | EdgeOpOverload


def _get_round_decomposition_ops(op) -> tuple[Op, Op, Op, Op, Op, Op, Op]:
"""
Returns the (full_op, ge_op, add_op, sub_op, floor_op, ceil_op, where_op) for the
given round operation. The ops depend on whether the round op is an aten or edge op.
"""
if op == exir_ops.edge.aten.round.default:
return (
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.ge.Tensor,
exir_ops.edge.aten.add.Scalar,
exir_ops.edge.aten.sub.Scalar,
exir_ops.edge.aten.floor.default,
exir_ops.edge.aten.ceil.default,
exir_ops.edge.aten.where.self,
)
elif op == torch.ops.aten.round.default:
return (
torch.ops.aten.full.default,
torch.ops.aten.ge.Tensor,
torch.ops.aten.add.Scalar,
torch.ops.aten.sub.Scalar,
torch.ops.aten.floor.default,
torch.ops.aten.ceil.default,
torch.ops.aten.where.self,
)
raise RuntimeError(f"Can't get round decomposition ops for op {op}")


class DecomposeRoundPass(ArmPass):
"""
For inputs >= 0, round(x) is equivalent to floor(x + 0.5), and for inputs < 0,
round(x) is equivalent to ceil(x - 0.5). This pass decomposes the round operation into
a sequence of more primitive operations.
Example:
%zero = full((1,), 0.0, dtype=torch.float32)
%is_non_negative = ge(x, %zero)
%plus_half = add(x, 0.5)
%minus_half = sub(x, 0.5)
%floor = floor(%plus_half)
%ceil = ceil(%minus_half)
%result = where(%is_non_negative, %floor, %ceil)
"""

def call_operator(self, op, args, kwargs, meta, updated=False):
if op not in (exir_ops.edge.aten.round.default, torch.ops.aten.round.default):
return super().call_operator(op, args, kwargs, meta, updated)
x = args[0]
full, ge, add, sub, floor, ceil, where = _get_round_decomposition_ops(op)
zero = super().call_operator(
full,
args=((1,), 0.0),
kwargs={"dtype": torch.float32},
meta=meta,
updated=True,
)
is_non_negative = super().call_operator(
ge, (x, zero), kwargs, meta, updated=True
)
plus_half = super().call_operator(add, (x, 0.5), kwargs, meta, updated=True)
minus_half = super().call_operator(sub, (x, 0.5), kwargs, meta, updated=True)
floor = super().call_operator(floor, (plus_half,), kwargs, meta, updated=True)
ceil = super().call_operator(ceil, (minus_half,), kwargs, meta, updated=True)
return super().call_operator(
where,
(is_non_negative, floor, ceil),
kwargs,
meta,
updated=True,
)
2 changes: 2 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def is_node_supported(
exir_ops.edge.aten.leaky_relu.default,
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten.round.default,
exir_ops.edge.aten._softmax.default,
exir_ops.edge.aten.select_copy.int,
exir_ops.edge.aten._log_softmax.default,
Expand Down Expand Up @@ -281,6 +282,7 @@ def is_node_supported(
exir_ops.edge.aten.ne.Scalar: None,
exir_ops.edge.aten.div.Scalar: None,
exir_ops.edge.aten.leaky_relu.default: None,
exir_ops.edge.aten.round.default: None,
}

if node.target in needs_decomp_dict:
Expand Down
84 changes: 84 additions & 0 deletions backends/arm/test/ops/test_round.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Tuple

import pytest
import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineBI,
EthosU85PipelineBI,
TosaPipelineBI,
TosaPipelineMI,
)

input_t1 = Tuple[torch.Tensor] # Input x

aten_op = "torch.ops.aten.round.default"
exir_op = "executorch_exir_dialects_edge__ops_aten_round_default"

test_data_suite = {
# (test_name, test_data)
"zeros": lambda: torch.zeros(1, 10, 10, 10),
"ones": lambda: torch.ones(10, 10, 10),
"rand": lambda: torch.rand(10, 10) - 0.5,
"randn_pos": lambda: torch.randn(10) + 10,
"randn_neg": lambda: torch.randn(10) - 10,
"ramp": lambda: torch.arange(-16, 16, 0.2),
}


class Round(torch.nn.Module):
def forward(self, x: torch.Tensor):
return x.round()


@common.parametrize("test_data", test_data_suite)
def test_round_tosa_MI(test_data: torch.Tensor):
pipeline = TosaPipelineMI[input_t1](
Round(),
(test_data(),),
aten_op,
exir_op,
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
def test_round_tosa_BI(test_data: torch.Tensor):
pipeline = TosaPipelineBI[input_t1](
Round(),
(test_data(),),
[],
exir_op,
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.XfailIfNoCorstone300
@pytest.mark.xfail(reason="where.self not supported on U55")
def test_round_u55_BI(test_data: torch.Tensor):
pipeline = EthosU55PipelineBI[input_t1](
Round(),
(test_data(),),
[],
exir_op,
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.XfailIfNoCorstone320
def test_round_u85_BI(test_data: torch.Tensor):
pipeline = EthosU85PipelineBI[input_t1](
Round(),
(test_data(),),
[],
exir_op,
)
pipeline.run()
Loading