Skip to content

Commit 12af535

Browse files
authored
Arm backend: Fix TOSA 1.0 node visitor for sum (#10908)
### Summary Fixes serialization for sum.dim_IntList node visitor as well as some rescale handling issues. ### Test plan Tested with internal and external GitHub CI. Signed-off-by: Per Åstrand <[email protected]>
1 parent 4b67dc9 commit 12af535

File tree

13 files changed

+37
-34
lines changed

13 files changed

+37
-34
lines changed

backends/arm/operators/op_abs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def define_node(
164164
scale_back = 1.0
165165
if inputs[0].dtype == ts.DType.INT8:
166166
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
167-
tosa_graph, inputs, node, self.tosa_specs
167+
tosa_graph, inputs, node, self.tosa_spec
168168
) # type: ignore[possibly-undefined]
169169
else:
170170
# input[0].dtype == ts.DType.INT32
@@ -192,7 +192,7 @@ def define_node(
192192
# Scale output back to 8 bit
193193
# pyre-ignore
194194
tqutils.insert_rescale_op_to_int8(
195-
tosa_graph, abs_output, scale_back, node, self.tosa_specs
195+
tosa_graph, abs_output, scale_back, node, self.tosa_spec
196196
) # type: ignore[possibly-undefined]
197197

198198

backends/arm/operators/op_add.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def define_node(
174174
scale_back = 1.0
175175
if inputs[0].dtype == ts.DType.INT8:
176176
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
177-
tosa_graph, inputs, node, self.tosa_specs
177+
tosa_graph, inputs, node, self.tosa_spec
178178
)
179179
else:
180180
# input[0].dtype == ts.DType.INT32
@@ -202,7 +202,7 @@ def define_node(
202202
# Scale output back to 8 bit
203203
# pyre-ignore
204204
tqutils.insert_rescale_op_to_int8(
205-
tosa_graph, add_output, scale_back, node, self.tosa_specs
205+
tosa_graph, add_output, scale_back, node, self.tosa_spec
206206
) # type: ignore[possibly-undefined]
207207

208208

backends/arm/operators/op_eq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def define_node(
9898
if inputs[0].dtype == ts.DType.INT8:
9999
# Rescale inputs to 32 bit
100100
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
101-
tosa_graph, inputs, node, self.tosa_specs
101+
tosa_graph, inputs, node, self.tosa_spec
102102
)
103103

104104
# Update IO

backends/arm/operators/op_ge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def define_node(
9797
if inputs[0].dtype == ts.DType.INT8:
9898
# Rescale inputs to 32 bit
9999
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
100-
tosa_graph, inputs, node, self.tosa_specs
100+
tosa_graph, inputs, node, self.tosa_spec
101101
)
102102

103103
# Update IO

backends/arm/operators/op_gt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def define_node(
9797
if inputs[0].dtype == ts.DType.INT8:
9898
# Rescale inputs to 32 bit
9999
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
100-
tosa_graph, inputs, node, self.tosa_specs
100+
tosa_graph, inputs, node, self.tosa_spec
101101
)
102102

103103
# Update IO

backends/arm/operators/op_le.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def define_node(
9797
if inputs[0].dtype == ts.DType.INT8:
9898
# Rescale inputs to 32 bit
9999
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
100-
tosa_graph, inputs, node, self.tosa_specs
100+
tosa_graph, inputs, node, self.tosa_spec
101101
)
102102

103103
# Update IO

backends/arm/operators/op_lt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def define_node(
9797
if inputs[0].dtype == ts.DType.INT8:
9898
# Rescale inputs to 32 bit
9999
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
100-
tosa_graph, inputs, node, self.tosa_specs
100+
tosa_graph, inputs, node, self.tosa_spec
101101
)
102102

103103
# Update IO

backends/arm/operators/op_maximum.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def define_node(
129129
)
130130

131131
operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
132-
tosa_graph, inputs, node, self.tosa_specs
132+
tosa_graph, inputs, node, self.tosa_spec
133133
)
134134

135135
output.shape = tosa_shape(output.shape, output.dim_order)
@@ -155,5 +155,5 @@ def define_node(
155155
if output.dtype == ts.DType.INT8:
156156
# insert RESCALE from int32 back to int8
157157
tqutils.insert_rescale_op_to_int8(
158-
tosa_graph, max_output, scale_back, node, self.tosa_specs
158+
tosa_graph, max_output, scale_back, node, self.tosa_spec
159159
)

backends/arm/operators/op_minimum.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def define_node(
128128
)
129129

130130
operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
131-
tosa_graph, inputs, node, self.tosa_specs
131+
tosa_graph, inputs, node, self.tosa_spec
132132
)
133133

134134
output.shape = tosa_shape(output.shape, output.dim_order)
@@ -154,5 +154,5 @@ def define_node(
154154
if output.dtype == ts.DType.INT8:
155155
# insert RESCALE from int32 back to int8
156156
tqutils.insert_rescale_op_to_int8(
157-
tosa_graph, min_output, scale_back, node, self.tosa_specs
157+
tosa_graph, min_output, scale_back, node, self.tosa_spec
158158
)

backends/arm/operators/op_mul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,14 +189,14 @@ def define_node(
189189
input_A,
190190
input_A_qargs.zp,
191191
[1.0],
192-
tosa_spec=self.tosa_specs,
192+
tosa_spec=self.tosa_spec,
193193
)
194194
input_B_rescaled = tqutils.build_rescale_to_int32(
195195
tosa_graph,
196196
input_B,
197197
input_B_qargs.zp,
198198
[1.0],
199-
tosa_spec=self.tosa_specs,
199+
tosa_spec=self.tosa_spec,
200200
)
201201

202202
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
@@ -211,7 +211,7 @@ def define_node(
211211
)
212212
output_scale = input_A_qargs.scale * input_B_qargs.scale
213213
tqutils.insert_rescale_op_to_int8(
214-
tosa_graph, mul_output, output_scale, node, self.tosa_specs
214+
tosa_graph, mul_output, output_scale, node, self.tosa_spec
215215
)
216216

217217

backends/arm/operators/op_sub.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def define_node(
168168
scale_back = 1.0
169169
if inputs[0].dtype == ts.DType.INT8:
170170
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
171-
tosa_graph, inputs, node, self.tosa_specs
171+
tosa_graph, inputs, node, self.tosa_spec
172172
)
173173
else:
174174
# input[0].dtype == ts.DType.INT32
@@ -197,7 +197,7 @@ def define_node(
197197
# Scale output back to 8 bit
198198
# pyre-ignore
199199
tqutils.insert_rescale_op_to_int8(
200-
tosa_graph, sub_output, scale_back, node, self.tosa_specs
200+
tosa_graph, sub_output, scale_back, node, self.tosa_spec
201201
) # type: ignore[possibly-undefined]
202202

203203

backends/arm/operators/op_sum.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,11 @@ def define_node(
159159

160160
# Rescale input to 32 bit
161161
rescaled_inputs, scale = tqutils.insert_rescale_ops_to_int32(
162-
tosa_graph,
163-
[tensor],
164-
node,
162+
tosa_graph, [tensor], node, self.tosa_spec
165163
)
166164

167165
attr = ts.TosaSerializerAttribute()
168-
attr.AxisAttribute(tensor.dim_order.index(dim))
166+
attr.ReduceSumAttribute(tensor.dim_order.index(dim))
169167

170168
intermediate = tosa_graph.addIntermediate(
171169
tutils.tosa_shape(output_shape, tensor.dim_order),
@@ -179,7 +177,9 @@ def define_node(
179177
attr,
180178
)
181179

182-
tqutils.insert_rescale_op_to_int8(tosa_graph, intermediate, scale, node)
180+
tqutils.insert_rescale_op_to_int8(
181+
tosa_graph, intermediate, scale, node, self.tosa_spec
182+
)
183183

184184

185185
@register_node_visitor
@@ -212,7 +212,7 @@ def define_node(
212212
output_shape[dim] = 1 # Output shape is input shape with dim reduced
213213

214214
attr = ts.TosaSerializerAttribute()
215-
attr.AxisAttribute(tensor.dim_order.index(dim))
215+
attr.ReduceSumAttribute(tensor.dim_order.index(dim))
216216

217217
tosa_graph.addOperator(
218218
ts.TosaOp.Op().REDUCE_SUM,

backends/arm/tosa_quant_utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def insert_rescale_ops_to_int32(
3232
tosa_graph: Any,
3333
inputs: list[TosaArg],
3434
node: Node,
35-
tosa_spec=tosa_specification.Tosa_0_80,
35+
tosa_spec=None,
3636
) -> tuple[list[Any], float]:
3737
"""Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'.
3838
The scales are adjusted using the smallest scale of all 'nodes'.
@@ -79,7 +79,7 @@ def insert_rescale_op_to_int8(
7979
last_tensor: TosaArg,
8080
scale: float,
8181
node: Node,
82-
tosa_spec=tosa_specification.Tosa_0_80,
82+
tosa_spec=None,
8383
) -> None:
8484
"""Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'.
8585
Parameters:
@@ -323,10 +323,11 @@ def build_rescale_to_int32(
323323
is_scale32: bool = True,
324324
is_double_round: bool = False,
325325
per_channel: bool = False,
326-
tosa_spec=tosa_specification.Tosa_0_80,
326+
tosa_spec=None,
327327
) -> Any:
328328
input_A_rescaled_to_int32 = None
329-
if tosa_spec == tosa_specification.Tosa_0_80:
329+
if not tosa_spec or isinstance(tosa_spec, tosa_specification.Tosa_0_80):
330+
# default to TOSA v0.80 until we switch to v1.0
330331
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
331332

332333
input_A_rescaled_to_int32 = tosa_fb.addIntermediate(
@@ -343,7 +344,7 @@ def build_rescale_to_int32(
343344
output_zp=0,
344345
) # type: ignore[call-arg]
345346

346-
elif isinstance(tosa_spec[0], tosa_specification.Tosa_1_00):
347+
elif isinstance(tosa_spec, tosa_specification.Tosa_1_00):
347348
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
348349
# to the RESCALE op see: https://www.mlplatform.org/tosa/tosa_spec.html#_rescale
349350
import serializer.tosa_serializer as ts # type: ignore
@@ -375,9 +376,10 @@ def build_rescale_from_int32(
375376
is_scale32: bool = True,
376377
is_double_round: bool = False,
377378
per_channel: bool = False,
378-
tosa_spec=tosa_specification.Tosa_0_80,
379+
tosa_spec=None,
379380
) -> None:
380-
if tosa_spec == tosa_specification.Tosa_0_80:
381+
if not tosa_spec or isinstance(tosa_spec, tosa_specification.Tosa_0_80):
382+
# default to TOSA v0.80 until we switch to v1.0
381383
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
382384

383385
build_rescale_v0_80(
@@ -390,7 +392,7 @@ def build_rescale_from_int32(
390392
output_zp=output_zp,
391393
) # type: ignore[call-arg]
392394

393-
elif isinstance(tosa_spec[0], tosa_specification.Tosa_1_00):
395+
elif isinstance(tosa_spec, tosa_specification.Tosa_1_00):
394396
import serializer.tosa_serializer as ts # type: ignore
395397

396398
# For TOSA v1.0 multipliers, shifts, input_zp and output_zp are now inputs
@@ -420,15 +422,16 @@ def build_rescale_conv_output(
420422
weight_scale: list[float],
421423
output_scale: list[float],
422424
output_zp: int,
423-
tosa_spec=tosa_specification.Tosa_0_80,
425+
tosa_spec=None,
424426
):
425427
# TODO add check to verify if this is a Per-channel quantization.
426428
post_conv2d_scale = [
427429
(inp * w) / out for inp, w, out in zip(input_scale, weight_scale, output_scale)
428430
]
429431

430432
# Since we assume the input tensor that is being rescaled is int32 date type, zero point must be 0.
431-
if tosa_spec == tosa_specification.Tosa_0_80:
433+
if not tosa_spec or isinstance(tosa_spec, tosa_specification.Tosa_0_80):
434+
# default to TOSA v0.80 until we switch to v1.0
432435
build_rescale_v0_80(
433436
tosa_fb=tosa_fb,
434437
scale=post_conv2d_scale,

0 commit comments

Comments
 (0)