Skip to content

Add buildRescale helper function with scale width and double rounding option #567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 65 additions & 14 deletions backends/arm/tosa_quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,30 +33,42 @@ def isQuantArg(arg):
)


# Check if scale32 mode is used for given output element type
def isScale32(type):
return type == ts.DType.INT8


# TOSA uses the RESCALE operation to scale between values with differing precision.
# The RESCALE operator is defined using an integer multiply, add, and shift.
# This utility function is for calculating the multier and shift given a scale.
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
def computeMultiplierAndShift(scale):
def computeMultiplierAndShift(scale, scaleWidth=32):
if scaleWidth == 16:
offset = 15
elif scaleWidth == 32:
offset = 31
else:
raise AssertionError("unsupported scale width")

assert isinstance(scale, float)

mantissa, exponent = math.frexp(scale)
shift = exponent

const_two_to_31 = 1 << 31
shifted_mantissa = round(mantissa * const_two_to_31)
const_2_power_15_or_31 = 1 << offset
shifted_mantissa = round(mantissa * const_2_power_15_or_31)

assert shifted_mantissa <= const_two_to_31
assert shifted_mantissa <= const_2_power_15_or_31

if shifted_mantissa == const_two_to_31:
if shifted_mantissa == const_2_power_15_or_31:
shifted_mantissa = shifted_mantissa / 2
shift += 1

# TOSA expects right shift to be positive, and embed (1 << 31) into right shift bits.
shift = 31 - shift
# TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits.
shift = offset - shift

# INT32_MAX, 2^31 - 1
assert shifted_mantissa <= (const_two_to_31 - 1)
assert shifted_mantissa <= (const_2_power_15_or_31 - 1)

multiplier = shifted_mantissa

Expand All @@ -66,8 +78,41 @@ def computeMultiplierAndShift(scale):
return multiplier, shift


def buildRescale(
tosa_fb,
scale,
input_node,
output_type,
output_shape,
input_zp,
output_zp,
is_double_round,
):
is_scale32 = isScale32(output_type)
scale_width = 32 if is_scale32 else 16
multiplier, shift = computeMultiplierAndShift(scale, scale_width)

attr_rescale = ts.TosaSerializerAttribute()
attr_rescale.RescaleAttribute(
input_zp=input_zp,
output_zp=output_zp,
multiplier=[multiplier],
shift=[shift],
scale32=is_scale32,
double_round=is_double_round,
per_channel=False,
)

rescale_out = tosa_fb.addIntermediate(output_shape, output_type)
tosa_fb.addOperator(
TosaOp.Op().RESCALE, [input_node.name], [rescale_out.name], attr_rescale
)

return rescale_out


def buildRescaleToInt32(
tosa_fb, input, input_zp, rescale_scale
tosa_fb, input, input_zp, rescale_scale, is_scale32=True, is_double_round=True
) -> TosaSerializerTensor:
multiplier, shift = computeMultiplierAndShift(rescale_scale)
attr_rescale = ts.TosaSerializerAttribute()
Expand All @@ -76,8 +121,8 @@ def buildRescaleToInt32(
output_zp=0,
multiplier=[multiplier],
shift=[shift],
scale32=True,
double_round=True,
scale32=is_scale32,
double_round=is_double_round,
per_channel=False,
)
input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input.shape, ts.DType.INT32)
Expand All @@ -92,7 +137,13 @@ def buildRescaleToInt32(


def buildRescaleFromInt32(
tosa_fb, input_name, output_name, output_zp, rescale_scale
tosa_fb,
input_name,
output_name,
output_zp,
rescale_scale,
is_scale32=True,
is_double_round=True,
) -> TosaSerializerTensor:
multiplier, shift = computeMultiplierAndShift(rescale_scale)
attr_rescale_output = ts.TosaSerializerAttribute()
Expand All @@ -101,8 +152,8 @@ def buildRescaleFromInt32(
output_zp=output_zp,
multiplier=[multiplier],
shift=[shift],
scale32=True,
double_round=True,
scale32=is_scale32,
double_round=is_double_round,
per_channel=False,
)

Expand Down