@@ -33,30 +33,42 @@ def isQuantArg(arg):
33
33
)
34
34
35
35
36
+ # Check if scale32 mode is used for given output element type
37
+ def isScale32 (type ):
38
+ return type == ts .DType .INT8
39
+
40
+
36
41
# TOSA uses the RESCALE operation to scale between values with differing precision.
37
42
# The RESCALE operator is defined using an integer multiply, add, and shift.
38
43
# This utility function is for calculating the multier and shift given a scale.
39
44
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
40
- def computeMultiplierAndShift (scale ):
45
+ def computeMultiplierAndShift (scale , scaleWidth = 32 ):
46
+ if scaleWidth == 16 :
47
+ offset = 15
48
+ elif scaleWidth == 32 :
49
+ offset = 31
50
+ else :
51
+ raise AssertionError ("unsupported scale width" )
52
+
41
53
assert isinstance (scale , float )
42
54
43
55
mantissa , exponent = math .frexp (scale )
44
56
shift = exponent
45
57
46
- const_two_to_31 = 1 << 31
47
- shifted_mantissa = round (mantissa * const_two_to_31 )
58
+ const_2_power_15_or_31 = 1 << offset
59
+ shifted_mantissa = round (mantissa * const_2_power_15_or_31 )
48
60
49
- assert shifted_mantissa <= const_two_to_31
61
+ assert shifted_mantissa <= const_2_power_15_or_31
50
62
51
- if shifted_mantissa == const_two_to_31 :
63
+ if shifted_mantissa == const_2_power_15_or_31 :
52
64
shifted_mantissa = shifted_mantissa / 2
53
65
shift += 1
54
66
55
- # TOSA expects right shift to be positive, and embed (1 << 31 ) into right shift bits.
56
- shift = 31 - shift
67
+ # TOSA expects right shift to be positive, and embed (1 << offset ) into right shift bits.
68
+ shift = offset - shift
57
69
58
70
# INT32_MAX, 2^31 - 1
59
- assert shifted_mantissa <= (const_two_to_31 - 1 )
71
+ assert shifted_mantissa <= (const_2_power_15_or_31 - 1 )
60
72
61
73
multiplier = shifted_mantissa
62
74
@@ -67,7 +79,7 @@ def computeMultiplierAndShift(scale):
67
79
68
80
69
81
def buildRescaleToInt32 (
70
- tosa_fb , input , input_zp , rescale_scale
82
+ tosa_fb , input , input_zp , rescale_scale , is_scale32 = True , is_double_round = True
71
83
) -> TosaSerializerTensor :
72
84
multiplier , shift = computeMultiplierAndShift (rescale_scale )
73
85
attr_rescale = ts .TosaSerializerAttribute ()
@@ -76,8 +88,8 @@ def buildRescaleToInt32(
76
88
output_zp = 0 ,
77
89
multiplier = [multiplier ],
78
90
shift = [shift ],
79
- scale32 = True ,
80
- double_round = True ,
91
+ scale32 = is_scale32 ,
92
+ double_round = is_double_round ,
81
93
per_channel = False ,
82
94
)
83
95
input_A_rescaled_to_int32 = tosa_fb .addIntermediate (input .shape , ts .DType .INT32 )
@@ -92,7 +104,13 @@ def buildRescaleToInt32(
92
104
93
105
94
106
def buildRescaleFromInt32 (
95
- tosa_fb , input_name , output_name , output_zp , rescale_scale
107
+ tosa_fb ,
108
+ input_name ,
109
+ output_name ,
110
+ output_zp ,
111
+ rescale_scale ,
112
+ is_scale32 = True ,
113
+ is_double_round = True ,
96
114
) -> TosaSerializerTensor :
97
115
multiplier , shift = computeMultiplierAndShift (rescale_scale )
98
116
attr_rescale_output = ts .TosaSerializerAttribute ()
@@ -101,8 +119,8 @@ def buildRescaleFromInt32(
101
119
output_zp = output_zp ,
102
120
multiplier = [multiplier ],
103
121
shift = [shift ],
104
- scale32 = True ,
105
- double_round = True ,
122
+ scale32 = is_scale32 ,
123
+ double_round = is_double_round ,
106
124
per_channel = False ,
107
125
)
108
126
0 commit comments