Skip to content

Commit 42c4cdb

Browse files
martinlsmMartin Lindström
andauthored
Arm backend: Make validate_same_dtype print readable types (#11468)
Prior to this change, it was hard to read which types the function validate_same_dtype referred to upon error prints. Modify this function such that it now prints the types in readable format, i.e., the name of the data type instead of its enum value. Co-authored-by: Martin Lindström <[email protected]>
1 parent e02ca41 commit 42c4cdb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+113
-102
lines changed

backends/arm/operators/op_abs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def define_node(
4444
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4545

4646
validate_num_inputs(self.target, inputs, 1)
47-
validate_same_dtype(self.target, [*inputs, output])
47+
validate_same_dtype(self.target, [*inputs, output], ts)
4848

4949
# Handle int8 (quantized) and int32
5050
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
@@ -106,7 +106,7 @@ def define_node(
106106
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
107107

108108
validate_num_inputs(self.target, inputs, 1)
109-
validate_same_dtype(self.target, [*inputs, output])
109+
validate_same_dtype(self.target, [*inputs, output], ts)
110110

111111
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
112112
# Call the inherited define_node for handling integers
@@ -153,7 +153,7 @@ def define_node(
153153
import serializer.tosa_serializer as ts # type: ignore
154154

155155
validate_num_inputs(self.target, inputs, 1)
156-
validate_same_dtype(self.target, [*inputs, output])
156+
validate_same_dtype(self.target, [*inputs, output], ts)
157157

158158
# Handle int8 (quantized) and int32
159159
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
@@ -216,7 +216,7 @@ def define_node(
216216
import serializer.tosa_serializer as ts # type: ignore
217217

218218
validate_num_inputs(self.target, inputs, 1)
219-
validate_same_dtype(self.target, [*inputs, output])
219+
validate_same_dtype(self.target, [*inputs, output], ts)
220220

221221
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
222222
# Call the inherited define_node for handling integers

backends/arm/operators/op_add.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def define_node(
4545
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4646

4747
validate_num_inputs(self.target, inputs, 2)
48-
validate_same_dtype(self.target, [*inputs, output])
48+
validate_same_dtype(self.target, [*inputs, output], ts)
4949

5050
# Handle int8 (quantized) and int32
5151
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
@@ -118,7 +118,7 @@ def define_node(
118118
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
119119

120120
validate_num_inputs(self.target, inputs, 2)
121-
validate_same_dtype(self.target, [*inputs, output])
121+
validate_same_dtype(self.target, [*inputs, output], ts)
122122

123123
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
124124
# Call the inherited define_node for handling integers
@@ -163,7 +163,7 @@ def define_node(
163163
import serializer.tosa_serializer as ts # type: ignore
164164

165165
validate_num_inputs(self.target, inputs, 2)
166-
validate_same_dtype(self.target, [*inputs, output])
166+
validate_same_dtype(self.target, [*inputs, output], ts)
167167

168168
# Handle int8 (quantized) and int32
169169
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
@@ -226,7 +226,7 @@ def define_node(
226226
import serializer.tosa_serializer as ts # type: ignore
227227

228228
validate_num_inputs(self.target, inputs, 2)
229-
validate_same_dtype(self.target, [*inputs, output])
229+
validate_same_dtype(self.target, [*inputs, output], ts)
230230

231231
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
232232
# Call the inherited define_node for handling integers

backends/arm/operators/op_amax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def define_node(
3636
import tosa_tools.v0_80.serializer.tosa_serializer as ts
3737

3838
validate_num_inputs(self.target, inputs, 3)
39-
validate_same_dtype(self.target, [inputs[0], output])
39+
validate_same_dtype(self.target, [inputs[0], output], ts)
4040

4141
input = inputs[0]
4242
dim = inputs[1].number
@@ -79,7 +79,7 @@ def define_node(
7979
import serializer.tosa_serializer as ts
8080

8181
validate_num_inputs(self.target, inputs, 3)
82-
validate_same_dtype(self.target, [inputs[0], output])
82+
validate_same_dtype(self.target, [inputs[0], output], ts)
8383

8484
input = inputs[0]
8585
dim = inputs[1].number

backends/arm/operators/op_amin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def define_node(
3636
import tosa_tools.v0_80.serializer.tosa_serializer as ts
3737

3838
validate_num_inputs(self.target, inputs, 3)
39-
validate_same_dtype(self.target, [inputs[0], output])
39+
validate_same_dtype(self.target, [inputs[0], output], ts)
4040

4141
input = inputs[0]
4242
dim = inputs[1].number
@@ -79,7 +79,7 @@ def define_node(
7979
import serializer.tosa_serializer as ts
8080

8181
validate_num_inputs(self.target, inputs, 3)
82-
validate_same_dtype(self.target, [inputs[0], output])
82+
validate_same_dtype(self.target, [inputs[0], output], ts)
8383

8484
input = inputs[0]
8585
dim = inputs[1].number

backends/arm/operators/op_any.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def define_node(
3535
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
3636

3737
validate_num_inputs(self.target, inputs, 3)
38-
validate_same_dtype(self.target, [inputs[0], output])
38+
validate_same_dtype(self.target, [inputs[0], output], ts)
3939

4040
if not (inputs[0].dtype == ts.DType.BOOL):
4141
raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}")
@@ -72,7 +72,7 @@ def define_node(
7272
import serializer.tosa_serializer as ts
7373

7474
validate_num_inputs(self.target, inputs, 3)
75-
validate_same_dtype(self.target, [inputs[0], output])
75+
validate_same_dtype(self.target, [inputs[0], output], ts)
7676

7777
if not (inputs[0].dtype == ts.DType.BOOL):
7878
raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}")

backends/arm/operators/op_avg_pool2d.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def define_node(
105105
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
106106

107107
validate_num_inputs(self.target, inputs, [3, 4, 6])
108-
validate_same_dtype(self.target, [inputs[0], output])
108+
validate_same_dtype(self.target, [inputs[0], output], ts)
109109

110110
supported_dtypes = [ts.DType.INT8]
111111
if inputs[0].dtype not in supported_dtypes:
@@ -145,7 +145,7 @@ def define_node(
145145
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
146146

147147
validate_num_inputs(self.target, inputs, [3, 4, 6])
148-
validate_same_dtype(self.target, [inputs[0], output])
148+
validate_same_dtype(self.target, [inputs[0], output], ts)
149149

150150
supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
151151
if inputs[0].dtype not in supported_dtypes:
@@ -252,7 +252,7 @@ def define_node(
252252
import serializer.tosa_serializer as ts # type: ignore
253253

254254
validate_num_inputs(self.target, inputs, [3, 4, 6])
255-
validate_same_dtype(self.target, [inputs[0], output])
255+
validate_same_dtype(self.target, [inputs[0], output], ts)
256256

257257
supported_dtypes = [ts.DType.INT8]
258258
if inputs[0].dtype not in supported_dtypes:
@@ -295,7 +295,7 @@ def define_node(
295295
import serializer.tosa_serializer as ts # type: ignore
296296

297297
validate_num_inputs(self.target, inputs, [3, 4, 6])
298-
validate_same_dtype(self.target, [inputs[0], output])
298+
validate_same_dtype(self.target, [inputs[0], output], ts)
299299

300300
supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
301301
if inputs[0].dtype not in supported_dtypes:

backends/arm/operators/op_bmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def define_node(
5050
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
5151

5252
validate_num_inputs(self.target, inputs, 2)
53-
validate_same_dtype(self.target, [*inputs, output])
53+
validate_same_dtype(self.target, [*inputs, output], ts)
5454

5555
# aten.bmm maps directly to MATMUL
5656
# NOTE: For now, only INT8 & FP32 is supported
@@ -129,7 +129,7 @@ def define_node(
129129
import serializer.tosa_serializer as ts # type: ignore
130130

131131
validate_num_inputs(self.target, inputs, 2)
132-
validate_same_dtype(self.target, [*inputs, output])
132+
validate_same_dtype(self.target, [*inputs, output], ts)
133133

134134
# aten.bmm maps directly to MATMUL
135135
# NOTE: For now, only INT8 & FP32 is supported

backends/arm/operators/op_clamp.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,10 @@ def define_node(
8888
inputs: List[TosaArg],
8989
output: TosaArg,
9090
) -> None:
91+
import tosa_tools.v0_80.serializer.tosa_serializer as ts
92+
9193
validate_num_inputs(self.target, inputs, [2, 3])
92-
validate_same_dtype(self.target, [inputs[0], output])
94+
validate_same_dtype(self.target, [inputs[0], output], ts)
9395

9496
min_int8, max_int8 = self._get_min_max_arguments(
9597
node,
@@ -130,7 +132,7 @@ def define_node(
130132
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
131133

132134
validate_num_inputs(self.target, inputs, [2, 3])
133-
validate_same_dtype(self.target, [inputs[0], output])
135+
validate_same_dtype(self.target, [inputs[0], output], ts)
134136

135137
if inputs[0].dtype == ts.DType.INT8:
136138
# Call the inherited define_node for handling integers
@@ -197,7 +199,7 @@ def define_node(
197199
import serializer.tosa_serializer as ts # type: ignore
198200

199201
validate_num_inputs(self.target, inputs, [2, 3])
200-
validate_same_dtype(self.target, [inputs[0], output])
202+
validate_same_dtype(self.target, [inputs[0], output], ts)
201203

202204
# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
203205
min_int8, max_int8 = self._get_min_max_arguments(
@@ -240,7 +242,7 @@ def define_node(
240242
import serializer.tosa_serializer as ts # type: ignore
241243

242244
validate_num_inputs(self.target, inputs, [2, 3])
243-
validate_same_dtype(self.target, [inputs[0], output])
245+
validate_same_dtype(self.target, [inputs[0], output], ts)
244246

245247
min_fp32, max_fp32 = self._get_min_max_arguments(
246248
node,

backends/arm/operators/op_constant_pad_nd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def define_node(
4444
import tosa_tools.v0_80.serializer.tosa_serializer as ts
4545

4646
validate_num_inputs(self.target, inputs, 3)
47-
validate_same_dtype(self.target, [inputs[0], output])
47+
validate_same_dtype(self.target, [inputs[0], output], ts)
4848

4949
if inputs[0].dtype == ts.DType.INT8:
5050
input_qparams = get_input_qparams(node)
@@ -108,7 +108,7 @@ def define_node(
108108
import serializer.tosa_serializer as ts # type: ignore
109109

110110
validate_num_inputs(self.target, inputs, 3)
111-
validate_same_dtype(self.target, [inputs[0], output])
111+
validate_same_dtype(self.target, [inputs[0], output], ts)
112112

113113
if inputs[0].dtype == ts.DType.INT8:
114114
input_qparams = get_input_qparams(node)

backends/arm/operators/op_cos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def define_node(
3838
output: TosaArg,
3939
) -> None:
4040
validate_num_inputs(self.target, inputs, 1)
41-
validate_same_dtype(self.target, [*inputs, output])
41+
validate_same_dtype(self.target, [*inputs, output], ts)
4242
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
4343
raise ValueError(
4444
f"Input and output for {self.target} need to be FP32, got input_dtype: "

backends/arm/operators/op_eq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def define_node(
4646
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4747

4848
validate_num_inputs(self.target, inputs, 2)
49-
validate_same_dtype(self.target, inputs)
49+
validate_same_dtype(self.target, inputs, ts)
5050

5151
input_nodes = inputs
5252
# Handle quantization
@@ -91,7 +91,7 @@ def define_node(
9191
import serializer.tosa_serializer as ts # type: ignore
9292

9393
validate_num_inputs(self.target, inputs, 2)
94-
validate_same_dtype(self.target, inputs)
94+
validate_same_dtype(self.target, inputs, ts)
9595

9696
input_nodes = inputs
9797
# Handle quantization

backends/arm/operators/op_erf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def define_node(
3838
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
3939

4040
validate_num_inputs(self.target, inputs, 1)
41-
validate_same_dtype(self.target, [*inputs, output])
41+
validate_same_dtype(self.target, [*inputs, output], ts)
4242

4343
if not (inputs[0].dtype == ts.DType.FP32):
4444
raise ValueError("All inputs need to be FP32." f"Got {inputs[0].dtype=}")
@@ -66,7 +66,7 @@ def define_node(
6666
import serializer.tosa_serializer as ts
6767

6868
validate_num_inputs(self.target, inputs, 1)
69-
validate_same_dtype(self.target, [*inputs, output])
69+
validate_same_dtype(self.target, [*inputs, output], ts)
7070

7171
if not (inputs[0].dtype == ts.DType.FP32):
7272
raise ValueError("All inputs need to be FP32." f"Got {inputs[0].dtype=}")

backends/arm/operators/op_exp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def define_node(
3939
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4040

4141
validate_num_inputs(self.target, inputs, 1)
42-
validate_same_dtype(self.target, [*inputs, output])
42+
validate_same_dtype(self.target, [*inputs, output], ts)
4343

4444
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
4545
raise ValueError(
@@ -70,7 +70,7 @@ def define_node(
7070
import serializer.tosa_serializer as ts
7171

7272
validate_num_inputs(self.target, inputs, 1)
73-
validate_same_dtype(self.target, [*inputs, output])
73+
validate_same_dtype(self.target, [*inputs, output], ts)
7474

7575
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
7676
raise ValueError(

backends/arm/operators/op_ge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def define_node(
4646
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4747

4848
validate_num_inputs(self.target, inputs, 2)
49-
validate_same_dtype(self.target, inputs)
49+
validate_same_dtype(self.target, inputs, ts)
5050

5151
input_nodes = inputs
5252
# Handle quantization
@@ -90,7 +90,7 @@ def define_node(
9090
import serializer.tosa_serializer as ts # type: ignore
9191

9292
validate_num_inputs(self.target, inputs, 2)
93-
validate_same_dtype(self.target, inputs)
93+
validate_same_dtype(self.target, inputs, ts)
9494

9595
input_nodes = inputs
9696
# Handle quantization

backends/arm/operators/op_gt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def define_node(
4646
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4747

4848
validate_num_inputs(self.target, inputs, 2)
49-
validate_same_dtype(self.target, inputs)
49+
validate_same_dtype(self.target, inputs, ts)
5050

5151
input_nodes = inputs
5252
# Handle quantization
@@ -90,7 +90,7 @@ def define_node(
9090
import serializer.tosa_serializer as ts # type: ignore
9191

9292
validate_num_inputs(self.target, inputs, 2)
93-
validate_same_dtype(self.target, inputs)
93+
validate_same_dtype(self.target, inputs, ts)
9494

9595
input_nodes = inputs
9696
# Handle quantization

backends/arm/operators/op_le.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def define_node(
4646
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4747

4848
validate_num_inputs(self.target, inputs, 2)
49-
validate_same_dtype(self.target, inputs)
49+
validate_same_dtype(self.target, inputs, ts)
5050

5151
input_nodes = inputs
5252
# Handle quantization
@@ -90,7 +90,7 @@ def define_node(
9090
import serializer.tosa_serializer as ts # type: ignore
9191

9292
validate_num_inputs(self.target, inputs, 2)
93-
validate_same_dtype(self.target, inputs)
93+
validate_same_dtype(self.target, inputs, ts)
9494

9595
input_nodes = inputs
9696
# Handle quantization

backends/arm/operators/op_log.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def define_node(
3939
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4040

4141
validate_num_inputs(self.target, inputs, 1)
42-
validate_same_dtype(self.target, [*inputs, output])
42+
validate_same_dtype(self.target, [*inputs, output], ts)
4343

4444
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
4545
raise ValueError(
@@ -70,7 +70,7 @@ def define_node(
7070
import serializer.tosa_serializer as ts
7171

7272
validate_num_inputs(self.target, inputs, 1)
73-
validate_same_dtype(self.target, [*inputs, output])
73+
validate_same_dtype(self.target, [*inputs, output], ts)
7474

7575
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
7676
raise ValueError(

backends/arm/operators/op_lt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def define_node(
4646
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4747

4848
validate_num_inputs(self.target, inputs, 2)
49-
validate_same_dtype(self.target, inputs)
49+
validate_same_dtype(self.target, inputs, ts)
5050

5151
input_nodes = inputs
5252
# Handle quantization
@@ -90,7 +90,7 @@ def define_node(
9090
import serializer.tosa_serializer as ts # type: ignore
9191

9292
validate_num_inputs(self.target, inputs, 2)
93-
validate_same_dtype(self.target, inputs)
93+
validate_same_dtype(self.target, inputs, ts)
9494

9595
input_nodes = inputs
9696
# Handle quantization

0 commit comments

Comments
 (0)