|
16 | 16 | NodeVisitor,
|
17 | 17 | register_node_visitor,
|
18 | 18 | )
|
19 |
| - |
| 19 | +from executorch.backends.arm.operators.operator_validation_utils import ( |
| 20 | + validate_num_inputs, |
| 21 | + validate_same_dtype, |
| 22 | +) |
20 | 23 | from executorch.backends.arm.tosa_mapping import TosaArg
|
21 | 24 |
|
22 | 25 |
|
@@ -60,14 +63,12 @@ def define_node(
|
60 | 63 | ts.DType.FP32,
|
61 | 64 | }
|
62 | 65 |
|
| 66 | + validate_num_inputs(self.target, inputs, 1) |
| 67 | + validate_same_dtype(self.target, [*inputs, output]) |
| 68 | + |
63 | 69 | if inputs[0].dtype not in supported_dtypes:
|
64 | 70 | raise ValueError(f"Unsupported dtype for NEGATE: {inputs[0].dtype}")
|
65 | 71 |
|
66 |
| - if inputs[0].dtype != output.dtype: |
67 |
| - raise ValueError( |
68 |
| - "All inputs and output need same dtype." |
69 |
| - f"Got {inputs[0].dtype=}, {output.dtype=}" |
70 |
| - ) |
71 | 72 | input_zp, output_zp = get_negate_zero_points(
|
72 | 73 | node, inputs[0].dtype == ts.DType.INT8
|
73 | 74 | )
|
@@ -109,14 +110,12 @@ def define_node(
|
109 | 110 | ts.DType.FP32,
|
110 | 111 | }
|
111 | 112 |
|
| 113 | + validate_num_inputs(self.target, inputs, 1) |
| 114 | + validate_same_dtype(self.target, [*inputs, output]) |
| 115 | + |
112 | 116 | if inputs[0].dtype not in supported_dtypes:
|
113 | 117 | raise ValueError(f"Unsupported dtype for NEGATE: {inputs[0].dtype}")
|
114 | 118 |
|
115 |
| - if inputs[0].dtype != output.dtype: |
116 |
| - raise ValueError( |
117 |
| - "All inputs and output need same dtype." |
118 |
| - f"Got {inputs[0].dtype=}, {output.dtype=}" |
119 |
| - ) |
120 | 119 | input_zp, output_zp = get_negate_zero_points(
|
121 | 120 | node, inputs[0].dtype == ts.DType.INT8
|
122 | 121 | )
|
|
0 commit comments