Skip to content

Arm backend: Add support to neg.default #10653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def is_node_supported(
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.ne.Tensor,
exir_ops.edge.aten.ne.Scalar,
exir_ops.edge.aten.neg.default,
exir_ops.edge.aten.add.Scalar,
exir_ops.edge.aten.sub.Scalar,
exir_ops.edge.aten.mul.Scalar,
Expand Down Expand Up @@ -311,6 +312,7 @@ class CheckProperQuantization(OperatorSupportBase):
exir_ops.edge.aten.max_pool2d_with_indices.default,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.neg.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.upsample_bilinear2d.vec,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
op_maximum,
op_minimum,
op_mul,
op_neg,
op_permute,
op_pow,
op_reciprocal,
Expand Down
78 changes: 78 additions & 0 deletions backends/arm/operators/op_neg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import List

import torch.fx

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
get_output_qparams,
)
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)

from executorch.backends.arm.tosa_mapping import TosaArg


def get_negate_zero_points(node: torch.fx.Node, dtype: ts.DType) -> tuple[int, int]:
"""
Returns (input1_zp, output_zp) for TOSA NEGATE.
Must be zero for non-int8 types.
"""
if dtype == ts.DType.INT8:
return (
get_input_qparams(node)[0].zp,
get_output_qparams(node)[0].zp,
)
return (0, 0)


@register_node_visitor
class NegVisitor(NodeVisitor):
target = "aten.neg.default"

supported_dtypes = {
ts.DType.INT8,
ts.DType.INT16,
ts.DType.INT32,
ts.DType.FP16,
ts.DType.BF16,
ts.DType.FP32,
}

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

if inputs[0].dtype not in self.supported_dtypes:
raise ValueError(f"Unsupported dtype for NEGATE: {inputs[0].dtype}")

if inputs[0].dtype != output.dtype:
raise ValueError(
"All inputs and output need same dtype."
f"Got {inputs[0].dtype=}, {output.dtype=}"
)
input_zp, output_zp = get_negate_zero_points(node, inputs[0].dtype)

attr = ts.TosaSerializerAttribute()
attr.NegateAttribute(input1_zp=input_zp, output_zp=output_zp)
tosa_graph.addOperator(
ts.TosaOp.Op().NEGATE,
[inputs[0].name],
[output.name],
attributes=attr,
)
3 changes: 3 additions & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,9 @@ def any_or_hardtanh_min_zero(n: Node):
)
]
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
elif node.target in (torch.ops.aten.neg.default,):
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
quant_properties.quant_output = _QuantProperty(0, input_act_qspec)
elif node.target in _one_to_one:
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
Expand Down
66 changes: 66 additions & 0 deletions backends/arm/test/ops/test_neg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Dict, Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineBI,
EthosU85PipelineBI,
TosaPipelineBI,
TosaPipelineMI,
)

input_t1 = Tuple[torch.Tensor]


class Neg(torch.nn.Module):

aten_op = "torch.ops.aten.neg.default"
exir_op = "executorch_exir_dialects_edge__ops_aten_neg_default"

test_data: Dict[str, input_t1] = {
"rank_1_ramp": (torch.arange(-16, 16, 0.2),),
"rank_2_rand_uniform": (torch.rand(10, 10) - 0.5,),
"rank_3_all_ones": (torch.ones(10, 10, 10),),
"rank_4_all_zeros": (torch.zeros(1, 10, 10, 10),),
"rank_4_randn_pos": (torch.randn(1, 4, 4, 4) + 10,),
"rank_4_randn_neg": (torch.randn(1, 4, 4, 4) - 10,),
}

def forward(self, x: torch.Tensor):
return torch.neg(x)


@common.parametrize("test_data", Neg.test_data)
def test_neg_tosa_MI(test_data: input_t1):
pipeline = TosaPipelineMI[input_t1](Neg(), test_data, Neg.aten_op, Neg.exir_op)
pipeline.run()


@common.parametrize("test_data", Neg.test_data)
def test_neg_tosa_BI(test_data: input_t1):
pipeline = TosaPipelineBI[input_t1](Neg(), test_data, Neg.aten_op, Neg.exir_op)
pipeline.run()


@common.parametrize("test_data", Neg.test_data)
@common.XfailIfNoCorstone300
def test_neg_u55_BI(test_data: input_t1):
pipeline = EthosU55PipelineBI[input_t1](
Neg(), test_data, Neg.aten_op, Neg.exir_op, run_on_fvp=True
)
pipeline.run()


@common.parametrize("test_data", Neg.test_data)
@common.XfailIfNoCorstone320
def test_neg_u85_BI(test_data: input_t1):
pipeline = EthosU85PipelineBI[input_t1](
Neg(), test_data, Neg.aten_op, Neg.exir_op, run_on_fvp=True
)
pipeline.run()
Loading