Skip to content

Commit 4f90ce4

Browse files
Arm: Add FLOOR operator (#8563)
Implement an unary operator factory for creating one input NodeVisitors. Change-Id: I59ba0407b763e9e0cb79f214b7679465eda94825
1 parent 43efc37 commit 4f90ce4

File tree

6 files changed

+143
-0
lines changed

6 files changed

+143
-0
lines changed

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class InsertTableOpsPass(ExportPass):
3939

4040
table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = {
4141
exir_ops.edge.aten.exp.default: torch.exp,
42+
exir_ops.edge.aten.floor.default: torch.floor,
4243
exir_ops.edge.aten.log.default: torch.log,
4344
exir_ops.edge.aten.reciprocal.default: torch.reciprocal,
4445
exir_ops.edge.aten.rsqrt.default: torch.rsqrt,

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
107107
exir_ops.edge.aten.log.default,
108108
exir_ops.edge.aten.linear.default,
109109
exir_ops.edge.aten.split_with_sizes_copy.default,
110+
exir_ops.edge.aten.floor.default,
110111
exir_ops.edge.aten.full.default,
111112
exir_ops.edge.aten.full_like.default,
112113
exir_ops.edge.aten.ge.Tensor,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,5 @@
4646
op_upsample_nearest2d,
4747
op_view,
4848
ops_binary,
49+
ops_unary,
4950
)

backends/arm/operators/ops_unary.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import List
8+
9+
import serializer.tosa_serializer as ts # type: ignore
10+
import torch.fx
11+
from executorch.backends.arm.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
16+
from executorch.backends.arm.tosa_mapping import TosaArg
17+
from executorch.backends.arm.tosa_specification import TosaSpecification
18+
from serializer.tosa_serializer import TosaOp
19+
20+
21+
def unary_operator_factory(unary_target: str, tosa_op):
22+
"Creates and registers NodeVisitors for operations that have one input and map directly into a TOSA op."
23+
24+
class UnaryOperator_080_MI(NodeVisitor):
25+
target = unary_target
26+
27+
tosa_specs = [TosaSpecification.create_from_string("TOSA-0.80+MI")]
28+
29+
def __init__(self, *args):
30+
super().__init__(*args)
31+
32+
def define_node(
33+
self,
34+
node: torch.fx.Node,
35+
tosa_graph: ts.TosaSerializer,
36+
inputs: List[TosaArg],
37+
output: TosaArg,
38+
) -> None:
39+
40+
if not (inputs[0].dtype == output.dtype):
41+
raise ValueError(
42+
"All inputs and output need same dtype."
43+
f"Got {inputs[0].dtype=}, {output.dtype=}"
44+
)
45+
46+
if not (inputs[0].dtype == ts.DType.FP32):
47+
raise ValueError(
48+
"All inputs need to be FP32." f"Got {inputs[0].dtype=}"
49+
)
50+
51+
# MI lowering
52+
tosa_graph.addOperator(tosa_op, [inputs[0].name], [output.name])
53+
54+
register_node_visitor(UnaryOperator_080_MI)
55+
56+
57+
unary_operator_factory("aten.floor.default", TosaOp.Op().FLOOR)

backends/arm/quantizer/quantization_annotator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def _match_pattern(
127127
_one_to_one = [
128128
torch.ops.aten.abs.default,
129129
torch.ops.aten.exp.default,
130+
torch.ops.aten.floor.default,
130131
torch.ops.aten.log.default,
131132
torch.ops.aten.reciprocal.default,
132133
torch.ops.aten.rsqrt.default,

backends/arm/test/ops/test_floor.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Tuple
8+
9+
import torch
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineBI,
13+
EthosU85PipelineBI,
14+
TosaPipelineBI,
15+
TosaPipelineMI,
16+
)
17+
18+
19+
aten_op = "torch.ops.aten.floor.default"
20+
exir_op = "executorch_exir_dialects_edge__ops_aten_floor_default"
21+
22+
input_t1 = Tuple[torch.Tensor] # Input x
23+
24+
25+
class Floor(torch.nn.Module):
26+
def forward(self, x: torch.Tensor):
27+
return torch.floor(x)
28+
29+
test_data: dict[str, input_t1] = {
30+
"zeros": (torch.zeros(1, 10, 10, 10),),
31+
"ones": (torch.ones(10, 10, 10),),
32+
"rand": ((torch.rand(10, 10) - 0.5),),
33+
"randn_pos": ((torch.randn(1, 4, 4, 4) + 10),),
34+
"randn_neg": ((torch.randn(1, 4, 4, 4) - 10),),
35+
"ramp": (torch.arange(-16, 16, 0.2),),
36+
}
37+
38+
39+
@common.parametrize("test_data", Floor.test_data)
40+
def test_floor_tosa_MI(test_data: input_t1):
41+
pipeline = TosaPipelineMI[input_t1](Floor(), test_data, aten_op, exir_op)
42+
pipeline.run()
43+
44+
45+
@common.parametrize("test_data", Floor.test_data)
46+
def test_floor_tosa_BI(test_data: input_t1):
47+
pipeline = TosaPipelineBI[input_t1](Floor(), test_data, aten_op, exir_op)
48+
pipeline.run()
49+
50+
51+
@common.parametrize("test_data", Floor.test_data)
52+
def test_floor_u55_BI(test_data: input_t1):
53+
pipeline = EthosU55PipelineBI[input_t1](
54+
Floor(), test_data, aten_op, exir_op, run_on_fvp=False
55+
)
56+
pipeline.run()
57+
58+
59+
@common.parametrize("test_data", Floor.test_data)
60+
def test_floor_u85_BI(test_data: input_t1):
61+
pipeline = EthosU85PipelineBI[input_t1](
62+
Floor(), test_data, aten_op, exir_op, run_on_fvp=False
63+
)
64+
pipeline.run()
65+
66+
67+
@common.parametrize("test_data", Floor.test_data)
68+
@common.SkipIfNoCorstone300
69+
def test_floor_u55_BI_on_fvp(test_data: input_t1):
70+
pipeline = EthosU55PipelineBI[input_t1](
71+
Floor(), test_data, aten_op, exir_op, run_on_fvp=True
72+
)
73+
pipeline.run()
74+
75+
76+
@common.parametrize("test_data", Floor.test_data)
77+
@common.SkipIfNoCorstone320
78+
def test_floor_u85_BI_on_fvp(test_data: input_t1):
79+
pipeline = EthosU85PipelineBI[input_t1](
80+
Floor(), test_data, aten_op, exir_op, run_on_fvp=True
81+
)
82+
pipeline.run()

0 commit comments

Comments
 (0)