Skip to content

Commit a7b9697

Browse files
authored
Arm: Support bitwise and, xor, and or ops (#8518)
Support bitwise and, xor, and or ops in Arm backend Ops are very similar and thus clumped together. No quantization since that doesn't make sense for bitwise ops. Add a factory for creating simple two input NodeVisitors. This can be extended for future such ops. Signed-off-by: Erik Lundell <[email protected]>
1 parent 583d408 commit a7b9697

File tree

5 files changed

+281
-0
lines changed

5 files changed

+281
-0
lines changed

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# pyre-unsafe
77

88
from . import ( # noqa
9+
bitwise_support,
910
convolution_support,
1011
pool_2d_support,
1112
reduce_sum_support,
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch.fx as fx
7+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
8+
register_tosa_support_check,
9+
SupportedTOSAOperatorCheck,
10+
)
11+
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
14+
15+
@register_tosa_support_check
16+
class BitwiseSupported(SupportedTOSAOperatorCheck):
17+
targets = [
18+
exir_ops.edge.aten.bitwise_and.Tensor,
19+
exir_ops.edge.aten.bitwise_or.Tensor,
20+
exir_ops.edge.aten.bitwise_xor.Tensor,
21+
]
22+
23+
tosa_specs = [
24+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
25+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
26+
]
27+
28+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
29+
# U55 case, Vela 4.2.0 (25.02 release)
30+
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:
31+
return False
32+
33+
return True

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,5 @@
4646
op_transpose,
4747
op_upsample_nearest2d,
4848
op_view,
49+
ops_binary,
4950
)

backends/arm/operators/ops_binary.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import List
9+
10+
import serializer.tosa_serializer as ts
11+
import torch
12+
import torch.fx
13+
14+
from executorch.backends.arm.operators.node_visitor import (
15+
NodeVisitor,
16+
register_node_visitor,
17+
)
18+
from executorch.backends.arm.tosa_mapping import TosaArg
19+
from serializer.tosa_serializer import TosaOp
20+
21+
22+
def binary_operator_factory(bw_target: str, tosa_op):
23+
"""Creates and registers NodeVisitors for operators that have two inputs and map directly to a TOSA op."""
24+
25+
class BinaryOperator(NodeVisitor):
26+
target = bw_target
27+
28+
def define_node(
29+
self,
30+
node: torch.fx.Node,
31+
tosa_graph: ts.TosaSerializer,
32+
inputs: List[TosaArg],
33+
output: TosaArg,
34+
) -> None:
35+
36+
if not (inputs[0].dtype == inputs[1].dtype == output.dtype):
37+
raise ValueError(
38+
"All inputs and outputs need same dtype."
39+
f"Got {inputs[0].dtype=}, {inputs[1].dtype=}, {output.dtype=}."
40+
)
41+
42+
tosa_graph.addOperator(
43+
tosa_op, [inputs[0].name, inputs[1].name], [output.name]
44+
)
45+
46+
register_node_visitor(BinaryOperator)
47+
48+
49+
binary_operator_factory("aten.bitwise_and.Tensor", TosaOp.Op().BITWISE_AND)
50+
binary_operator_factory("aten.bitwise_xor.Tensor", TosaOp.Op().BITWISE_XOR)
51+
binary_operator_factory("aten.bitwise_or.Tensor", TosaOp.Op().BITWISE_OR)

backends/arm/test/ops/test_bitwise.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import unittest
7+
8+
from typing import Callable, NamedTuple, Tuple
9+
10+
import torch
11+
from executorch.backends.arm.test import common, conftest
12+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
13+
from parameterized import parameterized
14+
15+
16+
class DataTuple(NamedTuple):
17+
name: str
18+
tensor1: torch.Tensor
19+
tensor2: torch.Tensor
20+
21+
22+
class OpTuple(NamedTuple):
23+
name: str
24+
operator: torch.nn.Module
25+
26+
27+
class And(torch.nn.Module):
28+
def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor):
29+
return tensor1.bitwise_and(tensor2)
30+
31+
32+
class Xor(torch.nn.Module):
33+
def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor):
34+
return tensor1.bitwise_xor(tensor2)
35+
36+
37+
class Or(torch.nn.Module):
38+
def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor):
39+
return tensor1.bitwise_or(tensor2)
40+
41+
42+
test_data_suite: list[DataTuple] = [
43+
DataTuple(
44+
"zeros",
45+
torch.zeros(1, 10, 10, 10, dtype=torch.int32),
46+
torch.zeros(1, 10, 10, 10, dtype=torch.int32),
47+
),
48+
DataTuple(
49+
"ones",
50+
torch.ones(10, 10, 10, dtype=torch.int8),
51+
torch.ones(10, 10, 10, dtype=torch.int8),
52+
),
53+
DataTuple(
54+
"rand_rank2",
55+
torch.randint(-128, 127, (10, 10), dtype=torch.int8),
56+
torch.randint(-128, 127, (10, 10), dtype=torch.int8),
57+
),
58+
DataTuple(
59+
"rand_rank4",
60+
torch.randint(-128, -127, (1, 10, 10, 10), dtype=torch.int8),
61+
torch.randint(-128, 127, (1, 10, 10, 10), dtype=torch.int8),
62+
),
63+
]
64+
65+
66+
ops: list[OpTuple] = [
67+
OpTuple("and", And()),
68+
OpTuple("or", Or()),
69+
OpTuple("xor", Xor()),
70+
]
71+
72+
full_test_suite = []
73+
for op in ops:
74+
for test_data in test_data_suite:
75+
full_test_suite.append(
76+
(
77+
f"{op.name}_{test_data.name}",
78+
op.operator,
79+
test_data.tensor1,
80+
test_data.tensor2,
81+
)
82+
)
83+
84+
del test_data
85+
del ops
86+
87+
88+
class TestBitwise(unittest.TestCase):
89+
90+
def _test_bitwise_tosa_MI_pipeline(
91+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor, torch.tensor]
92+
):
93+
(
94+
ArmTester(
95+
module,
96+
example_inputs=test_data,
97+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
98+
)
99+
.export()
100+
.to_edge_transform_and_lower()
101+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
102+
.to_executorch()
103+
.run_method_and_compare_outputs(inputs=test_data)
104+
)
105+
106+
def _test_bitwise_tosa_BI_pipeline(
107+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor, torch.tensor]
108+
):
109+
(
110+
ArmTester(
111+
module,
112+
example_inputs=test_data,
113+
compile_spec=common.get_tosa_compile_spec(
114+
"TOSA-0.80+BI", custom_path="local_bin/bitwise"
115+
),
116+
)
117+
.export()
118+
.to_edge_transform_and_lower()
119+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
120+
.to_executorch()
121+
.run_method_and_compare_outputs(inputs=test_data)
122+
)
123+
124+
def _test_bitwise_tosa_u55_BI_pipeline(
125+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
126+
):
127+
# Tests that we don't delegate these ops since they are not supported on U55.
128+
(
129+
ArmTester(
130+
module,
131+
example_inputs=test_data,
132+
compile_spec=common.get_u55_compile_spec(),
133+
)
134+
.export()
135+
.to_edge_transform_and_lower()
136+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 0})
137+
)
138+
139+
def _test_bitwise_tosa_u85_BI_pipeline(
140+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
141+
):
142+
tester = (
143+
ArmTester(
144+
module,
145+
example_inputs=test_data,
146+
compile_spec=common.get_u85_compile_spec(),
147+
)
148+
.export()
149+
.to_edge_transform_and_lower()
150+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
151+
.to_executorch()
152+
.serialize()
153+
)
154+
if conftest.is_option_enabled("corstone_fvp"):
155+
tester.run_method_and_compare_outputs(inputs=test_data)
156+
157+
@parameterized.expand(full_test_suite)
158+
def test_tosa_MI(
159+
self,
160+
test_name: str,
161+
operator: Callable,
162+
tensor1: torch.Tensor,
163+
tensor2: torch.Tensor,
164+
):
165+
self._test_bitwise_tosa_MI_pipeline(operator, (tensor1, tensor2))
166+
167+
@parameterized.expand(full_test_suite)
168+
def test_tosa_BI(
169+
self,
170+
test_name: str,
171+
operator: Callable,
172+
tensor1: torch.Tensor,
173+
tensor2: torch.Tensor,
174+
):
175+
self._test_bitwise_tosa_BI_pipeline(operator, (tensor1, tensor2))
176+
177+
@parameterized.expand(full_test_suite)
178+
def test_tosa_u55_BI(
179+
self,
180+
test_name: str,
181+
operator: Callable,
182+
tensor1: torch.Tensor,
183+
tensor2: torch.Tensor,
184+
):
185+
self._test_bitwise_tosa_u55_BI_pipeline(operator, (tensor1, tensor2))
186+
187+
@parameterized.expand(full_test_suite)
188+
def test_tosa_u85_BI(
189+
self,
190+
test_name: str,
191+
operator: Callable,
192+
tensor1: torch.Tensor,
193+
tensor2: torch.Tensor,
194+
):
195+
self._test_bitwise_tosa_u85_BI_pipeline(operator, (tensor1, tensor2))

0 commit comments

Comments
 (0)