Skip to content

Arm backend: Add validation for same dtype to operators #10872

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 6 additions & 29 deletions backends/arm/operators/op_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_same_dtype,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
Expand Down Expand Up @@ -43,13 +44,8 @@ def define_node(
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 1)
# Specification (0.80) states that input and output types
# should all be the same
if not (inputs[0].dtype == output.dtype):
raise ValueError(
"All inputs and outputs need same dtype."
f"Got {inputs[0].dtype=}, {output.dtype=}"
)
validate_same_dtype(self.target, [*inputs, output])

# Handle int8 (quantized) and int32
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
raise ValueError(
Expand Down Expand Up @@ -110,13 +106,7 @@ def define_node(
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 1)
# Specification (0.80) states that input and output types
# should all be the same
if not (inputs[0].dtype == output.dtype):
raise ValueError(
"All inputs and output need same dtype."
f"Got {inputs[0].dtype=}, {output.dtype=}"
)
validate_same_dtype(self.target, [*inputs, output])

if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
# Call the inherited define_node for handling integers
Expand Down Expand Up @@ -163,14 +153,8 @@ def define_node(
import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 1)
validate_same_dtype(self.target, [*inputs, output])

# Specification (1.0) states that input and output types
# should all be the same
if not (inputs[0].dtype == output.dtype):
raise ValueError(
"All inputs and outputs need same dtype."
f"Got {inputs[0].dtype=}, {output.dtype=}"
)
# Handle int8 (quantized) and int32
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
raise ValueError(
Expand Down Expand Up @@ -232,14 +216,7 @@ def define_node(
import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 1)

# Specification (1.0) states that input and output types
# should all be the same
if not (inputs[0].dtype == output.dtype):
raise ValueError(
"All inputs and output need same dtype."
f"Got {inputs[0].dtype=}, {output.dtype=}"
)
validate_same_dtype(self.target, [*inputs, output])

if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
# Call the inherited define_node for handling integers
Expand Down
39 changes: 6 additions & 33 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_same_dtype,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
Expand Down Expand Up @@ -44,14 +45,8 @@ def define_node(
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 2)
# Specification (0.80) states that input and output types
# should all be the same
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
raise TypeError(
f"All IO needs to have the same data type, got input 1: "
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
f"{output.dtype}"
)
validate_same_dtype(self.target, [*inputs, output])

# Handle int8 (quantized) and int32
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
if inputs[0].dtype not in supported_dtypes:
Expand Down Expand Up @@ -123,14 +118,7 @@ def define_node(
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 2)
# Specification (0.80) states that input and output types
# should all be the same
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
raise TypeError(
f"All IO needs to have the same data type, got input 1: "
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
f"{output.dtype}"
)
validate_same_dtype(self.target, [*inputs, output])

if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
# Call the inherited define_node for handling integers
Expand Down Expand Up @@ -175,15 +163,8 @@ def define_node(
import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 2)
validate_same_dtype(self.target, [*inputs, output])

# Specification (1.0) states that input and output types
# should all be the same
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
raise TypeError(
f"All IO needs to have the same data type, got input 1: "
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
f"{output.dtype}"
)
# Handle int8 (quantized) and int32
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
if inputs[0].dtype not in supported_dtypes:
Expand Down Expand Up @@ -245,15 +226,7 @@ def define_node(
import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 2)

# Specification (1.0) states that input and output types
# should all be the same
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
raise TypeError(
f"All IO needs to have the same data type, got input 1: "
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
f"{output.dtype}"
)
validate_same_dtype(self.target, [*inputs, output])

if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
# Call the inherited define_node for handling integers
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/operators/op_amax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_same_dtype,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from torch.fx import Node
Expand All @@ -35,6 +36,7 @@ def define_node(
import tosa_tools.v0_80.serializer.tosa_serializer as ts

validate_num_inputs(self.target, inputs, 3)
validate_same_dtype(self.target, [inputs[0], output])

input = inputs[0]
dim = inputs[1].number
Expand Down Expand Up @@ -77,6 +79,7 @@ def define_node(
import serializer.tosa_serializer as ts

validate_num_inputs(self.target, inputs, 3)
validate_same_dtype(self.target, [inputs[0], output])

input = inputs[0]
dim = inputs[1].number
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/operators/op_amin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_same_dtype,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from torch.fx import Node
Expand All @@ -35,6 +36,7 @@ def define_node(
import tosa_tools.v0_80.serializer.tosa_serializer as ts

validate_num_inputs(self.target, inputs, 3)
validate_same_dtype(self.target, [inputs[0], output])

input = inputs[0]
dim = inputs[1].number
Expand Down Expand Up @@ -77,6 +79,7 @@ def define_node(
import serializer.tosa_serializer as ts

validate_num_inputs(self.target, inputs, 3)
validate_same_dtype(self.target, [inputs[0], output])

input = inputs[0]
dim = inputs[1].number
Expand Down
13 changes: 3 additions & 10 deletions backends/arm/operators/op_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_same_dtype,
)

from executorch.backends.arm.tosa_mapping import TosaArg # type: ignore
Expand All @@ -34,12 +35,8 @@ def define_node(
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 3)
validate_same_dtype(self.target, [inputs[0], output])

if not (inputs[0].dtype == output.dtype):
raise ValueError(
"All inputs and outputs need same dtype."
f"Got {ts.DTypeNames[inputs[0].dtype]=}, {ts.DTypeNames[output.dtype]=}."
)
if not (inputs[0].dtype == ts.DType.BOOL):
raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}")

Expand Down Expand Up @@ -75,12 +72,8 @@ def define_node(
import serializer.tosa_serializer as ts

validate_num_inputs(self.target, inputs, 3)
validate_same_dtype(self.target, [inputs[0], output])

if not (inputs[0].dtype == output.dtype):
raise ValueError(
"All inputs and outputs need same dtype."
f"Got {ts.DTypeNames[inputs[0].dtype]=}, {ts.DTypeNames[output.dtype]=}."
)
if not (inputs[0].dtype == ts.DType.BOOL):
raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}")

Expand Down
5 changes: 5 additions & 0 deletions backends/arm/operators/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_same_dtype,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
Expand Down Expand Up @@ -89,6 +90,7 @@ def define_node(
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, [3, 4, 6])
validate_same_dtype(self.target, [inputs[0], output])

supported_dtypes = [ts.DType.INT8]
if inputs[0].dtype not in supported_dtypes:
Expand Down Expand Up @@ -128,6 +130,7 @@ def define_node(
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, [3, 4, 6])
validate_same_dtype(self.target, [inputs[0], output])

supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
if inputs[0].dtype not in supported_dtypes:
Expand Down Expand Up @@ -220,6 +223,7 @@ def define_node(
import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, [3, 4, 6])
validate_same_dtype(self.target, [inputs[0], output])

supported_dtypes = [ts.DType.INT8]
if inputs[0].dtype not in supported_dtypes:
Expand Down Expand Up @@ -262,6 +266,7 @@ def define_node(
import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, [3, 4, 6])
validate_same_dtype(self.target, [inputs[0], output])

supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
if inputs[0].dtype not in supported_dtypes:
Expand Down
14 changes: 3 additions & 11 deletions backends/arm/operators/op_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_same_dtype,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import build_rescale, build_rescale_v0_80
Expand Down Expand Up @@ -49,11 +50,7 @@ def define_node(
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 2)
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
raise TypeError(
f"All IO needs to have the same data type, got: "
f"{inputs[0].dtype=}, {inputs[1].dtype=} and {output.dtype=}"
)
validate_same_dtype(self.target, [*inputs, output])

# aten.bmm maps directly to MATMUL
# NOTE: For now, only INT8 & FP32 is supported
Expand Down Expand Up @@ -132,12 +129,7 @@ def define_node(
import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 2)

if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
raise TypeError(
f"All IO needs to have the same data type, got: "
f"{inputs[0].dtype=}, {inputs[1].dtype=} and {output.dtype=}"
)
validate_same_dtype(self.target, [*inputs, output])

# aten.bmm maps directly to MATMUL
# NOTE: For now, only INT8 & FP32 is supported
Expand Down
5 changes: 5 additions & 0 deletions backends/arm/operators/op_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_same_dtype,
)

from executorch.backends.arm.tosa_mapping import TosaArg
Expand Down Expand Up @@ -88,6 +89,7 @@ def define_node(
output: TosaArg,
) -> None:
validate_num_inputs(self.target, inputs, [2, 3])
validate_same_dtype(self.target, [inputs[0], output])

min_int8, max_int8 = self._get_min_max_arguments(
node,
Expand Down Expand Up @@ -128,6 +130,7 @@ def define_node(
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, [2, 3])
validate_same_dtype(self.target, [inputs[0], output])

if inputs[0].dtype == ts.DType.INT8:
# Call the inherited define_node for handling integers
Expand Down Expand Up @@ -194,6 +197,7 @@ def define_node(
import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, [2, 3])
validate_same_dtype(self.target, [inputs[0], output])

# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
min_int8, max_int8 = self._get_min_max_arguments(
Expand Down Expand Up @@ -236,6 +240,7 @@ def define_node(
import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, [2, 3])
validate_same_dtype(self.target, [inputs[0], output])

min_fp32, max_fp32 = self._get_min_max_arguments(
node,
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/operators/op_constant_pad_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_same_dtype,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
Expand All @@ -43,6 +44,7 @@ def define_node(
import tosa_tools.v0_80.serializer.tosa_serializer as ts

validate_num_inputs(self.target, inputs, 3)
validate_same_dtype(self.target, [inputs[0], output])

if inputs[0].dtype == ts.DType.INT8:
input_qparams = get_input_qparams(node)
Expand Down Expand Up @@ -106,6 +108,7 @@ def define_node(
import serializer.tosa_serializer as ts # type: ignore

validate_num_inputs(self.target, inputs, 3)
validate_same_dtype(self.target, [inputs[0], output])

if inputs[0].dtype == ts.DType.INT8:
input_qparams = get_input_qparams(node)
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operators/op_cos.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_same_dtype,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
Expand All @@ -37,6 +38,7 @@ def define_node(
output: TosaArg,
) -> None:
validate_num_inputs(self.target, inputs, 1)
validate_same_dtype(self.target, [*inputs, output])
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
raise ValueError(
f"Input and output for {self.target} need to be FP32, got input_dtype: "
Expand Down
Loading
Loading