|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
7 | 7 | from __future__ import annotations
|
8 |
| - |
9 | 8 | import json
|
10 | 9 | from functools import reduce
|
11 | 10 | from math import gcd
|
@@ -140,10 +139,72 @@ def convert_for_runtime(self) -> nn.Module:
|
140 | 139 | def quantized_model(self) -> nn.Module:
|
141 | 140 | return self.model.to(device=device, **kwargs)
|
142 | 141 |
|
| 142 | + |
| 143 | +######################################################################### |
| 144 | +### QuantHandler API definition ### |
| 145 | +### (unify with torchao in future) ### |
| 146 | + |
| 147 | +class QuantHandler: |
| 148 | + def __init__(self, mod, device = "cpu", tokenizer = None): |
| 149 | + self.mod = mod |
| 150 | + self.device = device |
| 151 | + self.tokenizer = tokenizer |
| 152 | + |
| 153 | + def create_quantized_state_dict(self) -> Dict: # "StateDict" |
| 154 | + pass |
| 155 | + |
| 156 | + def convert_for_runtime(self) -> nn.Module: |
| 157 | + pass |
| 158 | + |
| 159 | + def quantized_model(self) -> nn.Module: |
| 160 | + model_updated_state_dict = self.create_quantized_state_dict() |
| 161 | + self.convert_for_runtime() |
| 162 | + self.mod.load_state_dict(model_updated_state_dict) |
| 163 | + return self.mod |
| 164 | + |
143 | 165 |
|
144 | 166 | #########################################################################
|
145 |
| -##### Quantization Primitives ###### |
| 167 | +### QuantHandler wrapper for a8w4dq from torchao ### |
| 168 | + |
| 169 | +class Int8DynActInt4WeightQuantizer(QuantHandler): |
| 170 | + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer as aoInt8DynActInt4WeightQuantizer |
| 171 | + |
| 172 | + def __init__(self, mod, device = "cpu", tokenizer = None, **kwargs): |
| 173 | + self.mod = mod |
| 174 | + self.device = device |
| 175 | + self.tokenizer = tokenizer |
| 176 | + self.quantizer = aoInt8DynActInt4WeightQuantizer(**kwargs) |
| 177 | + |
| 178 | + def create_quantized_state_dict(self) -> Dict: # "StateDict" |
| 179 | + pass |
146 | 180 |
|
| 181 | + def convert_for_runtime(self) -> nn.Module: |
| 182 | + pass |
| 183 | + |
| 184 | + def quantized_model(self) -> nn.Module: |
| 185 | + return self.quantizer.quantize(self.model) |
| 186 | + |
| 187 | +######################################################################### |
| 188 | +### QuantHandler wrapper for a8w4dq from torchao ### |
| 189 | + |
| 190 | +class PrecisionHandler(QuantHandler): |
| 191 | + def __init__(self, mod, device = "cpu", tokenizer = None, **kwargs): |
| 192 | + self.mod = mod |
| 193 | + self.device = device |
| 194 | + self.tokenizer = tokenizer |
| 195 | + |
| 196 | + def create_quantized_state_dict(self) -> Dict: # "StateDict" |
| 197 | + pass |
| 198 | + |
| 199 | + def convert_for_runtime(self) -> nn.Module: |
| 200 | + pass |
| 201 | + |
| 202 | + def quantized_model(self) -> nn.Module: |
| 203 | + return self.model.to(device=device, **kwargs) |
| 204 | + |
| 205 | + |
| 206 | +######################################################################### |
| 207 | +##### Quantization Primitives ###### |
147 | 208 |
|
148 | 209 | def dynamically_quantize_per_channel(
|
149 | 210 | x,
|
@@ -354,7 +415,7 @@ def replace_linear_weight_only_int8_per_channel(
|
354 | 415 | module, device, node_type, groupsize=None
|
355 | 416 | ):
|
356 | 417 | if groupsize is not None and groupsize != 0:
|
357 |
| - pass |
| 418 | + pass |
358 | 419 |
|
359 | 420 | for name, child in module.named_children():
|
360 | 421 | # print(f"name: {name}")
|
@@ -808,14 +869,7 @@ def replace_linear_int4(
|
808 | 869 |
|
809 | 870 | class WeightOnlyInt4QuantHandler(QuantHandler):
|
810 | 871 | def __init__(
|
811 |
| - self, |
812 |
| - mod, |
813 |
| - device, |
814 |
| - tokenizer=None, |
815 |
| - *, |
816 |
| - groupsize=128, |
817 |
| - inner_k_tiles=8, |
818 |
| - padding_allowed=True, |
| 872 | + self, mod, device, tokenizer=None, *, groupsize=128, inner_k_tiles=8, padding_allowed=True |
819 | 873 | ):
|
820 | 874 | self.mod = mod
|
821 | 875 | self.device = device
|
|
0 commit comments