Skip to content

Arm backend: Create op utility function for num input verification #10713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions backends/arm/operators/op_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from torch.fx import Node
Expand All @@ -39,6 +42,7 @@ def define_node(

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 1)
# Specification (0.80) states that input and output types
# should all be the same
if not (inputs[0].dtype == output.dtype):
Expand Down Expand Up @@ -105,6 +109,7 @@ def define_node(

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 1)
# Specification (0.80) states that input and output types
# should all be the same
if not (inputs[0].dtype == output.dtype):
Expand Down Expand Up @@ -157,6 +162,8 @@ def define_node(

import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 1)

# Specification (1.0) states that input and output types
# should all be the same
if not (inputs[0].dtype == output.dtype):
Expand Down Expand Up @@ -224,6 +231,8 @@ def define_node(

import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 1)

# Specification (1.0) states that input and output types
# should all be the same
if not (inputs[0].dtype == output.dtype):
Expand Down
9 changes: 9 additions & 0 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from torch.fx import Node
Expand All @@ -40,6 +43,7 @@ def define_node(

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 2)
# Specification (0.80) states that input and output types
# should all be the same
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
Expand Down Expand Up @@ -118,6 +122,7 @@ def define_node(

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 2)
# Specification (0.80) states that input and output types
# should all be the same
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
Expand Down Expand Up @@ -169,6 +174,8 @@ def define_node(

import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 2)

# Specification (1.0) states that input and output types
# should all be the same
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
Expand Down Expand Up @@ -237,6 +244,8 @@ def define_node(

import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 2)

# Specification (1.0) states that input and output types
# should all be the same
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
Expand Down
7 changes: 7 additions & 0 deletions backends/arm/operators/op_amax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from torch.fx import Node

Expand All @@ -31,6 +34,8 @@ def define_node(
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts

validate_num_inputs(self.target, inputs, 3)

input = inputs[0]
dim = inputs[1].number

Expand Down Expand Up @@ -71,6 +76,8 @@ def define_node(
) -> None:
import serializer.tosa_serializer as ts

validate_num_inputs(self.target, inputs, 3)

input = inputs[0]
dim = inputs[1].number

Expand Down
7 changes: 7 additions & 0 deletions backends/arm/operators/op_amin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from torch.fx import Node

Expand All @@ -31,6 +34,8 @@ def define_node(
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts

validate_num_inputs(self.target, inputs, 3)

input = inputs[0]
dim = inputs[1].number

Expand Down Expand Up @@ -71,6 +76,8 @@ def define_node(
) -> None:
import serializer.tosa_serializer as ts

validate_num_inputs(self.target, inputs, 3)

input = inputs[0]
dim = inputs[1].number

Expand Down
7 changes: 7 additions & 0 deletions backends/arm/operators/op_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
)

from executorch.backends.arm.tosa_mapping import TosaArg # type: ignore
from torch.fx import Node
Expand All @@ -30,6 +33,8 @@ def define_node(
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 3)

if not (inputs[0].dtype == output.dtype):
raise ValueError(
"All inputs and outputs need same dtype."
Expand Down Expand Up @@ -69,6 +74,8 @@ def define_node(
) -> None:
import serializer.tosa_serializer as ts

validate_num_inputs(self.target, inputs, 3)

if not (inputs[0].dtype == output.dtype):
raise ValueError(
"All inputs and outputs need same dtype."
Expand Down
11 changes: 11 additions & 0 deletions backends/arm/operators/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification

Expand Down Expand Up @@ -85,6 +88,8 @@ def define_node(
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, [3, 4, 6])

supported_dtypes = [ts.DType.INT8]
if inputs[0].dtype not in supported_dtypes:
raise TypeError(
Expand Down Expand Up @@ -122,6 +127,8 @@ def define_node(
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, [3, 4, 6])

supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
if inputs[0].dtype not in supported_dtypes:
raise TypeError(
Expand Down Expand Up @@ -212,6 +219,8 @@ def define_node(
) -> None:
import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, [3, 4, 6])

supported_dtypes = [ts.DType.INT8]
if inputs[0].dtype not in supported_dtypes:
raise TypeError(
Expand Down Expand Up @@ -252,6 +261,8 @@ def define_node(
) -> None:
import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, [3, 4, 6])

supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
if inputs[0].dtype not in supported_dtypes:
raise TypeError(
Expand Down
7 changes: 6 additions & 1 deletion backends/arm/operators/op_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
NodeVisitor,
register_node_visitor,
)

from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import build_rescale, build_rescale_v0_80
from executorch.backends.arm.tosa_specification import TosaSpecification
Expand Down Expand Up @@ -46,6 +48,7 @@ def define_node(

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 2)
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
raise TypeError(
f"All IO needs to have the same data type, got: "
Expand Down Expand Up @@ -128,6 +131,8 @@ def define_node(

import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 2)

if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
raise TypeError(
f"All IO needs to have the same data type, got: "
Expand Down
7 changes: 7 additions & 0 deletions backends/arm/operators/op_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from torch.fx import Node

Expand All @@ -33,6 +36,8 @@ def define_node(
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, [1, 2])

tensors = inputs[0].special
dim = 0 if len(inputs) < 2 else inputs[1].number
rank = len(output.shape)
Expand Down Expand Up @@ -68,6 +73,8 @@ def define_node(
) -> None:
import serializer.tosa_serializer as ts

validate_num_inputs(self.target, inputs, [1, 2])

tensors = inputs[0].special
dim = 0 if len(inputs) < 2 else inputs[1].number
rank = len(output.shape)
Expand Down
29 changes: 7 additions & 22 deletions backends/arm/operators/op_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
)

from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
Expand Down Expand Up @@ -65,9 +68,6 @@ def cast_type(value: Any) -> int | float:
# Attempt to cast to float
return float(value)

if len(node.args) != 2 and len(node.args) != 3:
raise ValueError(f"Expected len(node.args) to be 2 or 3, got {node.args}")

min_arg = dtype_min
max_arg = dtype_max

Expand All @@ -87,10 +87,7 @@ def define_node(
inputs: List[TosaArg],
output: TosaArg,
) -> None:
if len(node.all_input_nodes) != 1:
raise ValueError(
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
)
validate_num_inputs(self.target, inputs, [2, 3])

min_int8, max_int8 = self._get_min_max_arguments(
node,
Expand Down Expand Up @@ -130,10 +127,7 @@ def define_node(
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

if len(node.all_input_nodes) != 1:
raise ValueError(
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
)
validate_num_inputs(self.target, inputs, [2, 3])

if inputs[0].dtype == ts.DType.INT8:
# Call the inherited define_node for handling integers
Expand Down Expand Up @@ -178,9 +172,6 @@ def cast_type(value: Any) -> int | float:
# Attempt to cast to float
return float(value)

if len(node.args) != 2 and len(node.args) != 3:
raise ValueError(f"Expected len(node.args) to be 2 or 3, got {node.args}")

min_arg = dtype_min
max_arg = dtype_max

Expand All @@ -202,10 +193,7 @@ def define_node(
) -> None:
import serializer.tosa_serializer as ts # type: ignore

if len(node.all_input_nodes) != 1:
raise ValueError(
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
)
validate_num_inputs(self.target, inputs, [2, 3])

# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
min_int8, max_int8 = self._get_min_max_arguments(
Expand Down Expand Up @@ -247,10 +235,7 @@ def define_node(
) -> None:
import serializer.tosa_serializer as ts # type: ignore

if len(node.all_input_nodes) != 1:
raise ValueError(
f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}"
)
validate_num_inputs(self.target, inputs, [2, 3])

min_fp32, max_fp32 = self._get_min_max_arguments(
node,
Expand Down
8 changes: 7 additions & 1 deletion backends/arm/operators/op_constant_pad_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification

Expand All @@ -39,6 +42,8 @@ def define_node(
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts

validate_num_inputs(self.target, inputs, 3)

if inputs[0].dtype == ts.DType.INT8:
input_qparams = get_input_qparams(node)
qargs = input_qparams[0]
Expand Down Expand Up @@ -98,9 +103,10 @@ def define_node(
inputs: List[TosaArg],
output: TosaArg,
) -> None:

import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 3)

if inputs[0].dtype == ts.DType.INT8:
input_qparams = get_input_qparams(node)
qargs = input_qparams[0]
Expand Down
Loading
Loading