Skip to content

Commit 60d9de6

Browse files
perzingo
authored andcommitted
Arm backend: Allow lists as input for rescales
In order to support per_channel operations, switch to using lists for the arguments to rescale operation creators. Signed-off-by: Per Åstrand <[email protected]> Change-Id: If67826df631af2540a80e74584fcd6500398ceff
1 parent 19b82bb commit 60d9de6

File tree

5 files changed

+65
-51
lines changed

5 files changed

+65
-51
lines changed

backends/arm/operators/op_bmm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def define_node(
8080

8181
build_rescale(
8282
tosa_fb=tosa_graph,
83-
scale=final_output_scale,
83+
scale=[final_output_scale],
8484
# pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined.
8585
input_node=bmm_result, # type: ignore[possibly-undefined]
8686
output_name=output.name,

backends/arm/operators/op_conv2d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ def define_node(
176176
conv2d_res, # type: ignore[possibly-undefined]
177177
output.name,
178178
output.dtype,
179-
input_scale,
180-
weight_scale,
181-
output_qargs[0].scale,
179+
[input_scale],
180+
[weight_scale],
181+
[output_qargs[0].scale],
182182
output_qargs[0].zp,
183183
)

backends/arm/operators/op_mul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ def define_node(
6363
tosa_graph,
6464
input_A,
6565
input_A_qargs.zp,
66-
rescale_scale=1.0,
66+
[1.0],
6767
)
6868
input_B_rescaled = tqutils.build_rescale_to_int32(
6969
tosa_graph,
7070
input_B,
7171
input_B_qargs.zp,
72-
rescale_scale=1.0,
72+
[1.0],
7373
)
7474

7575
output_shape = tutils.tosa_shape(output.shape, output.dim_order)

backends/arm/operators/op_rescale.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,14 @@ def define_node(
5050

5151
scale_width = 32 if output_dtype == torch.int32 else 16
5252
multiplier, shift = tosa_quant_utils.compute_multiplier_and_shift(
53-
scale, scale_width
53+
[scale], scale_width
5454
)
5555
attr_rescale = ts.TosaSerializerAttribute()
5656
attr_rescale.RescaleAttribute(
5757
input_zp=input_zp,
5858
output_zp=output_zp,
59-
multiplier=[multiplier],
60-
shift=[shift],
59+
multiplier=multiplier,
60+
shift=shift,
6161
scale32=output_dtype == torch.int32,
6262
double_round=False,
6363
per_channel=False,

backends/arm/tosa_quant_utils.py

Lines changed: 56 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def insert_rescale_ops_to_int32(
6969
tosa_graph,
7070
tensor,
7171
qarg.zp,
72-
scale,
72+
[scale],
7373
)
7474
)
7575
return rescaled_nodes, min_scale
@@ -109,7 +109,7 @@ def insert_rescale_op_to_int8(
109109
last_tensor.name,
110110
node.name,
111111
qargs_out.zp,
112-
output_rescale_scale,
112+
[output_rescale_scale],
113113
)
114114

115115

@@ -156,65 +156,73 @@ def is_scale32(type: int) -> ts.DType:
156156
# The RESCALE operator is defined using an integer multiply, add, and shift.
157157
# This utility function is for calculating the multier and shift given a scale.
158158
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
159-
def compute_multiplier_and_shift(scale: float, scaleWidth: int = 32) -> Tuple[int, int]:
159+
def compute_multiplier_and_shift(
160+
scales: list[float], scaleWidth: int = 32
161+
) -> Tuple[list[int], list[int]]:
160162
if scaleWidth == 16:
161163
offset = 15
162164
elif scaleWidth == 32:
163165
offset = 31
164166
else:
165-
raise AssertionError("unsupported scale width")
166-
167-
assert isinstance(scale, float)
167+
raise ValueError(
168+
f"Unsupported scale width: {scaleWidth}, only 16 and 32 are valid values."
169+
)
168170

169-
mantissa, exponent = math.frexp(scale)
170-
shift = exponent
171+
multipliers = []
172+
shifts = []
173+
for scale in scales:
174+
mantissa, exponent = math.frexp(scale)
175+
shift = exponent
171176

172-
const_2_power_15_or_31 = 1 << offset
173-
shifted_mantissa = int(round(mantissa * const_2_power_15_or_31))
177+
const_2_power_15_or_31 = 1 << offset
178+
shifted_mantissa = round(mantissa * const_2_power_15_or_31)
174179

175-
assert shifted_mantissa <= const_2_power_15_or_31
180+
assert shifted_mantissa <= const_2_power_15_or_31
176181

177-
if shifted_mantissa == const_2_power_15_or_31:
178-
shifted_mantissa = int(shifted_mantissa / 2)
179-
shift += 1
182+
if shifted_mantissa == const_2_power_15_or_31:
183+
shifted_mantissa = shifted_mantissa // 2
184+
shift += 1
180185

181-
# TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits.
182-
shift = offset - shift
186+
# TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits.
187+
shift = offset - shift
183188

184-
# INT32_MAX, 2^31 - 1
185-
assert shifted_mantissa <= (const_2_power_15_or_31 - 1)
189+
# INT32_MAX, 2^31 - 1
190+
assert shifted_mantissa <= (const_2_power_15_or_31 - 1)
186191

187-
multiplier = shifted_mantissa
192+
multiplier = shifted_mantissa
188193

189-
if shift > 62:
190-
multiplier = multiplier >> min(31, shift - 62)
191-
shift = 62
192-
return multiplier, shift
194+
if shift > 62:
195+
multiplier = multiplier >> min(31, shift - 62)
196+
shift = 62
197+
multipliers.append(multiplier)
198+
shifts.append(shift)
199+
return multipliers, shifts
193200

194201

195202
def build_rescale(
196203
tosa_fb: TosaSerializer,
197-
scale: float,
204+
scale: list[float],
198205
input_node: TosaSerializerTensor,
199206
output_name: str,
200207
output_type: ts.DType,
201208
output_shape: List[int],
202209
input_zp: int,
203210
output_zp: int,
204211
is_double_round: bool = False,
212+
per_channel=False,
205213
):
206214
scale_width = 32 if is_scale32(output_type) else 16
207-
multiplier, shift = compute_multiplier_and_shift(scale, scale_width)
215+
multipliers, shifts = compute_multiplier_and_shift(scale, scale_width)
208216

209217
attr_rescale = ts.TosaSerializerAttribute()
210218
attr_rescale.RescaleAttribute(
211219
input_zp=input_zp,
212220
output_zp=output_zp,
213-
multiplier=[multiplier],
214-
shift=[shift],
221+
multiplier=multipliers,
222+
shift=shifts,
215223
scale32=is_scale32(output_type),
216224
double_round=is_double_round,
217-
per_channel=False,
225+
per_channel=per_channel,
218226
input_unsigned=False,
219227
output_unsigned=False,
220228
)
@@ -230,20 +238,21 @@ def build_rescale_to_int32(
230238
tosa_fb: TosaSerializer,
231239
input_arg: executorch.backends.arm.tosa_mapping.TosaArg,
232240
input_zp: int,
233-
rescale_scale: float,
241+
rescale_scale: list[float],
234242
is_scale32: bool = True,
235243
is_double_round: bool = False,
244+
per_channel: bool = False,
236245
) -> TosaSerializerTensor:
237-
multiplier, shift = compute_multiplier_and_shift(rescale_scale)
246+
multipliers, shifts = compute_multiplier_and_shift(rescale_scale)
238247
attr_rescale = ts.TosaSerializerAttribute()
239248
attr_rescale.RescaleAttribute(
240249
input_zp=input_zp,
241250
output_zp=0,
242-
multiplier=[multiplier],
243-
shift=[shift],
251+
multiplier=multipliers,
252+
shift=shifts,
244253
scale32=is_scale32,
245254
double_round=is_double_round,
246-
per_channel=False,
255+
per_channel=per_channel,
247256
input_unsigned=False,
248257
output_unsigned=False,
249258
)
@@ -263,20 +272,21 @@ def build_rescale_from_int32(
263272
input_name: str,
264273
output_name: str,
265274
output_zp: int,
266-
rescale_scale: float,
275+
rescale_scale: list[float],
267276
is_scale32: bool = True,
268277
is_double_round: bool = False,
278+
per_channel: bool = False,
269279
) -> None:
270-
multiplier, shift = compute_multiplier_and_shift(rescale_scale)
280+
multipliers, shifts = compute_multiplier_and_shift(rescale_scale)
271281
attr_rescale_output = ts.TosaSerializerAttribute()
272282
attr_rescale_output.RescaleAttribute(
273283
input_zp=0,
274284
output_zp=output_zp,
275-
multiplier=[multiplier],
276-
shift=[shift],
285+
multiplier=multipliers,
286+
shift=shifts,
277287
scale32=is_scale32,
278288
double_round=is_double_round,
279-
per_channel=False,
289+
per_channel=per_channel,
280290
input_unsigned=False,
281291
output_unsigned=False,
282292
)
@@ -296,13 +306,15 @@ def build_rescale_conv_output(
296306
op: TosaSerializerTensor,
297307
output_name: str,
298308
output_type: ts.DType,
299-
input_scale: float,
300-
weight_scale: float,
301-
output_scale: float,
309+
input_scale: list[float],
310+
weight_scale: list[float],
311+
output_scale: list[float],
302312
output_zp: int,
303313
):
304314
# TODO add check to verify if this is a Per-channel quantization.
305-
post_conv2d_scale = (input_scale * weight_scale) / output_scale
315+
post_conv2d_scale = [
316+
(inp * w) / out for inp, w, out in zip(input_scale, weight_scale, output_scale)
317+
]
306318

307319
# Since we assume the input tensor that is being rescaled is int32 date type, zero point must be 0.
308320
build_rescale(
@@ -314,5 +326,7 @@ def build_rescale_conv_output(
314326
op.shape,
315327
0,
316328
output_zp,
329+
False,
330+
isinstance(weight_scale, torch.Tensor),
317331
)
318332
return

0 commit comments

Comments
 (0)