@@ -33,30 +33,42 @@ def isQuantArg(arg):
33
33
)
34
34
35
35
36
+ # Check if scale32 mode is used for given output element type
37
+ def isScale32 (type ):
38
+ return type == ts .DType .INT8
39
+
40
+
36
41
# TOSA uses the RESCALE operation to scale between values with differing precision.
37
42
# The RESCALE operator is defined using an integer multiply, add, and shift.
38
43
# This utility function is for calculating the multier and shift given a scale.
39
44
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
40
- def computeMultiplierAndShift (scale ):
45
+ def computeMultiplierAndShift (scale , scaleWidth = 32 ):
46
+ if scaleWidth == 16 :
47
+ offset = 15
48
+ elif scaleWidth == 32 :
49
+ offset = 31
50
+ else :
51
+ raise AssertionError ("unsupported scale width" )
52
+
41
53
assert isinstance (scale , float )
42
54
43
55
mantissa , exponent = math .frexp (scale )
44
56
shift = exponent
45
57
46
- const_two_to_31 = 1 << 31
47
- shifted_mantissa = round (mantissa * const_two_to_31 )
58
+ const_2_power_15_or_31 = 1 << offset
59
+ shifted_mantissa = round (mantissa * const_2_power_15_or_31 )
48
60
49
- assert shifted_mantissa <= const_two_to_31
61
+ assert shifted_mantissa <= const_2_power_15_or_31
50
62
51
- if shifted_mantissa == const_two_to_31 :
63
+ if shifted_mantissa == const_2_power_15_or_31 :
52
64
shifted_mantissa = shifted_mantissa / 2
53
65
shift += 1
54
66
55
- # TOSA expects right shift to be positive, and embed (1 << 31 ) into right shift bits.
56
- shift = 31 - shift
67
+ # TOSA expects right shift to be positive, and embed (1 << offset ) into right shift bits.
68
+ shift = offset - shift
57
69
58
70
# INT32_MAX, 2^31 - 1
59
- assert shifted_mantissa <= (const_two_to_31 - 1 )
71
+ assert shifted_mantissa <= (const_2_power_15_or_31 - 1 )
60
72
61
73
multiplier = shifted_mantissa
62
74
@@ -66,8 +78,41 @@ def computeMultiplierAndShift(scale):
66
78
return multiplier , shift
67
79
68
80
81
+ def buildRescale (
82
+ tosa_fb ,
83
+ scale ,
84
+ input_node ,
85
+ output_type ,
86
+ output_shape ,
87
+ input_zp ,
88
+ output_zp ,
89
+ is_double_round ,
90
+ ):
91
+ is_scale32 = isScale32 (output_type )
92
+ scale_width = 32 if is_scale32 else 16
93
+ multiplier , shift = computeMultiplierAndShift (scale , scale_width )
94
+
95
+ attr_rescale = ts .TosaSerializerAttribute ()
96
+ attr_rescale .RescaleAttribute (
97
+ input_zp = input_zp ,
98
+ output_zp = output_zp ,
99
+ multiplier = [multiplier ],
100
+ shift = [shift ],
101
+ scale32 = is_scale32 ,
102
+ double_round = is_double_round ,
103
+ per_channel = False ,
104
+ )
105
+
106
+ rescale_out = tosa_fb .addIntermediate (output_shape , output_type )
107
+ tosa_fb .addOperator (
108
+ TosaOp .Op ().RESCALE , [input_node .name ], [rescale_out .name ], attr_rescale
109
+ )
110
+
111
+ return rescale_out
112
+
113
+
69
114
def buildRescaleToInt32 (
70
- tosa_fb , input , input_zp , rescale_scale
115
+ tosa_fb , input , input_zp , rescale_scale , is_scale32 = True , is_double_round = True
71
116
) -> TosaSerializerTensor :
72
117
multiplier , shift = computeMultiplierAndShift (rescale_scale )
73
118
attr_rescale = ts .TosaSerializerAttribute ()
@@ -76,8 +121,8 @@ def buildRescaleToInt32(
76
121
output_zp = 0 ,
77
122
multiplier = [multiplier ],
78
123
shift = [shift ],
79
- scale32 = True ,
80
- double_round = True ,
124
+ scale32 = is_scale32 ,
125
+ double_round = is_double_round ,
81
126
per_channel = False ,
82
127
)
83
128
input_A_rescaled_to_int32 = tosa_fb .addIntermediate (input .shape , ts .DType .INT32 )
@@ -92,7 +137,13 @@ def buildRescaleToInt32(
92
137
93
138
94
139
def buildRescaleFromInt32 (
95
- tosa_fb , input_name , output_name , output_zp , rescale_scale
140
+ tosa_fb ,
141
+ input_name ,
142
+ output_name ,
143
+ output_zp ,
144
+ rescale_scale ,
145
+ is_scale32 = True ,
146
+ is_double_round = True ,
96
147
) -> TosaSerializerTensor :
97
148
multiplier , shift = computeMultiplierAndShift (rescale_scale )
98
149
attr_rescale_output = ts .TosaSerializerAttribute ()
@@ -101,8 +152,8 @@ def buildRescaleFromInt32(
101
152
output_zp = output_zp ,
102
153
multiplier = [multiplier ],
103
154
shift = [shift ],
104
- scale32 = True ,
105
- double_round = True ,
155
+ scale32 = is_scale32 ,
156
+ double_round = is_double_round ,
106
157
per_channel = False ,
107
158
)
108
159
0 commit comments