@@ -32,30 +32,42 @@ def isQuantArg(arg):
32
32
return consumer_node .target == q_op
33
33
34
34
35
+ # Check if scale32 mode is used for given output element type
36
+ def isScale32 (type ):
37
+ return type == ts .DType .INT8
38
+
39
+
35
40
# TOSA uses the RESCALE operation to scale between values with differing precision.
36
41
# The RESCALE operator is defined using an integer multiply, add, and shift.
37
42
# This utility function is for calculating the multier and shift given a scale.
38
43
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
39
- def computeMultiplierAndShift (scale ):
44
+ def computeMultiplierAndShift (scale , scaleWidth = 32 ):
45
+ if scaleWidth == 16 :
46
+ offset = 15
47
+ elif scaleWidth == 32 :
48
+ offset = 31
49
+ else :
50
+ raise AssertionError ("unsupported scale width" )
51
+
40
52
assert isinstance (scale , float )
41
53
42
54
mantissa , exponent = math .frexp (scale )
43
55
shift = exponent
44
56
45
- const_two_to_31 = 1 << 31
46
- shifted_mantissa = round (mantissa * const_two_to_31 )
57
+ const_2_power_15_or_31 = 1 << offset
58
+ shifted_mantissa = round (mantissa * const_2_power_15_or_31 )
47
59
48
- assert shifted_mantissa <= const_two_to_31
60
+ assert shifted_mantissa <= const_2_power_15_or_31
49
61
50
- if shifted_mantissa == const_two_to_31 :
62
+ if shifted_mantissa == const_2_power_15_or_31 :
51
63
shifted_mantissa = shifted_mantissa / 2
52
64
shift += 1
53
65
54
- # TOSA expects right shift to be positive, and embed (1 << 31 ) into right shift bits.
55
- shift = 31 - shift
66
+ # TOSA expects right shift to be positive, and embed (1 << offset ) into right shift bits.
67
+ shift = offset - shift
56
68
57
69
# INT32_MAX, 2^31 - 1
58
- assert shifted_mantissa <= (const_two_to_31 - 1 )
70
+ assert shifted_mantissa <= (const_2_power_15_or_31 - 1 )
59
71
60
72
multiplier = shifted_mantissa
61
73
@@ -65,8 +77,41 @@ def computeMultiplierAndShift(scale):
65
77
return multiplier , shift
66
78
67
79
80
+ def buildRescale (
81
+ tosa_fb ,
82
+ scale ,
83
+ input_node ,
84
+ output_type ,
85
+ output_shape ,
86
+ input_zp ,
87
+ output_zp ,
88
+ is_double_round ,
89
+ ):
90
+ is_scale32 = isScale32 (output_type )
91
+ scale_width = 32 if is_scale32 else 16
92
+ multiplier , shift = computeMultiplierAndShift (scale , scale_width )
93
+
94
+ attr_rescale = ts .TosaSerializerAttribute ()
95
+ attr_rescale .RescaleAttribute (
96
+ input_zp = input_zp ,
97
+ output_zp = output_zp ,
98
+ multiplier = [multiplier ],
99
+ shift = [shift ],
100
+ scale32 = is_scale32 ,
101
+ double_round = is_double_round ,
102
+ per_channel = False ,
103
+ )
104
+
105
+ rescale_out = tosa_fb .addIntermediate (output_shape , output_type )
106
+ tosa_fb .addOperator (
107
+ TosaOp .Op ().RESCALE , [input_node .name ], [rescale_out .name ], attr_rescale
108
+ )
109
+
110
+ return rescale_out
111
+
112
+
68
113
def buildRescaleToInt32 (
69
- tosa_fb , input , input_zp , rescale_scale
114
+ tosa_fb , input , input_zp , rescale_scale , is_scale32 = True , is_double_round = True
70
115
) -> TosaSerializerTensor :
71
116
multiplier , shift = computeMultiplierAndShift (rescale_scale )
72
117
attr_rescale = ts .TosaSerializerAttribute ()
@@ -75,8 +120,8 @@ def buildRescaleToInt32(
75
120
output_zp = 0 ,
76
121
multiplier = [multiplier ],
77
122
shift = [shift ],
78
- scale32 = True ,
79
- double_round = True ,
123
+ scale32 = is_scale32 ,
124
+ double_round = is_double_round ,
80
125
per_channel = False ,
81
126
)
82
127
input_A_rescaled_to_int32 = tosa_fb .addIntermediate (input .shape , ts .DType .INT32 )
@@ -91,7 +136,13 @@ def buildRescaleToInt32(
91
136
92
137
93
138
def buildRescaleFromInt32 (
94
- tosa_fb , input_name , output_name , output_zp , rescale_scale
139
+ tosa_fb ,
140
+ input_name ,
141
+ output_name ,
142
+ output_zp ,
143
+ rescale_scale ,
144
+ is_scale32 = True ,
145
+ is_double_round = True ,
95
146
) -> TosaSerializerTensor :
96
147
multiplier , shift = computeMultiplierAndShift (rescale_scale )
97
148
attr_rescale_output = ts .TosaSerializerAttribute ()
@@ -100,8 +151,8 @@ def buildRescaleFromInt32(
100
151
output_zp = output_zp ,
101
152
multiplier = [multiplier ],
102
153
shift = [shift ],
103
- scale32 = True ,
104
- double_round = True ,
154
+ scale32 = is_scale32 ,
155
+ double_round = is_double_round ,
105
156
per_channel = False ,
106
157
)
107
158
0 commit comments