Skip to content

Commit fbe049b

Browse files
committed
Add scale width and double rounding option in buildRescale
1 parent 578fdf3 commit fbe049b

File tree

1 file changed

+32
-14
lines changed

1 file changed

+32
-14
lines changed

backends/arm/tosa_quant_utils.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,30 +33,42 @@ def isQuantArg(arg):
3333
)
3434

3535

36+
# Check if scale32 mode is used for given output element type
37+
def isScale32(type):
38+
return type == ts.DType.INT8
39+
40+
3641
# TOSA uses the RESCALE operation to scale between values with differing precision.
3742
# The RESCALE operator is defined using an integer multiply, add, and shift.
3843
# This utility function is for calculating the multier and shift given a scale.
3944
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
40-
def computeMultiplierAndShift(scale):
45+
def computeMultiplierAndShift(scale, scaleWidth=32):
46+
if scaleWidth == 16:
47+
offset = 15
48+
elif scaleWidth == 32:
49+
offset = 31
50+
else:
51+
raise AssertionError("unsupported scale width")
52+
4153
assert isinstance(scale, float)
4254

4355
mantissa, exponent = math.frexp(scale)
4456
shift = exponent
4557

46-
const_two_to_31 = 1 << 31
47-
shifted_mantissa = round(mantissa * const_two_to_31)
58+
const_2_power_15_or_31 = 1 << offset
59+
shifted_mantissa = round(mantissa * const_2_power_15_or_31)
4860

49-
assert shifted_mantissa <= const_two_to_31
61+
assert shifted_mantissa <= const_2_power_15_or_31
5062

51-
if shifted_mantissa == const_two_to_31:
63+
if shifted_mantissa == const_2_power_15_or_31:
5264
shifted_mantissa = shifted_mantissa / 2
5365
shift += 1
5466

55-
# TOSA expects right shift to be positive, and embed (1 << 31) into right shift bits.
56-
shift = 31 - shift
67+
# TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits.
68+
shift = offset - shift
5769

5870
# INT32_MAX, 2^31 - 1
59-
assert shifted_mantissa <= (const_two_to_31 - 1)
71+
assert shifted_mantissa <= (const_2_power_15_or_31 - 1)
6072

6173
multiplier = shifted_mantissa
6274

@@ -67,7 +79,7 @@ def computeMultiplierAndShift(scale):
6779

6880

6981
def buildRescaleToInt32(
70-
tosa_fb, input, input_zp, rescale_scale
82+
tosa_fb, input, input_zp, rescale_scale, is_scale32=True, is_double_round=True
7183
) -> TosaSerializerTensor:
7284
multiplier, shift = computeMultiplierAndShift(rescale_scale)
7385
attr_rescale = ts.TosaSerializerAttribute()
@@ -76,8 +88,8 @@ def buildRescaleToInt32(
7688
output_zp=0,
7789
multiplier=[multiplier],
7890
shift=[shift],
79-
scale32=True,
80-
double_round=True,
91+
scale32=is_scale32,
92+
double_round=is_double_round,
8193
per_channel=False,
8294
)
8395
input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input.shape, ts.DType.INT32)
@@ -92,7 +104,13 @@ def buildRescaleToInt32(
92104

93105

94106
def buildRescaleFromInt32(
95-
tosa_fb, input_name, output_name, output_zp, rescale_scale
107+
tosa_fb,
108+
input_name,
109+
output_name,
110+
output_zp,
111+
rescale_scale,
112+
is_scale32=True,
113+
is_double_round=True,
96114
) -> TosaSerializerTensor:
97115
multiplier, shift = computeMultiplierAndShift(rescale_scale)
98116
attr_rescale_output = ts.TosaSerializerAttribute()
@@ -101,8 +119,8 @@ def buildRescaleFromInt32(
101119
output_zp=output_zp,
102120
multiplier=[multiplier],
103121
shift=[shift],
104-
scale32=True,
105-
double_round=True,
122+
scale32=is_scale32,
123+
double_round=is_double_round,
106124
per_channel=False,
107125
)
108126

0 commit comments

Comments
 (0)