Skip to content

Commit 074a81e

Browse files
perfacebook-github-bot
authored andcommitted
Quantization types (#4094)
Summary: Pull Request resolved: #4094 Reviewed By: mergennachin Differential Revision: D59259075 Pulled By: digantdesai fbshipit-source-id: ca0c12684b47755796c95be83a5b901c86392ec2
1 parent 19ed018 commit 074a81e

File tree

4 files changed

+69
-13
lines changed

4 files changed

+69
-13
lines changed

backends/arm/arm_backend.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from executorch.backends.arm.arm_vela import vela_compile
1818
from executorch.backends.arm.operators.node_visitor import get_node_visitors
1919
from executorch.backends.arm.operators.op_placeholder import process_placeholder
20-
from executorch.backends.arm.tosa_mapping import TosaArg
21-
from executorch.backends.arm.tosa_quant_utils import is_quant_node
20+
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
21+
from executorch.backends.arm.tosa_quant_utils import get_quant_node_dtype, is_quant_node
2222
from executorch.backends.arm.tosa_utils import (
2323
dbg_fail,
2424
dbg_tosa_dump,
@@ -280,7 +280,11 @@ def preprocess( # noqa: C901
280280
if is_permute_node_before_addmm(node)
281281
else output.shape
282282
),
283-
ts.DType.INT8 if is_quant_node(node) else output.dtype,
283+
(
284+
map_dtype(get_quant_node_dtype(node))
285+
if is_quant_node(node)
286+
else output.dtype
287+
),
284288
)
285289

286290
# Visiting each Node

backends/arm/operators/op_placeholder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from executorch.backends.arm.tosa_mapping import TosaArg
1010
from executorch.backends.arm.tosa_quant_utils import (
11+
get_quant_arg_dtype,
1112
get_quant_node_args,
1213
is_quant_arg,
1314
q_op,
@@ -166,7 +167,7 @@ def process_placeholder(
166167
tensor = ts.TosaSerializerTensor(
167168
inputs[0].name,
168169
input_shape,
169-
ts.DType.INT8 if is_quant_arg(node) else inputs[0].dtype,
170+
get_quant_arg_dtype(node) if is_quant_arg(node) else inputs[0].dtype,
170171
data=None,
171172
placeholderFilename=inputs[0].name + ".npy",
172173
)

backends/arm/test/runner_utils.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,24 @@
2323

2424

2525
class QuantizationParams:
26-
__slots__ = ["node_name", "zp", "scale"]
26+
__slots__ = ["node_name", "zp", "scale", "qmin", "qmax", "dtype"]
2727

2828
# todo: zps and scales can be per tensors or per channel => a list??
29-
def __init__(self, node_name: str, zp: int, scale: float):
29+
def __init__(
30+
self,
31+
node_name: str,
32+
zp: int,
33+
scale: float,
34+
qmin: int,
35+
qmax: int,
36+
dtype: torch.dtype,
37+
):
3038
self.node_name = node_name # not need I think, but good for error check
3139
self.zp = zp
3240
self.scale = scale
41+
self.qmin = qmin
42+
self.qmax = qmax
43+
self.dtype = dtype
3344

3445

3546
def _get_input_names(program: ExportedProgram) -> list[str]:
@@ -74,7 +85,12 @@ def _get_input_quantization_params(
7485
and node.args[0].name in input_names
7586
):
7687
qp = QuantizationParams(
77-
node_name=node.args[0].name, scale=node.args[1], zp=node.args[2]
88+
node_name=node.args[0].name,
89+
scale=node.args[1],
90+
zp=node.args[2],
91+
qmin=node.args[3],
92+
qmax=node.args[4],
93+
dtype=node.args[5],
7894
)
7995
quant_params.append(qp)
8096
if (
@@ -122,7 +138,12 @@ def _get_output_quantization_params(
122138
and node == output_node.args[0][0]
123139
):
124140
quant_params = QuantizationParams(
125-
node_name=node.args[0].name, scale=node.args[1], zp=node.args[2]
141+
node_name=node.args[0].name,
142+
scale=node.args[1],
143+
zp=node.args[2],
144+
qmin=node.args[3],
145+
qmax=node.args[4],
146+
dtype=node.args[5],
126147
)
127148
break # break early, there's only one output node
128149
if quant_params is None:
@@ -376,13 +397,13 @@ def prep_data_for_save(
376397
assert (
377398
quant_param.node_name == input_name
378399
), "These quantization params do not match the input tensor name"
379-
int8_max = np.iinfo(np.int8).max
380-
int8_min = np.iinfo(np.int8).min
381400
data_np = (
382401
((data_np / np.float32(quant_param.scale)) + quant_param.zp)
383402
.round()
384-
.clip(int8_min, int8_max)
385-
.astype(np.int8)
403+
.clip(quant_param.qmin, quant_param.qmax)
404+
.astype(
405+
f"{quant_param.dtype}".replace("torch.", "")
406+
) # Use string format of dtype to convert to numpy dtype
386407
)
387408
return data_np
388409

backends/arm/tosa_quant_utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import serializer.tosa_serializer as ts
1212
import torch.fx
13-
from executorch.backends.arm.tosa_mapping import TosaArg
13+
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
1414
from executorch.exir.dialects._ops import ops as exir_ops
1515
from serializer.tosa_serializer import TosaOp, TosaSerializerTensor
1616

@@ -45,11 +45,41 @@ def is_quant_node(node: torch.fx.Node):
4545
)
4646

4747

48+
def get_quant_node_dtype(node: torch.fx.Node):
49+
if "tosa" in node.target.__name__:
50+
return node.meta["val"].dtype
51+
52+
if node.target in dq_q_ops:
53+
return node.args[5]
54+
55+
# if not a tosa node, nor a q/dq op, walk the graph until we find a q op
56+
consumer_node = list(node.users)[0]
57+
while True:
58+
if consumer_node.target in dq_q_ops:
59+
return consumer_node.args[5]
60+
61+
# Try to move on to the next node
62+
if len(consumer_node.users) == 0:
63+
raise RuntimeError("No quantized node found in graph")
64+
consumer_node = list(consumer_node.users)[0]
65+
66+
4867
def is_quant_arg(arg):
4968
consumer_node = list(arg.users)[0]
5069
return consumer_node.target == q_op
5170

5271

72+
def get_quant_arg_dtype(node: torch.fx.Node):
73+
consumer_node = list(node.users)[0]
74+
75+
# Get type of quant node, args differ from per_tensor and per_channel.
76+
if consumer_node.target == q_op:
77+
if is_quant_arg(node):
78+
return map_dtype(consumer_node.args[5])
79+
else:
80+
raise RuntimeError("Quantization argument not found")
81+
82+
5383
def get_quant_node_args(node: torch.fx.Node):
5484
"""
5585
Get the quantization parameters from a quant node.

0 commit comments

Comments
 (0)