Skip to content

Commit 03f064b

Browse files
Arm backend: Add TOSA support for logical not (#9128)
Adds TOSA support for logical not in Arm backend. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 Signed-off-by: Måns Nilsson <[email protected]> Co-authored-by: Yufeng Shi <[email protected]>
1 parent e86c9c9 commit 03f064b

File tree

4 files changed

+25
-8
lines changed

4 files changed

+25
-8
lines changed

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def is_node_supported(
115115
exir_ops.edge.aten.logical_and.default,
116116
exir_ops.edge.aten.logical_or.default,
117117
exir_ops.edge.aten.logical_xor.default,
118+
exir_ops.edge.aten.logical_not.default,
118119
exir_ops.edge.aten.bitwise_and.Tensor,
119120
exir_ops.edge.aten.bitwise_or.Tensor,
120121
exir_ops.edge.aten.bitwise_xor.Tensor,
@@ -199,6 +200,7 @@ def is_node_supported(
199200
exir_ops.edge.aten.logical_and.default,
200201
exir_ops.edge.aten.logical_or.default,
201202
exir_ops.edge.aten.logical_xor.default,
203+
exir_ops.edge.aten.logical_not.default,
202204
exir_ops.edge.aten.amax.default,
203205
exir_ops.edge.aten.amin.default,
204206
exir_ops.edge.aten.eq.Tensor,

backends/arm/operators/ops_unary.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,17 @@
1414
)
1515

1616
from executorch.backends.arm.tosa_mapping import TosaArg
17-
from executorch.backends.arm.tosa_specification import TosaSpecification
1817
from serializer.tosa_serializer import TosaOp
1918

2019

2120
def unary_operator_factory(unary_target: str, tosa_op):
2221
"Creates and registers NodeVisitors for operations that have one input and map directly into a TOSA op."
2322

24-
class UnaryOperator_080_MI(NodeVisitor):
25-
target = unary_target
23+
# Some TOSA unary operators only support float
24+
fp_only_ops = ["aten.floor.default"]
2625

27-
tosa_specs = [TosaSpecification.create_from_string("TOSA-0.80+MI")]
26+
class UnaryOperator(NodeVisitor):
27+
target = unary_target
2828

2929
def __init__(self, *args):
3030
super().__init__(*args)
@@ -43,15 +43,15 @@ def define_node(
4343
f"Got {inputs[0].dtype=}, {output.dtype=}"
4444
)
4545

46-
if not (inputs[0].dtype == ts.DType.FP32):
46+
if self.target in fp_only_ops and not (inputs[0].dtype == ts.DType.FP32):
4747
raise ValueError(
4848
"All inputs need to be FP32." f"Got {inputs[0].dtype=}"
4949
)
5050

51-
# MI lowering
5251
tosa_graph.addOperator(tosa_op, [inputs[0].name], [output.name])
5352

54-
register_node_visitor(UnaryOperator_080_MI)
53+
register_node_visitor(UnaryOperator)
5554

5655

5756
unary_operator_factory("aten.floor.default", TosaOp.Op().FLOOR)
57+
unary_operator_factory("aten.logical_not.default", TosaOp.Op().LOGICAL_NOT)

backends/arm/test/models/test_conformer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ class TestConformer(unittest.TestCase):
3434
"executorch_exir_dialects_edge__ops_aten_max_default": 1,
3535
"executorch_exir_dialects_edge__ops_aten_eq_Scalar": 2,
3636
"executorch_exir_dialects_edge__ops_aten_where_self": 4,
37-
"executorch_exir_dialects_edge__ops_aten_logical_not_default": 4,
3837
"executorch_exir_dialects_edge__ops_aten_any_dim": 2,
3938
"torch.ops.aten._assert_scalar.default": 10,
4039
"torch.ops.aten._local_scalar_dense.default": 1,

backends/arm/test/ops/test_logical.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor):
4040
return tensor1.logical_or(tensor2)
4141

4242

43+
class Not(torch.nn.Module):
44+
aten_op = "torch.ops.aten.logical_not.default"
45+
exir_op = "executorch_exir_dialects_edge__ops_aten_logical_not_default"
46+
47+
def forward(self, tensor: torch.Tensor):
48+
return torch.logical_not(tensor)
49+
50+
4351
input_t2 = Tuple[torch.Tensor, torch.Tensor] # Input x, y
4452

4553

@@ -64,6 +72,10 @@ def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor):
6472

6573

6674
test_data = {
75+
"not_rank1": (Not(), test_input["rank1"][:-1]),
76+
"not_rand_rank2": (Not(), test_input["rand_rank2"][:-1]),
77+
"not_rand_rank3": (Not(), test_input["rand_rank3"][:-1]),
78+
"not_rand_rank4": (Not(), test_input["rand_rank4"][:-1]),
6779
"and_rank1": (And(), test_input["rank1"]),
6880
"and_rand_rank2": (And(), test_input["rand_rank2"]),
6981
"and_rand_rank3": (And(), test_input["rand_rank3"]),
@@ -80,6 +92,10 @@ def forward(self, tensor1: torch.Tensor, tensor2: torch.Tensor):
8092

8193

8294
fvp_xfails = {
95+
"not_rank1": "MLETORCH-706 Support ScalarType::Bool in EthosUBackend.",
96+
"not_rand_rank2": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
97+
"not_rand_rank3": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
98+
"not_rand_rank4": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
8399
"and_rank1": "MLETORCH-706 Support ScalarType::Bool in EthosUBackend.",
84100
"and_rand_rank2": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
85101
"and_rand_rank3": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",

0 commit comments

Comments
 (0)