8
8
# Utiliy functions for TOSA quantized lowerings
9
9
10
10
import math
11
- from typing import cast , NamedTuple
11
+ from typing import cast , List , NamedTuple , Tuple
12
+
13
+ import executorch .backends .arm .tosa_mapping
12
14
13
15
import serializer .tosa_serializer as ts # type: ignore
14
16
import torch .fx
17
+ import torch .fx .node
15
18
import tosa .Op as TosaOp # type: ignore
16
19
from executorch .backends .arm .tosa_mapping import TosaArg
17
20
from executorch .exir .dialects ._ops import ops as exir_ops
18
- from serializer .tosa_serializer import TosaSerializerTensor
21
+ from serializer .tosa_serializer import TosaSerializer , TosaSerializerTensor
22
+ from torch import Tensor
19
23
from torch .fx import Node
20
24
21
25
@@ -116,7 +120,7 @@ class QuantArgs(NamedTuple):
116
120
qmax : int
117
121
dtype : torch .dtype
118
122
119
- def quantize_value (self , x ) :
123
+ def quantize_value (self , x : torch . Tensor | float ) -> Tensor :
120
124
if not isinstance (x , torch .Tensor ):
121
125
x = torch .Tensor ([x ])
122
126
return torch .clip (
@@ -144,15 +148,15 @@ def from_operator(cls, op, args):
144
148
145
149
146
150
# Check if scale32 mode is used for given output element type
147
- def is_scale32 (type ) :
151
+ def is_scale32 (type : int ) -> ts . DType :
148
152
return type == ts .DType .INT8
149
153
150
154
151
155
# TOSA uses the RESCALE operation to scale between values with differing precision.
152
156
# The RESCALE operator is defined using an integer multiply, add, and shift.
153
157
# This utility function is for calculating the multier and shift given a scale.
154
158
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
155
- def compute_multiplier_and_shift (scale , scaleWidth = 32 ):
159
+ def compute_multiplier_and_shift (scale : float , scaleWidth : int = 32 ) -> Tuple [ int , int ] :
156
160
if scaleWidth == 16 :
157
161
offset = 15
158
162
elif scaleWidth == 32 :
@@ -166,12 +170,12 @@ def compute_multiplier_and_shift(scale, scaleWidth=32):
166
170
shift = exponent
167
171
168
172
const_2_power_15_or_31 = 1 << offset
169
- shifted_mantissa = round (mantissa * const_2_power_15_or_31 )
173
+ shifted_mantissa = int ( round (mantissa * const_2_power_15_or_31 ) )
170
174
171
175
assert shifted_mantissa <= const_2_power_15_or_31
172
176
173
177
if shifted_mantissa == const_2_power_15_or_31 :
174
- shifted_mantissa = shifted_mantissa / 2
178
+ shifted_mantissa = int ( shifted_mantissa / 2 )
175
179
shift += 1
176
180
177
181
# TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits.
@@ -189,15 +193,15 @@ def compute_multiplier_and_shift(scale, scaleWidth=32):
189
193
190
194
191
195
def build_rescale (
192
- tosa_fb ,
193
- scale ,
194
- input_node ,
195
- output_name ,
196
- output_type ,
197
- output_shape ,
198
- input_zp ,
199
- output_zp ,
200
- is_double_round = False ,
196
+ tosa_fb : TosaSerializer ,
197
+ scale : float ,
198
+ input_node : TosaSerializerTensor ,
199
+ output_name : str ,
200
+ output_type : ts . DType ,
201
+ output_shape : List [ int ] ,
202
+ input_zp : int ,
203
+ output_zp : int ,
204
+ is_double_round : bool = False ,
201
205
):
202
206
scale_width = 32 if is_scale32 (output_type ) else 16
203
207
multiplier , shift = compute_multiplier_and_shift (scale , scale_width )
@@ -223,7 +227,12 @@ def build_rescale(
223
227
224
228
225
229
def build_rescale_to_int32 (
226
- tosa_fb , input , input_zp , rescale_scale , is_scale32 = True , is_double_round = False
230
+ tosa_fb : TosaSerializer ,
231
+ input_arg : executorch .backends .arm .tosa_mapping .TosaArg ,
232
+ input_zp : int ,
233
+ rescale_scale : float ,
234
+ is_scale32 : bool = True ,
235
+ is_double_round : bool = False ,
227
236
) -> TosaSerializerTensor :
228
237
multiplier , shift = compute_multiplier_and_shift (rescale_scale )
229
238
attr_rescale = ts .TosaSerializerAttribute ()
@@ -238,10 +247,10 @@ def build_rescale_to_int32(
238
247
input_unsigned = False ,
239
248
output_unsigned = False ,
240
249
)
241
- input_A_rescaled_to_int32 = tosa_fb .addIntermediate (input .shape , ts .DType .INT32 )
250
+ input_A_rescaled_to_int32 = tosa_fb .addIntermediate (input_arg .shape , ts .DType .INT32 )
242
251
tosa_fb .addOperator (
243
252
TosaOp .Op ().RESCALE ,
244
- [input .name ],
253
+ [input_arg .name ],
245
254
[input_A_rescaled_to_int32 .name ],
246
255
attr_rescale ,
247
256
)
@@ -250,13 +259,13 @@ def build_rescale_to_int32(
250
259
251
260
252
261
def build_rescale_from_int32 (
253
- tosa_fb ,
254
- input_name ,
255
- output_name ,
256
- output_zp ,
257
- rescale_scale ,
258
- is_scale32 = True ,
259
- is_double_round = False ,
262
+ tosa_fb : TosaSerializer ,
263
+ input_name : str ,
264
+ output_name : str ,
265
+ output_zp : int ,
266
+ rescale_scale : float ,
267
+ is_scale32 : bool = True ,
268
+ is_double_round : bool = False ,
260
269
) -> None :
261
270
multiplier , shift = compute_multiplier_and_shift (rescale_scale )
262
271
attr_rescale_output = ts .TosaSerializerAttribute ()
@@ -283,14 +292,14 @@ def build_rescale_from_int32(
283
292
284
293
285
294
def build_rescale_conv_output (
286
- tosa_fb ,
287
- op ,
288
- output_name ,
289
- output_type ,
290
- input_scale ,
291
- weight_scale ,
292
- output_scale ,
293
- output_zp ,
295
+ tosa_fb : TosaSerializer ,
296
+ op : TosaSerializerTensor ,
297
+ output_name : str ,
298
+ output_type : ts . DType ,
299
+ input_scale : float ,
300
+ weight_scale : float ,
301
+ output_scale : float ,
302
+ output_zp : int ,
294
303
):
295
304
# TODO add check to verify if this is a Per-channel quantization.
296
305
post_conv2d_scale = (input_scale * weight_scale ) / output_scale
0 commit comments