|
23 | 23 |
|
24 | 24 |
|
25 | 25 | class QuantizationParams:
|
26 |
| - __slots__ = ["node_name", "zp", "scale"] |
| 26 | + __slots__ = ["node_name", "zp", "scale", "qmin", "qmax", "dtype"] |
27 | 27 |
|
28 | 28 | # todo: zps and scales can be per tensors or per channel => a list??
|
29 |
| - def __init__(self, node_name: str, zp: int, scale: float): |
| 29 | + def __init__( |
| 30 | + self, |
| 31 | + node_name: str, |
| 32 | + zp: int, |
| 33 | + scale: float, |
| 34 | + qmin: int, |
| 35 | + qmax: int, |
| 36 | + dtype: torch.dtype, |
| 37 | + ): |
30 | 38 | self.node_name = node_name # not need I think, but good for error check
|
31 | 39 | self.zp = zp
|
32 | 40 | self.scale = scale
|
| 41 | + self.qmin = qmin |
| 42 | + self.qmax = qmax |
| 43 | + self.dtype = dtype |
33 | 44 |
|
34 | 45 |
|
35 | 46 | def _get_input_names(program: ExportedProgram) -> list[str]:
|
@@ -74,7 +85,12 @@ def _get_input_quantization_params(
|
74 | 85 | and node.args[0].name in input_names
|
75 | 86 | ):
|
76 | 87 | qp = QuantizationParams(
|
77 |
| - node_name=node.args[0].name, scale=node.args[1], zp=node.args[2] |
| 88 | + node_name=node.args[0].name, |
| 89 | + scale=node.args[1], |
| 90 | + zp=node.args[2], |
| 91 | + qmin=node.args[3], |
| 92 | + qmax=node.args[4], |
| 93 | + dtype=node.args[5], |
78 | 94 | )
|
79 | 95 | quant_params.append(qp)
|
80 | 96 | if (
|
@@ -122,7 +138,12 @@ def _get_output_quantization_params(
|
122 | 138 | and node == output_node.args[0][0]
|
123 | 139 | ):
|
124 | 140 | quant_params = QuantizationParams(
|
125 |
| - node_name=node.args[0].name, scale=node.args[1], zp=node.args[2] |
| 141 | + node_name=node.args[0].name, |
| 142 | + scale=node.args[1], |
| 143 | + zp=node.args[2], |
| 144 | + qmin=node.args[3], |
| 145 | + qmax=node.args[4], |
| 146 | + dtype=node.args[5], |
126 | 147 | )
|
127 | 148 | break # break early, there's only one output node
|
128 | 149 | if quant_params is None:
|
@@ -376,13 +397,13 @@ def prep_data_for_save(
|
376 | 397 | assert (
|
377 | 398 | quant_param.node_name == input_name
|
378 | 399 | ), "These quantization params do not match the input tensor name"
|
379 |
| - int8_max = np.iinfo(np.int8).max |
380 |
| - int8_min = np.iinfo(np.int8).min |
381 | 400 | data_np = (
|
382 | 401 | ((data_np / np.float32(quant_param.scale)) + quant_param.zp)
|
383 | 402 | .round()
|
384 |
| - .clip(int8_min, int8_max) |
385 |
| - .astype(np.int8) |
| 403 | + .clip(quant_param.qmin, quant_param.qmax) |
| 404 | + .astype( |
| 405 | + f"{quant_param.dtype}".replace("torch.", "") |
| 406 | + ) # Use string format of dtype to convert to numpy dtype |
386 | 407 | )
|
387 | 408 | return data_np
|
388 | 409 |
|
|
0 commit comments