Skip to content

Commit 093e735

Browse files
committed
Expand QuantizationParams to include dtype and limits
In order to support quantization to different types then int8 keep track of the limits and type as well. Signed-off-by: Per Åstrand <[email protected]> Change-Id: Ia5861adfeff4d57676ff06ccf5a7a8213c34efe6
1 parent f32d707 commit 093e735

File tree

1 file changed

+29
-8
lines changed

1 file changed

+29
-8
lines changed

backends/arm/test/runner_utils.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,24 @@
2323

2424

2525
class QuantizationParams:
26-
__slots__ = ["node_name", "zp", "scale"]
26+
__slots__ = ["node_name", "zp", "scale", "qmin", "qmax", "dtype"]
2727

2828
# todo: zps and scales can be per tensors or per channel => a list??
29-
def __init__(self, node_name: str, zp: int, scale: float):
29+
def __init__(
30+
self,
31+
node_name: str,
32+
zp: int,
33+
scale: float,
34+
qmin: int,
35+
qmax: int,
36+
dtype: torch.dtype,
37+
):
3038
self.node_name = node_name # not need I think, but good for error check
3139
self.zp = zp
3240
self.scale = scale
41+
self.qmin = qmin
42+
self.qmax = qmax
43+
self.dtype = dtype
3344

3445

3546
def _get_input_names(program: ExportedProgram) -> list[str]:
@@ -74,7 +85,12 @@ def _get_input_quantization_params(
7485
and node.args[0].name in input_names
7586
):
7687
qp = QuantizationParams(
77-
node_name=node.args[0].name, scale=node.args[1], zp=node.args[2]
88+
node_name=node.args[0].name,
89+
scale=node.args[1],
90+
zp=node.args[2],
91+
qmin=node.args[3],
92+
qmax=node.args[4],
93+
dtype=node.args[5],
7894
)
7995
quant_params.append(qp)
8096
if (
@@ -122,7 +138,12 @@ def _get_output_quantization_params(
122138
and node == output_node.args[0][0]
123139
):
124140
quant_params = QuantizationParams(
125-
node_name=node.args[0].name, scale=node.args[1], zp=node.args[2]
141+
node_name=node.args[0].name,
142+
scale=node.args[1],
143+
zp=node.args[2],
144+
qmin=node.args[3],
145+
qmax=node.args[4],
146+
dtype=node.args[5],
126147
)
127148
break # break early, there's only one output node
128149
if quant_params is None:
@@ -376,13 +397,13 @@ def prep_data_for_save(
376397
assert (
377398
quant_param.node_name == input_name
378399
), "These quantization params do not match the input tensor name"
379-
int8_max = np.iinfo(np.int8).max
380-
int8_min = np.iinfo(np.int8).min
381400
data_np = (
382401
((data_np / np.float32(quant_param.scale)) + quant_param.zp)
383402
.round()
384-
.clip(int8_min, int8_max)
385-
.astype(np.int8)
403+
.clip(quant_param.qmin, quant_param.qmax)
404+
.astype(
405+
f"{quant_param.dtype}".replace("torch.", "")
406+
) # Use string format of dtype to convert to numpy dtype
386407
)
387408
return data_np
388409

0 commit comments

Comments
 (0)