Skip to content

Commit 3bc5fe3

Browse files
mikekgfbmalfet
authored andcommitted
refactor quantizer entry point quantize_model to be table driven (#324)
* refactor quantizer entry point quantize_model to be table driven, and scalable * add tokenizer arg consistently * code beautification
1 parent 50781ac commit 3bc5fe3

File tree

1 file changed

+65
-11
lines changed

1 file changed

+65
-11
lines changed

quantize.py

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from __future__ import annotations
8-
98
import json
109
from functools import reduce
1110
from math import gcd
@@ -140,10 +139,72 @@ def convert_for_runtime(self) -> nn.Module:
140139
def quantized_model(self) -> nn.Module:
141140
return self.model.to(device=device, **kwargs)
142141

142+
143+
#########################################################################
144+
### QuantHandler API definition ###
145+
### (unify with torchao in future) ###
146+
147+
class QuantHandler:
148+
def __init__(self, mod, device = "cpu", tokenizer = None):
149+
self.mod = mod
150+
self.device = device
151+
self.tokenizer = tokenizer
152+
153+
def create_quantized_state_dict(self) -> Dict: # "StateDict"
154+
pass
155+
156+
def convert_for_runtime(self) -> nn.Module:
157+
pass
158+
159+
def quantized_model(self) -> nn.Module:
160+
model_updated_state_dict = self.create_quantized_state_dict()
161+
self.convert_for_runtime()
162+
self.mod.load_state_dict(model_updated_state_dict)
163+
return self.mod
164+
143165

144166
#########################################################################
145-
##### Quantization Primitives ######
167+
### QuantHandler wrapper for a8w4dq from torchao ###
168+
169+
class Int8DynActInt4WeightQuantizer(QuantHandler):
170+
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer as aoInt8DynActInt4WeightQuantizer
171+
172+
def __init__(self, mod, device = "cpu", tokenizer = None, **kwargs):
173+
self.mod = mod
174+
self.device = device
175+
self.tokenizer = tokenizer
176+
self.quantizer = aoInt8DynActInt4WeightQuantizer(**kwargs)
177+
178+
def create_quantized_state_dict(self) -> Dict: # "StateDict"
179+
pass
146180

181+
def convert_for_runtime(self) -> nn.Module:
182+
pass
183+
184+
def quantized_model(self) -> nn.Module:
185+
return self.quantizer.quantize(self.model)
186+
187+
#########################################################################
188+
### QuantHandler wrapper for a8w4dq from torchao ###
189+
190+
class PrecisionHandler(QuantHandler):
191+
def __init__(self, mod, device = "cpu", tokenizer = None, **kwargs):
192+
self.mod = mod
193+
self.device = device
194+
self.tokenizer = tokenizer
195+
196+
def create_quantized_state_dict(self) -> Dict: # "StateDict"
197+
pass
198+
199+
def convert_for_runtime(self) -> nn.Module:
200+
pass
201+
202+
def quantized_model(self) -> nn.Module:
203+
return self.model.to(device=device, **kwargs)
204+
205+
206+
#########################################################################
207+
##### Quantization Primitives ######
147208

148209
def dynamically_quantize_per_channel(
149210
x,
@@ -354,7 +415,7 @@ def replace_linear_weight_only_int8_per_channel(
354415
module, device, node_type, groupsize=None
355416
):
356417
if groupsize is not None and groupsize != 0:
357-
pass
418+
pass
358419

359420
for name, child in module.named_children():
360421
# print(f"name: {name}")
@@ -808,14 +869,7 @@ def replace_linear_int4(
808869

809870
class WeightOnlyInt4QuantHandler(QuantHandler):
810871
def __init__(
811-
self,
812-
mod,
813-
device,
814-
tokenizer=None,
815-
*,
816-
groupsize=128,
817-
inner_k_tiles=8,
818-
padding_allowed=True,
872+
self, mod, device, tokenizer=None, *, groupsize=128, inner_k_tiles=8, padding_allowed=True
819873
):
820874
self.mod = mod
821875
self.device = device

0 commit comments

Comments
 (0)