|
20 | 20 |
|
21 | 21 |
|
22 | 22 | class QuantizationParams:
|
23 |
| - __slots__ = ["node_name", "zp", "scale"] |
| 23 | + __slots__ = ["node_name", "zp", "scale", "qmin", "qmax", "dtype"] |
24 | 24 |
|
25 | 25 | # todo: zps and scales can be per tensors or per channel => a list??
|
26 |
| - def __init__(self, node_name: str, zp: int, scale: float): |
| 26 | + def __init__( |
| 27 | + self, |
| 28 | + node_name: str, |
| 29 | + zp: int, |
| 30 | + scale: float, |
| 31 | + qmin: int, |
| 32 | + qmax: int, |
| 33 | + dtype: torch.dtype, |
| 34 | + ): |
27 | 35 | self.node_name = node_name # not need I think, but good for error check
|
28 | 36 | self.zp = zp
|
29 | 37 | self.scale = scale
|
| 38 | + self.qmin = qmin |
| 39 | + self.qmax = qmax |
| 40 | + self.dtype = dtype |
30 | 41 |
|
31 | 42 | def __repr__(self):
|
32 |
| - return f"QuantizationParams(node_name={self.node_name}, zp={self.zp}, scale={self.scale})" |
| 43 | + return f"QuantizationParams(node_name={self.node_name}, zp={self.zp}, scale={self.scale}, [{self.qmin},{self.qmax}], dtype={self.dtype})" |
33 | 44 |
|
34 | 45 |
|
35 | 46 | """
|
@@ -160,13 +171,13 @@ def run_tosa_ref_model(
|
160 | 171 | assert (
|
161 | 172 | quant_param.node_name == input_name
|
162 | 173 | ), "These quantization params do not match the input tensor name"
|
163 |
| - int8_max = np.iinfo(np.int8).max |
164 |
| - int8_min = np.iinfo(np.int8).min |
165 | 174 | data_np = (
|
166 | 175 | ((data_np / np.float32(quant_param.scale)) + quant_param.zp)
|
167 | 176 | .round()
|
168 |
| - .clip(int8_min, int8_max) |
169 |
| - .astype(np.int8) |
| 177 | + .clip(quant_param.qmin, quant_param.qmax) |
| 178 | + .astype( |
| 179 | + f"{quant_param.dtype}".replace("torch.", "") |
| 180 | + ) # Use string format of dtype to convert to numpy dtype |
170 | 181 | )
|
171 | 182 | file_path = os.path.join(self.intermediate_path, input_name + ".npy")
|
172 | 183 | np.save(file_path, data_np, allow_pickle=False)
|
|
0 commit comments