Skip to content

Quantization types #4094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from executorch.backends.arm.arm_vela import vela_compile
from executorch.backends.arm.operators.node_visitor import get_node_visitors
from executorch.backends.arm.operators.op_placeholder import process_placeholder
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import is_quant_node
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
from executorch.backends.arm.tosa_quant_utils import get_quant_node_dtype, is_quant_node
from executorch.backends.arm.tosa_utils import (
dbg_fail,
dbg_tosa_dump,
Expand Down Expand Up @@ -280,7 +280,11 @@ def preprocess( # noqa: C901
if is_permute_node_before_addmm(node)
else output.shape
),
ts.DType.INT8 if is_quant_node(node) else output.dtype,
(
map_dtype(get_quant_node_dtype(node))
if is_quant_node(node)
else output.dtype
),
)

# Visiting each Node
Expand Down
3 changes: 2 additions & 1 deletion backends/arm/operators/op_placeholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import (
get_quant_arg_dtype,
get_quant_node_args,
is_quant_arg,
q_op,
Expand Down Expand Up @@ -166,7 +167,7 @@ def process_placeholder(
tensor = ts.TosaSerializerTensor(
inputs[0].name,
input_shape,
ts.DType.INT8 if is_quant_arg(node) else inputs[0].dtype,
get_quant_arg_dtype(node) if is_quant_arg(node) else inputs[0].dtype,
data=None,
placeholderFilename=inputs[0].name + ".npy",
)
Expand Down
37 changes: 29 additions & 8 deletions backends/arm/test/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,24 @@


class QuantizationParams:
__slots__ = ["node_name", "zp", "scale"]
__slots__ = ["node_name", "zp", "scale", "qmin", "qmax", "dtype"]

# todo: zps and scales can be per tensors or per channel => a list??
def __init__(self, node_name: str, zp: int, scale: float):
def __init__(
self,
node_name: str,
zp: int,
scale: float,
qmin: int,
qmax: int,
dtype: torch.dtype,
):
self.node_name = node_name # not need I think, but good for error check
self.zp = zp
self.scale = scale
self.qmin = qmin
self.qmax = qmax
self.dtype = dtype


def _get_input_names(program: ExportedProgram) -> list[str]:
Expand Down Expand Up @@ -74,7 +85,12 @@ def _get_input_quantization_params(
and node.args[0].name in input_names
):
qp = QuantizationParams(
node_name=node.args[0].name, scale=node.args[1], zp=node.args[2]
node_name=node.args[0].name,
scale=node.args[1],
zp=node.args[2],
qmin=node.args[3],
qmax=node.args[4],
dtype=node.args[5],
)
quant_params.append(qp)
if (
Expand Down Expand Up @@ -122,7 +138,12 @@ def _get_output_quantization_params(
and node == output_node.args[0][0]
):
quant_params = QuantizationParams(
node_name=node.args[0].name, scale=node.args[1], zp=node.args[2]
node_name=node.args[0].name,
scale=node.args[1],
zp=node.args[2],
qmin=node.args[3],
qmax=node.args[4],
dtype=node.args[5],
)
break # break early, there's only one output node
if quant_params is None:
Expand Down Expand Up @@ -376,13 +397,13 @@ def prep_data_for_save(
assert (
quant_param.node_name == input_name
), "These quantization params do not match the input tensor name"
int8_max = np.iinfo(np.int8).max
int8_min = np.iinfo(np.int8).min
data_np = (
((data_np / np.float32(quant_param.scale)) + quant_param.zp)
.round()
.clip(int8_min, int8_max)
.astype(np.int8)
.clip(quant_param.qmin, quant_param.qmax)
.astype(
f"{quant_param.dtype}".replace("torch.", "")
) # Use string format of dtype to convert to numpy dtype
)
return data_np

Expand Down
32 changes: 31 additions & 1 deletion backends/arm/tosa_quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import serializer.tosa_serializer as ts
import torch.fx
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
from executorch.exir.dialects._ops import ops as exir_ops
from serializer.tosa_serializer import TosaOp, TosaSerializerTensor

Expand Down Expand Up @@ -45,11 +45,41 @@ def is_quant_node(node: torch.fx.Node):
)


def get_quant_node_dtype(node: torch.fx.Node):
if "tosa" in node.target.__name__:
return node.meta["val"].dtype

if node.target in dq_q_ops:
return node.args[5]

# if not a tosa node, nor a q/dq op, walk the graph until we find a q op
consumer_node = list(node.users)[0]
while True:
if consumer_node.target in dq_q_ops:
return consumer_node.args[5]

# Try to move on to the next node
if len(consumer_node.users) == 0:
raise RuntimeError("No quantized node found in graph")
consumer_node = list(consumer_node.users)[0]


def is_quant_arg(arg):
consumer_node = list(arg.users)[0]
return consumer_node.target == q_op


def get_quant_arg_dtype(node: torch.fx.Node):
consumer_node = list(node.users)[0]

# Get type of quant node, args differ from per_tensor and per_channel.
if consumer_node.target == q_op:
if is_quant_arg(node):
return map_dtype(consumer_node.args[5])
else:
raise RuntimeError("Quantization argument not found")


def get_quant_node_args(node: torch.fx.Node):
"""
Get the quantization parameters from a quant node.
Expand Down
Loading