Skip to content

Commit 54e7c75

Browse files
Arm backend: Add validation steps to op_neg (#10942)
The validation steps replace the raises that previously verified the same thing. This reduces duplicated code. Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 2ec8678 commit 54e7c75

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

backends/arm/operators/op_neg.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
NodeVisitor,
1717
register_node_visitor,
1818
)
19-
19+
from executorch.backends.arm.operators.operator_validation_utils import (
20+
validate_num_inputs,
21+
validate_same_dtype,
22+
)
2023
from executorch.backends.arm.tosa_mapping import TosaArg
2124

2225

@@ -60,14 +63,12 @@ def define_node(
6063
ts.DType.FP32,
6164
}
6265

66+
validate_num_inputs(self.target, inputs, 1)
67+
validate_same_dtype(self.target, [*inputs, output])
68+
6369
if inputs[0].dtype not in supported_dtypes:
6470
raise ValueError(f"Unsupported dtype for NEGATE: {inputs[0].dtype}")
6571

66-
if inputs[0].dtype != output.dtype:
67-
raise ValueError(
68-
"All inputs and output need same dtype."
69-
f"Got {inputs[0].dtype=}, {output.dtype=}"
70-
)
7172
input_zp, output_zp = get_negate_zero_points(
7273
node, inputs[0].dtype == ts.DType.INT8
7374
)
@@ -109,14 +110,12 @@ def define_node(
109110
ts.DType.FP32,
110111
}
111112

113+
validate_num_inputs(self.target, inputs, 1)
114+
validate_same_dtype(self.target, [*inputs, output])
115+
112116
if inputs[0].dtype not in supported_dtypes:
113117
raise ValueError(f"Unsupported dtype for NEGATE: {inputs[0].dtype}")
114118

115-
if inputs[0].dtype != output.dtype:
116-
raise ValueError(
117-
"All inputs and output need same dtype."
118-
f"Got {inputs[0].dtype=}, {output.dtype=}"
119-
)
120119
input_zp, output_zp = get_negate_zero_points(
121120
node, inputs[0].dtype == ts.DType.INT8
122121
)

0 commit comments

Comments
 (0)