Skip to content

Commit 7e1a1da

Browse files
committed
Add buildRescale helper function with scale width and double rounding option
1 parent 578fdf3 commit 7e1a1da

File tree

1 file changed

+65
-14
lines changed

1 file changed

+65
-14
lines changed

backends/arm/tosa_quant_utils.py

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,30 +33,42 @@ def isQuantArg(arg):
3333
)
3434

3535

36+
# Check if scale32 mode is used for given output element type
37+
def isScale32(type):
38+
return type == ts.DType.INT8
39+
40+
3641
# TOSA uses the RESCALE operation to scale between values with differing precision.
3742
# The RESCALE operator is defined using an integer multiply, add, and shift.
3843
# This utility function is for calculating the multier and shift given a scale.
3944
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
40-
def computeMultiplierAndShift(scale):
45+
def computeMultiplierAndShift(scale, scaleWidth=32):
46+
if scaleWidth == 16:
47+
offset = 15
48+
elif scaleWidth == 32:
49+
offset = 31
50+
else:
51+
raise AssertionError("unsupported scale width")
52+
4153
assert isinstance(scale, float)
4254

4355
mantissa, exponent = math.frexp(scale)
4456
shift = exponent
4557

46-
const_two_to_31 = 1 << 31
47-
shifted_mantissa = round(mantissa * const_two_to_31)
58+
const_2_power_15_or_31 = 1 << offset
59+
shifted_mantissa = round(mantissa * const_2_power_15_or_31)
4860

49-
assert shifted_mantissa <= const_two_to_31
61+
assert shifted_mantissa <= const_2_power_15_or_31
5062

51-
if shifted_mantissa == const_two_to_31:
63+
if shifted_mantissa == const_2_power_15_or_31:
5264
shifted_mantissa = shifted_mantissa / 2
5365
shift += 1
5466

55-
# TOSA expects right shift to be positive, and embed (1 << 31) into right shift bits.
56-
shift = 31 - shift
67+
# TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits.
68+
shift = offset - shift
5769

5870
# INT32_MAX, 2^31 - 1
59-
assert shifted_mantissa <= (const_two_to_31 - 1)
71+
assert shifted_mantissa <= (const_2_power_15_or_31 - 1)
6072

6173
multiplier = shifted_mantissa
6274

@@ -66,8 +78,41 @@ def computeMultiplierAndShift(scale):
6678
return multiplier, shift
6779

6880

81+
def buildRescale(
82+
tosa_fb,
83+
scale,
84+
input_node,
85+
output_type,
86+
output_shape,
87+
input_zp,
88+
output_zp,
89+
is_double_round,
90+
):
91+
is_scale32 = isScale32(output_type)
92+
scale_width = 32 if is_scale32 else 16
93+
multiplier, shift = computeMultiplierAndShift(scale, scale_width)
94+
95+
attr_rescale = ts.TosaSerializerAttribute()
96+
attr_rescale.RescaleAttribute(
97+
input_zp=input_zp,
98+
output_zp=output_zp,
99+
multiplier=[multiplier],
100+
shift=[shift],
101+
scale32=is_scale32,
102+
double_round=is_double_round,
103+
per_channel=False,
104+
)
105+
106+
rescale_out = tosa_fb.addIntermediate(output_shape, output_type)
107+
tosa_fb.addOperator(
108+
TosaOp.Op().RESCALE, [input_node.name], [rescale_out.name], attr_rescale
109+
)
110+
111+
return rescale_out
112+
113+
69114
def buildRescaleToInt32(
70-
tosa_fb, input, input_zp, rescale_scale
115+
tosa_fb, input, input_zp, rescale_scale, is_scale32=True, is_double_round=True
71116
) -> TosaSerializerTensor:
72117
multiplier, shift = computeMultiplierAndShift(rescale_scale)
73118
attr_rescale = ts.TosaSerializerAttribute()
@@ -76,8 +121,8 @@ def buildRescaleToInt32(
76121
output_zp=0,
77122
multiplier=[multiplier],
78123
shift=[shift],
79-
scale32=True,
80-
double_round=True,
124+
scale32=is_scale32,
125+
double_round=is_double_round,
81126
per_channel=False,
82127
)
83128
input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input.shape, ts.DType.INT32)
@@ -92,7 +137,13 @@ def buildRescaleToInt32(
92137

93138

94139
def buildRescaleFromInt32(
95-
tosa_fb, input_name, output_name, output_zp, rescale_scale
140+
tosa_fb,
141+
input_name,
142+
output_name,
143+
output_zp,
144+
rescale_scale,
145+
is_scale32=True,
146+
is_double_round=True,
96147
) -> TosaSerializerTensor:
97148
multiplier, shift = computeMultiplierAndShift(rescale_scale)
98149
attr_rescale_output = ts.TosaSerializerAttribute()
@@ -101,8 +152,8 @@ def buildRescaleFromInt32(
101152
output_zp=output_zp,
102153
multiplier=[multiplier],
103154
shift=[shift],
104-
scale32=True,
105-
double_round=True,
155+
scale32=is_scale32,
156+
double_round=is_double_round,
106157
per_channel=False,
107158
)
108159

0 commit comments

Comments
 (0)