@@ -69,7 +69,7 @@ def insert_rescale_ops_to_int32(
69
69
tosa_graph ,
70
70
tensor ,
71
71
qarg .zp ,
72
- scale ,
72
+ [ scale ] ,
73
73
)
74
74
)
75
75
return rescaled_nodes , min_scale
@@ -109,7 +109,7 @@ def insert_rescale_op_to_int8(
109
109
last_tensor .name ,
110
110
node .name ,
111
111
qargs_out .zp ,
112
- output_rescale_scale ,
112
+ [ output_rescale_scale ] ,
113
113
)
114
114
115
115
@@ -156,65 +156,73 @@ def is_scale32(type: int) -> ts.DType:
156
156
# The RESCALE operator is defined using an integer multiply, add, and shift.
157
157
# This utility function is for calculating the multier and shift given a scale.
158
158
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
159
- def compute_multiplier_and_shift (scale : float , scaleWidth : int = 32 ) -> Tuple [int , int ]:
159
+ def compute_multiplier_and_shift (
160
+ scales : list [float ], scaleWidth : int = 32
161
+ ) -> Tuple [list [int ], list [int ]]:
160
162
if scaleWidth == 16 :
161
163
offset = 15
162
164
elif scaleWidth == 32 :
163
165
offset = 31
164
166
else :
165
- raise AssertionError ( "unsupported scale width" )
166
-
167
- assert isinstance ( scale , float )
167
+ raise ValueError (
168
+ f"Unsupported scale width: { scaleWidth } , only 16 and 32 are valid values."
169
+ )
168
170
169
- mantissa , exponent = math .frexp (scale )
170
- shift = exponent
171
+ multipliers = []
172
+ shifts = []
173
+ for scale in scales :
174
+ mantissa , exponent = math .frexp (scale )
175
+ shift = exponent
171
176
172
- const_2_power_15_or_31 = 1 << offset
173
- shifted_mantissa = int ( round (mantissa * const_2_power_15_or_31 ) )
177
+ const_2_power_15_or_31 = 1 << offset
178
+ shifted_mantissa = round (mantissa * const_2_power_15_or_31 )
174
179
175
- assert shifted_mantissa <= const_2_power_15_or_31
180
+ assert shifted_mantissa <= const_2_power_15_or_31
176
181
177
- if shifted_mantissa == const_2_power_15_or_31 :
178
- shifted_mantissa = int ( shifted_mantissa / 2 )
179
- shift += 1
182
+ if shifted_mantissa == const_2_power_15_or_31 :
183
+ shifted_mantissa = shifted_mantissa // 2
184
+ shift += 1
180
185
181
- # TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits.
182
- shift = offset - shift
186
+ # TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits.
187
+ shift = offset - shift
183
188
184
- # INT32_MAX, 2^31 - 1
185
- assert shifted_mantissa <= (const_2_power_15_or_31 - 1 )
189
+ # INT32_MAX, 2^31 - 1
190
+ assert shifted_mantissa <= (const_2_power_15_or_31 - 1 )
186
191
187
- multiplier = shifted_mantissa
192
+ multiplier = shifted_mantissa
188
193
189
- if shift > 62 :
190
- multiplier = multiplier >> min (31 , shift - 62 )
191
- shift = 62
192
- return multiplier , shift
194
+ if shift > 62 :
195
+ multiplier = multiplier >> min (31 , shift - 62 )
196
+ shift = 62
197
+ multipliers .append (multiplier )
198
+ shifts .append (shift )
199
+ return multipliers , shifts
193
200
194
201
195
202
def build_rescale (
196
203
tosa_fb : TosaSerializer ,
197
- scale : float ,
204
+ scale : list [ float ] ,
198
205
input_node : TosaSerializerTensor ,
199
206
output_name : str ,
200
207
output_type : ts .DType ,
201
208
output_shape : List [int ],
202
209
input_zp : int ,
203
210
output_zp : int ,
204
211
is_double_round : bool = False ,
212
+ per_channel = False ,
205
213
):
206
214
scale_width = 32 if is_scale32 (output_type ) else 16
207
- multiplier , shift = compute_multiplier_and_shift (scale , scale_width )
215
+ multipliers , shifts = compute_multiplier_and_shift (scale , scale_width )
208
216
209
217
attr_rescale = ts .TosaSerializerAttribute ()
210
218
attr_rescale .RescaleAttribute (
211
219
input_zp = input_zp ,
212
220
output_zp = output_zp ,
213
- multiplier = [ multiplier ] ,
214
- shift = [ shift ] ,
221
+ multiplier = multipliers ,
222
+ shift = shifts ,
215
223
scale32 = is_scale32 (output_type ),
216
224
double_round = is_double_round ,
217
- per_channel = False ,
225
+ per_channel = per_channel ,
218
226
input_unsigned = False ,
219
227
output_unsigned = False ,
220
228
)
@@ -230,20 +238,21 @@ def build_rescale_to_int32(
230
238
tosa_fb : TosaSerializer ,
231
239
input_arg : executorch .backends .arm .tosa_mapping .TosaArg ,
232
240
input_zp : int ,
233
- rescale_scale : float ,
241
+ rescale_scale : list [ float ] ,
234
242
is_scale32 : bool = True ,
235
243
is_double_round : bool = False ,
244
+ per_channel : bool = False ,
236
245
) -> TosaSerializerTensor :
237
- multiplier , shift = compute_multiplier_and_shift (rescale_scale )
246
+ multipliers , shifts = compute_multiplier_and_shift (rescale_scale )
238
247
attr_rescale = ts .TosaSerializerAttribute ()
239
248
attr_rescale .RescaleAttribute (
240
249
input_zp = input_zp ,
241
250
output_zp = 0 ,
242
- multiplier = [ multiplier ] ,
243
- shift = [ shift ] ,
251
+ multiplier = multipliers ,
252
+ shift = shifts ,
244
253
scale32 = is_scale32 ,
245
254
double_round = is_double_round ,
246
- per_channel = False ,
255
+ per_channel = per_channel ,
247
256
input_unsigned = False ,
248
257
output_unsigned = False ,
249
258
)
@@ -263,20 +272,21 @@ def build_rescale_from_int32(
263
272
input_name : str ,
264
273
output_name : str ,
265
274
output_zp : int ,
266
- rescale_scale : float ,
275
+ rescale_scale : list [ float ] ,
267
276
is_scale32 : bool = True ,
268
277
is_double_round : bool = False ,
278
+ per_channel : bool = False ,
269
279
) -> None :
270
- multiplier , shift = compute_multiplier_and_shift (rescale_scale )
280
+ multipliers , shifts = compute_multiplier_and_shift (rescale_scale )
271
281
attr_rescale_output = ts .TosaSerializerAttribute ()
272
282
attr_rescale_output .RescaleAttribute (
273
283
input_zp = 0 ,
274
284
output_zp = output_zp ,
275
- multiplier = [ multiplier ] ,
276
- shift = [ shift ] ,
285
+ multiplier = multipliers ,
286
+ shift = shifts ,
277
287
scale32 = is_scale32 ,
278
288
double_round = is_double_round ,
279
- per_channel = False ,
289
+ per_channel = per_channel ,
280
290
input_unsigned = False ,
281
291
output_unsigned = False ,
282
292
)
@@ -296,13 +306,15 @@ def build_rescale_conv_output(
296
306
op : TosaSerializerTensor ,
297
307
output_name : str ,
298
308
output_type : ts .DType ,
299
- input_scale : float ,
300
- weight_scale : float ,
301
- output_scale : float ,
309
+ input_scale : list [ float ] ,
310
+ weight_scale : list [ float ] ,
311
+ output_scale : list [ float ] ,
302
312
output_zp : int ,
303
313
):
304
314
# TODO add check to verify if this is a Per-channel quantization.
305
- post_conv2d_scale = (input_scale * weight_scale ) / output_scale
315
+ post_conv2d_scale = [
316
+ (inp * w ) / out for inp , w , out in zip (input_scale , weight_scale , output_scale )
317
+ ]
306
318
307
319
# Since we assume the input tensor that is being rescaled is int32 date type, zero point must be 0.
308
320
build_rescale (
@@ -314,5 +326,7 @@ def build_rescale_conv_output(
314
326
op .shape ,
315
327
0 ,
316
328
output_zp ,
329
+ False ,
330
+ isinstance (weight_scale , torch .Tensor ),
317
331
)
318
332
return
0 commit comments