Skip to content

Arm backend: support int16 and int32 output tables. #9359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions backends/arm/_passes/insert_rescales_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,17 @@ def rescale_fake(
"""Casts the input tensor to dtype `dtype` to produce the correct tensor meta for a _rescale op.
Additionally validates TOSA constraints of a RESCALE op.
"""
if not (dtype == torch.int32 or dtype == torch.int8):
if dtype not in (torch.int32, torch.int8, torch.int16):
raise NotImplementedError(
"tosa::rescale currently only supports int32 and int8."
f"tosa::rescale currently only supports int32, int16 and int8, not {dtype}"
)
if dtype == torch.int32 and out_zp != 0:
if dtype in (torch.int32, torch.int16) and out_zp != 0:
raise ValueError(
"TOSA requires output_zp to be zero when the output dtype is int32."
f"TOSA requires output_zp to be zero when the output dtype is {dtype}."
)
if x.dtype == torch.int32 and in_zp != 0:
if x.dtype in (torch.int32, torch.int16) and in_zp != 0:
raise ValueError(
"TOSA requires input_zp to be zero when the input dtype is int32."
f"TOSA requires input_zp to be zero when the input dtype is {dtype}"
)
if x.dtype == torch.int8 and not -128 <= in_zp <= 127:
raise ValueError(f"{in_zp=} outside valid range (-128,127) for int8.")
Expand Down
121 changes: 106 additions & 15 deletions backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -18,6 +17,7 @@

from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule

from torch.library import impl, Library

lib = Library("tosa", "DEF")
Expand All @@ -26,7 +26,10 @@

@impl(lib, "_table")
def _table_impl(*args, **kwargs): # pyre-ignore
return args[0]
in_dtype = args[0].dtype
if in_dtype == torch.int8:
return args[0]
return args[0].to(dtype=torch.int32)


class InsertTableOpsPass(ExportPass):
Expand Down Expand Up @@ -59,29 +62,105 @@ def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None:
"""
self.exported_program.state_dict[buffer_name] = buffer

def generate_table_values(
def generate_8bit_table_values(
self,
torch_op: Callable[[torch.Tensor], torch.Tensor],
in_quantargs: QuantArgs,
out_quantargs: QuantArgs,
) -> torch.Tensor:
) -> tuple[torch.Tensor, int]:
"""Compute LUT values for a INT8 TOSA.TABLE. Also returns 0 since no shifting is required after 8bit table.
The INT8 table is a simple 256 value 1-1 LUT.
"""

def f(x: torch.Tensor) -> torch.Tensor:
x = in_quantargs.dequantize_value(x)
x = torch_op(x)
return out_quantargs.quantize_value(x)

input_dtype = in_quantargs.dtype
steps = in_quantargs.qmax - in_quantargs.qmin + 1
return f(
return (
f(
torch.linspace(
start=in_quantargs.qmin,
end=in_quantargs.qmax,
steps=256,
# use torch.int64 to avoid overflow when dequantizing (subtracting zp).
# e.g. torch.tensor(-50, dtype=torch.int8) - 100 == torch.tensor(106, dtype=torch.int8)
dtype=torch.int64,
)
).to(dtype=torch.int8),
0,
)

def generate_16_bit_table_values(
self,
torch_op: Callable[[torch.Tensor], torch.Tensor],
in_quantargs: QuantArgs,
out_quantargs: QuantArgs,
) -> tuple[torch.Tensor, int]:
"""Compute LUT values for a INT16 TOSA.TABLE with 32 bit output.
In practice the output is 23 bits that should be interpreted as 16 'whole' bits and 7 fractional bits, see
the specification: https://www.mlplatform.org/tosa/tosa_spec.html#_table. This means that the output
will interpreted as 2**7=128 times too large unless accounted for by rescaling down the table output.

Quantization can be either int16 or int32 which means that the op output could be larger than the 23 bits from
the TOSA.TABLE output. In that case, we need to rescale up the output.

To handle this we need to:
1) Make sure that our table values fit within 16 bits.
2) Insert a rescale after the table to handle the x128 from the fractional bits and match the quantization.

The function returns rescale_lshift which says how much to rescale after the table. This value can negative.
"""

def f(x: torch.Tensor) -> torch.Tensor:
# Dont use the 7 LSBs.
x = in_quantargs.dequantize_value((x & ~0x7F))
x = torch_op(x)
return out_quantargs.quantize_value(x)

lut_values = f(
torch.linspace(
start=in_quantargs.qmin,
end=in_quantargs.qmax,
steps=steps,
end=in_quantargs.qmax + 1,
steps=513,
# use torch.int64 to avoid overflow when dequantizing (subtracting zp).
# e.g. torch.tensor(-50, dtype=torch.int8) - 100 == torch.tensor(106, dtype=torch.int8)
dtype=torch.int64,
)
).to(dtype=input_dtype)
)
# Calculate how much we need to shift table values to fit in 16 signed bits
# ceil(log2(max absolute table value)) + 1 bit for signedness - 16
# Example:
# Max value in the table is 70 000. We want to fit it in 16 signed bits.
# 70 000=0b10001000101110000 (17 digits) has ceil(log2(70 000)) = ceil(16.095) = 17 bits.
# If we shift it 17-16=1 bit, we do get 16 bits (0b1000100010111000),
# but due to signedness this is a negative number! So we need to shift it one more bit.
# Note: for out_quantargs.dtype=torch.int16, rshift == 0 and rescale_lshift = -7.
rshift = int(torch.ceil(torch.log2(lut_values.abs().max()))) + 1 - 16
# The 7 fractional bits are equivalent to a lshift of 7, so subtract 7 from the lshift we do.
rescale_lshift = rshift - 7
lut_values = lut_values >> rshift
return lut_values.to(dtype=torch.int16), rescale_lshift

def generate_table_values(
self,
torch_op: Callable[[torch.Tensor], torch.Tensor],
in_quantargs: QuantArgs,
out_quantargs: QuantArgs,
) -> tuple[torch.Tensor, int]:
match out_quantargs.dtype:
case torch.int8:
return self.generate_8bit_table_values(
torch_op, in_quantargs, out_quantargs
)
case torch.int16 | torch.int32:
return self.generate_16_bit_table_values(
torch_op, in_quantargs, out_quantargs
)
case _:
raise ValueError(
f"Unsupported output dtype for table: {out_quantargs.dtype}"
)

def call(self, graph_module: GraphModule) -> PassResult:
modified = False
Expand All @@ -100,10 +179,12 @@ def call(self, graph_module: GraphModule) -> PassResult:
op_target=torch.ops.tosa._table.default,
args=(node.args[0],),
)
output_node = table_node
assert len(input_qparams) == 1
assert len(output_qparams) == 1
# Generate table buffer
buffer = self.generate_table_values(

# Generate table buffer and how much to lshift the table output.
buffer, lshift = self.generate_table_values(
torch_op=self.table_ops[node.target],
in_quantargs=input_qparams[0],
out_quantargs=output_qparams[0],
Expand All @@ -114,10 +195,20 @@ def call(self, graph_module: GraphModule) -> PassResult:
self.register_buffer(
buffer_name=table_node.name.replace("_default", ""), buffer=buffer
)
node.replace_all_uses_with(table_node)

if lshift != 0:
scale = 2.0**lshift
rescale_node = create_node(
graph=graph_module.graph,
op_target=torch.ops.tosa._rescale.default,
args=(table_node, output_qparams[0].dtype, scale, 0, 0),
)
output_node = rescale_node

node.replace_all_uses_with(output_node)
graph_module.graph.erase_node(node)
table_node.meta["input_qparams"] = input_qparams
table_node.meta["output_qparams"] = output_qparams
output_node.meta["input_qparams"] = input_qparams
output_node.meta["output_qparams"] = output_qparams
modified = True

if modified:
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class NodeVisitor:
]

def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification):
self._exported_program = exported_program or None
self._exported_program = exported_program
self.tosa_spec = tosa_spec

def define_node(
Expand Down
8 changes: 5 additions & 3 deletions backends/arm/operators/op_rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def define_node(
input_zp = cast(int, node.args[3])
output_zp = cast(int, node.args[4])

# Skip int16 cases for now.
if input_dtype != map_dtype(torch.int8) and input_zp != 0:
raise ValueError(
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{ts.DTypeNames[input_dtype]}, {input_zp=}"
Expand All @@ -48,7 +47,10 @@ def define_node(
f"If output dtype is not int8, output_zp must be 0. Got {output_dtype=}, {output_zp=}"
)

scale_width = 32 if output_dtype == torch.int32 else 16
# scale32 gives higher accuracy but for a higher HW cost.
# For now, always go for scale32.
scale_32 = True
scale_width = 32 if scale_32 else 16
multiplier, shift = tosa_quant_utils.compute_multiplier_and_shift(
[scale], scale_width
)
Expand All @@ -58,7 +60,7 @@ def define_node(
output_zp=output_zp,
multiplier=multiplier,
shift=shift,
scale32=output_dtype == torch.int32,
scale32=scale_32,
double_round=False,
per_channel=False,
input_unsigned=False,
Expand Down
17 changes: 15 additions & 2 deletions backends/arm/operators/op_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,24 @@ def define_node(
inputs: List[TosaArg],
output: TosaArg,
) -> None:
assert node.name in self._exported_program.state_dict.keys() # type: ignore[union-attr]
assert inputs[0].dtype == output.dtype == ts.DType.INT8
if node.name not in self._exported_program.state_dict.keys(): # type: ignore[union-attr]
raise RuntimeError(
f"Did not find key {node.name} in state_dict {self._exported_program.state_dict.keys()}."
)
if inputs[0].dtype == ts.DType.INT8 and output.dtype != ts.DType.INT8:
raise ValueError(f"Int8 tables need int8 output, got {output.dtype=}.")
if inputs[0].dtype == ts.DType.INT16 and output.dtype != ts.DType.INT32:
raise ValueError(f"Int16 tables need int32 output, got {output.dtype=}.")

if inputs[0].dtype not in (ts.DType.INT8, ts.DType.INT16):
raise ValueError(
f"TOSA.TABLE only supports int8 or int16 inputs, got {ts.DTypeNames[inputs[0]]}"
)

table = self._exported_program.state_dict[node.name] # type: ignore[union-attr]
table_attr = ts.TosaSerializerAttribute()
table_attr.TableAttribute(np.array(table))

tosa_graph.addOperator(
TosaOp.Op().TABLE, [inputs[0].name], [output.name], table_attr
)
Loading
Loading