Skip to content

Commit 17fee78

Browse files
tatwaichongfacebook-github-bot
authored andcommitted
Add buildRescale helper function with scale width and double rounding option (#567)
Summary: These scaling options can be used by lowering of convolution. Pull Request resolved: #567 Reviewed By: larryliu0820 Differential Revision: D49897715 Pulled By: digantdesai fbshipit-source-id: f833aab567bdc4d227231f5845c4f0ff42bf2be1
1 parent bebff52 commit 17fee78

File tree

1 file changed

+65
-14
lines changed

1 file changed

+65
-14
lines changed

backends/arm/tosa_quant_utils.py

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,30 +32,42 @@ def isQuantArg(arg):
3232
return consumer_node.target == q_op
3333

3434

35+
# Check if scale32 mode is used for given output element type
36+
def isScale32(type):
37+
return type == ts.DType.INT8
38+
39+
3540
# TOSA uses the RESCALE operation to scale between values with differing precision.
3641
# The RESCALE operator is defined using an integer multiply, add, and shift.
3742
# This utility function is for calculating the multier and shift given a scale.
3843
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
39-
def computeMultiplierAndShift(scale):
44+
def computeMultiplierAndShift(scale, scaleWidth=32):
45+
if scaleWidth == 16:
46+
offset = 15
47+
elif scaleWidth == 32:
48+
offset = 31
49+
else:
50+
raise AssertionError("unsupported scale width")
51+
4052
assert isinstance(scale, float)
4153

4254
mantissa, exponent = math.frexp(scale)
4355
shift = exponent
4456

45-
const_two_to_31 = 1 << 31
46-
shifted_mantissa = round(mantissa * const_two_to_31)
57+
const_2_power_15_or_31 = 1 << offset
58+
shifted_mantissa = round(mantissa * const_2_power_15_or_31)
4759

48-
assert shifted_mantissa <= const_two_to_31
60+
assert shifted_mantissa <= const_2_power_15_or_31
4961

50-
if shifted_mantissa == const_two_to_31:
62+
if shifted_mantissa == const_2_power_15_or_31:
5163
shifted_mantissa = shifted_mantissa / 2
5264
shift += 1
5365

54-
# TOSA expects right shift to be positive, and embed (1 << 31) into right shift bits.
55-
shift = 31 - shift
66+
# TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits.
67+
shift = offset - shift
5668

5769
# INT32_MAX, 2^31 - 1
58-
assert shifted_mantissa <= (const_two_to_31 - 1)
70+
assert shifted_mantissa <= (const_2_power_15_or_31 - 1)
5971

6072
multiplier = shifted_mantissa
6173

@@ -65,8 +77,41 @@ def computeMultiplierAndShift(scale):
6577
return multiplier, shift
6678

6779

80+
def buildRescale(
81+
tosa_fb,
82+
scale,
83+
input_node,
84+
output_type,
85+
output_shape,
86+
input_zp,
87+
output_zp,
88+
is_double_round,
89+
):
90+
is_scale32 = isScale32(output_type)
91+
scale_width = 32 if is_scale32 else 16
92+
multiplier, shift = computeMultiplierAndShift(scale, scale_width)
93+
94+
attr_rescale = ts.TosaSerializerAttribute()
95+
attr_rescale.RescaleAttribute(
96+
input_zp=input_zp,
97+
output_zp=output_zp,
98+
multiplier=[multiplier],
99+
shift=[shift],
100+
scale32=is_scale32,
101+
double_round=is_double_round,
102+
per_channel=False,
103+
)
104+
105+
rescale_out = tosa_fb.addIntermediate(output_shape, output_type)
106+
tosa_fb.addOperator(
107+
TosaOp.Op().RESCALE, [input_node.name], [rescale_out.name], attr_rescale
108+
)
109+
110+
return rescale_out
111+
112+
68113
def buildRescaleToInt32(
69-
tosa_fb, input, input_zp, rescale_scale
114+
tosa_fb, input, input_zp, rescale_scale, is_scale32=True, is_double_round=True
70115
) -> TosaSerializerTensor:
71116
multiplier, shift = computeMultiplierAndShift(rescale_scale)
72117
attr_rescale = ts.TosaSerializerAttribute()
@@ -75,8 +120,8 @@ def buildRescaleToInt32(
75120
output_zp=0,
76121
multiplier=[multiplier],
77122
shift=[shift],
78-
scale32=True,
79-
double_round=True,
123+
scale32=is_scale32,
124+
double_round=is_double_round,
80125
per_channel=False,
81126
)
82127
input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input.shape, ts.DType.INT32)
@@ -91,7 +136,13 @@ def buildRescaleToInt32(
91136

92137

93138
def buildRescaleFromInt32(
94-
tosa_fb, input_name, output_name, output_zp, rescale_scale
139+
tosa_fb,
140+
input_name,
141+
output_name,
142+
output_zp,
143+
rescale_scale,
144+
is_scale32=True,
145+
is_double_round=True,
95146
) -> TosaSerializerTensor:
96147
multiplier, shift = computeMultiplierAndShift(rescale_scale)
97148
attr_rescale_output = ts.TosaSerializerAttribute()
@@ -100,8 +151,8 @@ def buildRescaleFromInt32(
100151
output_zp=output_zp,
101152
multiplier=[multiplier],
102153
shift=[shift],
103-
scale32=True,
104-
double_round=True,
154+
scale32=is_scale32,
155+
double_round=is_double_round,
105156
per_channel=False,
106157
)
107158

0 commit comments

Comments
 (0)