Skip to content

Commit 4e9e50d

Browse files
committed
Expand QuantizationParams to include dtype and limits
In order to support quantization to different types then int8 keep track of the limits and type as well. Signed-off-by: Per Åstrand <[email protected]> Change-Id: Ia5861adfeff4d57676ff06ccf5a7a8213c34efe6
1 parent 8f12da1 commit 4e9e50d

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

backends/arm/test/tester/arm_tester.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,12 @@ def _get_input_params(
7474
and node.args[0].name in input_names
7575
):
7676
qp = QuantizationParams(
77-
node_name=node.args[0].name, scale=node.args[1], zp=node.args[2]
77+
node_name=node.args[0].name,
78+
scale=node.args[1],
79+
zp=node.args[2],
80+
qmin=node.args[3],
81+
qmax=node.args[4],
82+
dtype=node.args[5],
7883
)
7984
quant_params.append(qp)
8085
if len(quant_params) == len(
@@ -115,7 +120,12 @@ def _get_output_param(
115120
and node == output_node.args[0][0]
116121
):
117122
quant_params = QuantizationParams(
118-
node_name=node.args[0].name, scale=node.args[1], zp=node.args[2]
123+
node_name=node.args[0].name,
124+
scale=node.args[1],
125+
zp=node.args[2],
126+
qmin=node.args[3],
127+
qmax=node.args[4],
128+
dtype=node.args[5],
119129
)
120130
break # break early, there's only one output node
121131
assert quant_params is not None, "Quantization paramerters not found"

backends/arm/test/tosautil/tosa_test_utils.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,27 @@
2020

2121

2222
class QuantizationParams:
23-
__slots__ = ["node_name", "zp", "scale"]
23+
__slots__ = ["node_name", "zp", "scale", "qmin", "qmax", "dtype"]
2424

2525
# todo: zps and scales can be per tensors or per channel => a list??
26-
def __init__(self, node_name: str, zp: int, scale: float):
26+
def __init__(
27+
self,
28+
node_name: str,
29+
zp: int,
30+
scale: float,
31+
qmin: int,
32+
qmax: int,
33+
dtype: torch.dtype,
34+
):
2735
self.node_name = node_name # not need I think, but good for error check
2836
self.zp = zp
2937
self.scale = scale
38+
self.qmin = qmin
39+
self.qmax = qmax
40+
self.dtype = dtype
3041

3142
def __repr__(self):
32-
return f"QuantizationParams(node_name={self.node_name}, zp={self.zp}, scale={self.scale})"
43+
return f"QuantizationParams(node_name={self.node_name}, zp={self.zp}, scale={self.scale}, [{self.qmin},{self.qmax}], dtype={self.dtype})"
3344

3445

3546
"""
@@ -160,13 +171,13 @@ def run_tosa_ref_model(
160171
assert (
161172
quant_param.node_name == input_name
162173
), "These quantization params do not match the input tensor name"
163-
int8_max = np.iinfo(np.int8).max
164-
int8_min = np.iinfo(np.int8).min
165174
data_np = (
166175
((data_np / np.float32(quant_param.scale)) + quant_param.zp)
167176
.round()
168-
.clip(int8_min, int8_max)
169-
.astype(np.int8)
177+
.clip(quant_param.qmin, quant_param.qmax)
178+
.astype(
179+
f"{quant_param.dtype}".replace("torch.", "")
180+
) # Use string format of dtype to convert to numpy dtype
170181
)
171182
file_path = os.path.join(self.intermediate_path, input_name + ".npy")
172183
np.save(file_path, data_np, allow_pickle=False)

0 commit comments

Comments
 (0)