Skip to content

Commit aacee59

Browse files
Sebastian-Larssonzingo
authored andcommitted
Arm backend: Add validation for same dtype to operators
When applicable, check that the data types of inputs to a given operator are the same. Change-Id: I94edea9433413e4adb49edbd93e247596e0cd0b7 Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 1b593ad commit aacee59

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+193
-251
lines changed

backends/arm/operators/op_abs.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616
from executorch.backends.arm.operators.operator_validation_utils import (
1717
validate_num_inputs,
18+
validate_same_dtype,
1819
)
1920
from executorch.backends.arm.tosa_mapping import TosaArg
2021
from executorch.backends.arm.tosa_specification import TosaSpecification
@@ -43,13 +44,8 @@ def define_node(
4344
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4445

4546
validate_num_inputs(self.target, inputs, 1)
46-
# Specification (0.80) states that input and output types
47-
# should all be the same
48-
if not (inputs[0].dtype == output.dtype):
49-
raise ValueError(
50-
"All inputs and outputs need same dtype."
51-
f"Got {inputs[0].dtype=}, {output.dtype=}"
52-
)
47+
validate_same_dtype(self.target, [*inputs, output])
48+
5349
# Handle int8 (quantized) and int32
5450
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
5551
raise ValueError(
@@ -110,13 +106,7 @@ def define_node(
110106
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
111107

112108
validate_num_inputs(self.target, inputs, 1)
113-
# Specification (0.80) states that input and output types
114-
# should all be the same
115-
if not (inputs[0].dtype == output.dtype):
116-
raise ValueError(
117-
"All inputs and output need same dtype."
118-
f"Got {inputs[0].dtype=}, {output.dtype=}"
119-
)
109+
validate_same_dtype(self.target, [*inputs, output])
120110

121111
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
122112
# Call the inherited define_node for handling integers
@@ -163,14 +153,8 @@ def define_node(
163153
import serializer.tosa_serializer as ts # type: ignore
164154

165155
validate_num_inputs(self.target, inputs, 1)
156+
validate_same_dtype(self.target, [*inputs, output])
166157

167-
# Specification (1.0) states that input and output types
168-
# should all be the same
169-
if not (inputs[0].dtype == output.dtype):
170-
raise ValueError(
171-
"All inputs and outputs need same dtype."
172-
f"Got {inputs[0].dtype=}, {output.dtype=}"
173-
)
174158
# Handle int8 (quantized) and int32
175159
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
176160
raise ValueError(
@@ -232,14 +216,7 @@ def define_node(
232216
import serializer.tosa_serializer as ts # type: ignore
233217

234218
validate_num_inputs(self.target, inputs, 1)
235-
236-
# Specification (1.0) states that input and output types
237-
# should all be the same
238-
if not (inputs[0].dtype == output.dtype):
239-
raise ValueError(
240-
"All inputs and output need same dtype."
241-
f"Got {inputs[0].dtype=}, {output.dtype=}"
242-
)
219+
validate_same_dtype(self.target, [*inputs, output])
243220

244221
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
245222
# Call the inherited define_node for handling integers

backends/arm/operators/op_add.py

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717
from executorch.backends.arm.operators.operator_validation_utils import (
1818
validate_num_inputs,
19+
validate_same_dtype,
1920
)
2021
from executorch.backends.arm.tosa_mapping import TosaArg
2122
from executorch.backends.arm.tosa_specification import TosaSpecification
@@ -44,14 +45,8 @@ def define_node(
4445
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4546

4647
validate_num_inputs(self.target, inputs, 2)
47-
# Specification (0.80) states that input and output types
48-
# should all be the same
49-
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
50-
raise TypeError(
51-
f"All IO needs to have the same data type, got input 1: "
52-
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
53-
f"{output.dtype}"
54-
)
48+
validate_same_dtype(self.target, [*inputs, output])
49+
5550
# Handle int8 (quantized) and int32
5651
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
5752
if inputs[0].dtype not in supported_dtypes:
@@ -123,14 +118,7 @@ def define_node(
123118
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
124119

125120
validate_num_inputs(self.target, inputs, 2)
126-
# Specification (0.80) states that input and output types
127-
# should all be the same
128-
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
129-
raise TypeError(
130-
f"All IO needs to have the same data type, got input 1: "
131-
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
132-
f"{output.dtype}"
133-
)
121+
validate_same_dtype(self.target, [*inputs, output])
134122

135123
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
136124
# Call the inherited define_node for handling integers
@@ -175,15 +163,8 @@ def define_node(
175163
import serializer.tosa_serializer as ts # type: ignore
176164

177165
validate_num_inputs(self.target, inputs, 2)
166+
validate_same_dtype(self.target, [*inputs, output])
178167

179-
# Specification (1.0) states that input and output types
180-
# should all be the same
181-
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
182-
raise TypeError(
183-
f"All IO needs to have the same data type, got input 1: "
184-
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
185-
f"{output.dtype}"
186-
)
187168
# Handle int8 (quantized) and int32
188169
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
189170
if inputs[0].dtype not in supported_dtypes:
@@ -245,15 +226,7 @@ def define_node(
245226
import serializer.tosa_serializer as ts # type: ignore
246227

247228
validate_num_inputs(self.target, inputs, 2)
248-
249-
# Specification (1.0) states that input and output types
250-
# should all be the same
251-
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
252-
raise TypeError(
253-
f"All IO needs to have the same data type, got input 1: "
254-
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
255-
f"{output.dtype}"
256-
)
229+
validate_same_dtype(self.target, [*inputs, output])
257230

258231
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
259232
# Call the inherited define_node for handling integers

backends/arm/operators/op_amax.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from executorch.backends.arm.operators.operator_validation_utils import (
1313
validate_num_inputs,
14+
validate_same_dtype,
1415
)
1516
from executorch.backends.arm.tosa_mapping import TosaArg
1617
from torch.fx import Node
@@ -35,6 +36,7 @@ def define_node(
3536
import tosa_tools.v0_80.serializer.tosa_serializer as ts
3637

3738
validate_num_inputs(self.target, inputs, 3)
39+
validate_same_dtype(self.target, [inputs[0], output])
3840

3941
input = inputs[0]
4042
dim = inputs[1].number
@@ -77,6 +79,7 @@ def define_node(
7779
import serializer.tosa_serializer as ts
7880

7981
validate_num_inputs(self.target, inputs, 3)
82+
validate_same_dtype(self.target, [inputs[0], output])
8083

8184
input = inputs[0]
8285
dim = inputs[1].number

backends/arm/operators/op_amin.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from executorch.backends.arm.operators.operator_validation_utils import (
1313
validate_num_inputs,
14+
validate_same_dtype,
1415
)
1516
from executorch.backends.arm.tosa_mapping import TosaArg
1617
from torch.fx import Node
@@ -35,6 +36,7 @@ def define_node(
3536
import tosa_tools.v0_80.serializer.tosa_serializer as ts
3637

3738
validate_num_inputs(self.target, inputs, 3)
39+
validate_same_dtype(self.target, [inputs[0], output])
3840

3941
input = inputs[0]
4042
dim = inputs[1].number
@@ -77,6 +79,7 @@ def define_node(
7779
import serializer.tosa_serializer as ts
7880

7981
validate_num_inputs(self.target, inputs, 3)
82+
validate_same_dtype(self.target, [inputs[0], output])
8083

8184
input = inputs[0]
8285
dim = inputs[1].number

backends/arm/operators/op_any.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from executorch.backends.arm.operators.operator_validation_utils import (
1414
validate_num_inputs,
15+
validate_same_dtype,
1516
)
1617

1718
from executorch.backends.arm.tosa_mapping import TosaArg # type: ignore
@@ -34,12 +35,8 @@ def define_node(
3435
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
3536

3637
validate_num_inputs(self.target, inputs, 3)
38+
validate_same_dtype(self.target, [inputs[0], output])
3739

38-
if not (inputs[0].dtype == output.dtype):
39-
raise ValueError(
40-
"All inputs and outputs need same dtype."
41-
f"Got {ts.DTypeNames[inputs[0].dtype]=}, {ts.DTypeNames[output.dtype]=}."
42-
)
4340
if not (inputs[0].dtype == ts.DType.BOOL):
4441
raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}")
4542

@@ -75,12 +72,8 @@ def define_node(
7572
import serializer.tosa_serializer as ts
7673

7774
validate_num_inputs(self.target, inputs, 3)
75+
validate_same_dtype(self.target, [inputs[0], output])
7876

79-
if not (inputs[0].dtype == output.dtype):
80-
raise ValueError(
81-
"All inputs and outputs need same dtype."
82-
f"Got {ts.DTypeNames[inputs[0].dtype]=}, {ts.DTypeNames[output.dtype]=}."
83-
)
8477
if not (inputs[0].dtype == ts.DType.BOOL):
8578
raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}")
8679

backends/arm/operators/op_avg_pool2d.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from executorch.backends.arm.operators.operator_validation_utils import (
2020
validate_num_inputs,
21+
validate_same_dtype,
2122
)
2223
from executorch.backends.arm.tosa_mapping import TosaArg
2324
from executorch.backends.arm.tosa_specification import TosaSpecification
@@ -89,6 +90,7 @@ def define_node(
8990
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
9091

9192
validate_num_inputs(self.target, inputs, [3, 4, 6])
93+
validate_same_dtype(self.target, [inputs[0], output])
9294

9395
supported_dtypes = [ts.DType.INT8]
9496
if inputs[0].dtype not in supported_dtypes:
@@ -128,6 +130,7 @@ def define_node(
128130
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
129131

130132
validate_num_inputs(self.target, inputs, [3, 4, 6])
133+
validate_same_dtype(self.target, [inputs[0], output])
131134

132135
supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
133136
if inputs[0].dtype not in supported_dtypes:
@@ -220,6 +223,7 @@ def define_node(
220223
import serializer.tosa_serializer as ts # type: ignore
221224

222225
validate_num_inputs(self.target, inputs, [3, 4, 6])
226+
validate_same_dtype(self.target, [inputs[0], output])
223227

224228
supported_dtypes = [ts.DType.INT8]
225229
if inputs[0].dtype not in supported_dtypes:
@@ -262,6 +266,7 @@ def define_node(
262266
import serializer.tosa_serializer as ts # type: ignore
263267

264268
validate_num_inputs(self.target, inputs, [3, 4, 6])
269+
validate_same_dtype(self.target, [inputs[0], output])
265270

266271
supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
267272
if inputs[0].dtype not in supported_dtypes:

backends/arm/operators/op_bmm.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
)
2020
from executorch.backends.arm.operators.operator_validation_utils import (
2121
validate_num_inputs,
22+
validate_same_dtype,
2223
)
2324
from executorch.backends.arm.tosa_mapping import TosaArg
2425
from executorch.backends.arm.tosa_quant_utils import build_rescale, build_rescale_v0_80
@@ -49,11 +50,7 @@ def define_node(
4950
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
5051

5152
validate_num_inputs(self.target, inputs, 2)
52-
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
53-
raise TypeError(
54-
f"All IO needs to have the same data type, got: "
55-
f"{inputs[0].dtype=}, {inputs[1].dtype=} and {output.dtype=}"
56-
)
53+
validate_same_dtype(self.target, [*inputs, output])
5754

5855
# aten.bmm maps directly to MATMUL
5956
# NOTE: For now, only INT8 & FP32 is supported
@@ -132,12 +129,7 @@ def define_node(
132129
import serializer.tosa_serializer as ts # type: ignore
133130

134131
validate_num_inputs(self.target, inputs, 2)
135-
136-
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
137-
raise TypeError(
138-
f"All IO needs to have the same data type, got: "
139-
f"{inputs[0].dtype=}, {inputs[1].dtype=} and {output.dtype=}"
140-
)
132+
validate_same_dtype(self.target, [*inputs, output])
141133

142134
# aten.bmm maps directly to MATMUL
143135
# NOTE: For now, only INT8 & FP32 is supported

backends/arm/operators/op_clamp.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from executorch.backends.arm.operators.operator_validation_utils import (
1919
validate_num_inputs,
20+
validate_same_dtype,
2021
)
2122

2223
from executorch.backends.arm.tosa_mapping import TosaArg
@@ -88,6 +89,7 @@ def define_node(
8889
output: TosaArg,
8990
) -> None:
9091
validate_num_inputs(self.target, inputs, [2, 3])
92+
validate_same_dtype(self.target, [inputs[0], output])
9193

9294
min_int8, max_int8 = self._get_min_max_arguments(
9395
node,
@@ -128,6 +130,7 @@ def define_node(
128130
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
129131

130132
validate_num_inputs(self.target, inputs, [2, 3])
133+
validate_same_dtype(self.target, [inputs[0], output])
131134

132135
if inputs[0].dtype == ts.DType.INT8:
133136
# Call the inherited define_node for handling integers
@@ -194,6 +197,7 @@ def define_node(
194197
import serializer.tosa_serializer as ts # type: ignore
195198

196199
validate_num_inputs(self.target, inputs, [2, 3])
200+
validate_same_dtype(self.target, [inputs[0], output])
197201

198202
# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
199203
min_int8, max_int8 = self._get_min_max_arguments(
@@ -236,6 +240,7 @@ def define_node(
236240
import serializer.tosa_serializer as ts # type: ignore
237241

238242
validate_num_inputs(self.target, inputs, [2, 3])
243+
validate_same_dtype(self.target, [inputs[0], output])
239244

240245
min_fp32, max_fp32 = self._get_min_max_arguments(
241246
node,

backends/arm/operators/op_constant_pad_nd.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from executorch.backends.arm.operators.operator_validation_utils import (
2020
validate_num_inputs,
21+
validate_same_dtype,
2122
)
2223
from executorch.backends.arm.tosa_mapping import TosaArg
2324
from executorch.backends.arm.tosa_specification import TosaSpecification
@@ -43,6 +44,7 @@ def define_node(
4344
import tosa_tools.v0_80.serializer.tosa_serializer as ts
4445

4546
validate_num_inputs(self.target, inputs, 3)
47+
validate_same_dtype(self.target, [inputs[0], output])
4648

4749
if inputs[0].dtype == ts.DType.INT8:
4850
input_qparams = get_input_qparams(node)
@@ -106,6 +108,7 @@ def define_node(
106108
import serializer.tosa_serializer as ts # type: ignore
107109

108110
validate_num_inputs(self.target, inputs, 3)
111+
validate_same_dtype(self.target, [inputs[0], output])
109112

110113
if inputs[0].dtype == ts.DType.INT8:
111114
input_qparams = get_input_qparams(node)

backends/arm/operators/op_cos.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from executorch.backends.arm.operators.operator_validation_utils import (
1515
validate_num_inputs,
16+
validate_same_dtype,
1617
)
1718
from executorch.backends.arm.tosa_mapping import TosaArg
1819
from executorch.backends.arm.tosa_specification import TosaSpecification
@@ -37,6 +38,7 @@ def define_node(
3738
output: TosaArg,
3839
) -> None:
3940
validate_num_inputs(self.target, inputs, 1)
41+
validate_same_dtype(self.target, [*inputs, output])
4042
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
4143
raise ValueError(
4244
f"Input and output for {self.target} need to be FP32, got input_dtype: "

0 commit comments

Comments
 (0)