Skip to content

Commit 1a24b06

Browse files
committed
add 16a4w_hqq quant mode
Prerequistie: install hqq following https://github.com/mobiusml/hqq Step 1: use hqq to quantize weight to 4bit Step 2: use static quant to quantize activation to 16bit Currently the graph calibration is too slow, so adding the the quant oberserver to the eager model for faster iteration command: ``` python -m examples.models.llama2.eval_llama -t /data/users/chenlai/models/llama2/tokenizer.model -p /data/users/chenlai/models/llama2/params.json -c /data/users/chenlai/models/llama2/consolidated.00.pth --max_seq_len 129 -qmode 16a4w-hqq --limit 5 2>&1 | tee hqq_16a4w.log ``` Differential Revision: [D57849772](https://our.internmc.facebook.com/intern/diff/D57849772/) ghstack-source-id: 227884756 Pull Request resolved: #3752
1 parent 79e9b79 commit 1a24b06

File tree

2 files changed

+263
-1
lines changed

2 files changed

+263
-1
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def build_args_parser() -> argparse.ArgumentParser:
119119
"--quantization_mode",
120120
type=str,
121121
default=None,
122-
choices=["int8", "8da4w", "8da4w-gptq"],
122+
choices=["int8", "8da4w", "8da4w-gptq", "16a4w-hqq"],
123123
help="type of quantization",
124124
)
125125

examples/models/llama2/source_transformation/quantize.py

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,23 @@
1111
import torch
1212
import torch.nn as nn
1313
import torch.nn.functional as F
14+
from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer as Tiktoken
15+
from executorch.examples.models.llama2.tokenizer.tokenizer import (
16+
Tokenizer,
17+
Tokenizer as SentencePieceTokenizer,
18+
)
19+
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear
20+
from lm_eval.api.model import LM
21+
from lm_eval.evaluator import evaluate
22+
from lm_eval.models.huggingface import HFLM as eval_wrapper
23+
from lm_eval.tasks import get_task_dict
1424

1525
from sentencepiece import SentencePieceProcessor
1626

27+
from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm
28+
from torchao.prototype.hqq.hqq_tinygemm_linear import HQQLinearTorchWeightOnlyInt4
29+
from triton.testing import do_bench
30+
1731
from ..builder import DType
1832

1933
try:
@@ -33,6 +47,198 @@
3347
fsLinear = nn.Linear
3448

3549

50+
class EagerEvalWrapper(eval_wrapper):
51+
"""
52+
A wrapper class based on GPTFast, providing integration with the lm-evaluation-harness library.
53+
"""
54+
55+
def __init__(
56+
self,
57+
model: torch.nn.Module,
58+
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
59+
max_seq_length: Optional[int] = None,
60+
use_kv_cache: bool = False,
61+
):
62+
device = "cuda" if torch.cuda.is_available() else "cpu"
63+
super().__init__(device=device)
64+
self._model = model
65+
self._tokenizer = tokenizer
66+
self._device = torch.device(device)
67+
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
68+
self._use_kv_cache = use_kv_cache
69+
70+
@property
71+
def eot_token_id(self):
72+
return self._tokenizer.eos_id
73+
74+
@property
75+
def max_length(self):
76+
return self._max_seq_length
77+
78+
@property
79+
def max_gen_toks(self):
80+
return 50
81+
82+
@property
83+
def batch_size(self):
84+
return 1
85+
86+
@property
87+
def device(self):
88+
return self._device
89+
90+
def tok_encode(self, string: str, **kwargs):
91+
tokens = self._tokenizer.encode(string, bos=True, eos=False)
92+
encoded = torch.tensor(tokens, dtype=torch.int, device=self.device)
93+
# encoded is a pytorch tensor, but some internal logic in the
94+
# eval harness expects it to be a list instead
95+
# TODO: verify this for multi-batch as well
96+
encoded = encoded.tolist()
97+
return encoded
98+
99+
def tok_decode(self, tokens):
100+
decoded = self._tokenizer.decode(tokens)
101+
return decoded
102+
103+
def _model_call(self, inps):
104+
bsz, seq_len = inps.shape
105+
if self._use_kv_cache:
106+
pos_tensor = torch.arange(
107+
self._max_seq_length, dtype=torch.int64, device=self.device
108+
)
109+
110+
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
111+
return logits
112+
else:
113+
logits = self._model(inps)
114+
return logits
115+
116+
def _model_generate(self, context, max_length, eos_token_id):
117+
raise Exception("unimplemented")
118+
119+
120+
@torch.no_grad()
121+
def eval(
122+
eval_wrapper: LM,
123+
tasks: Optional[list] = None,
124+
limit: Optional[int] = None,
125+
) -> dict:
126+
"""
127+
Evaluates a language model on a specified task using the lm-evaluation-harness library.
128+
Args:
129+
eval_wrapper (LM): A LM wrapper class compatible with lm-evaluation-harness evaluation
130+
task (str): The name of the evaluation task to perform.
131+
limit (Optional[int]): The maximum number of samples to evaluate (None for all available).
132+
Returns:
133+
eval_results (dict): A dictionary of evaluation results for the specified task(s).
134+
"""
135+
if tasks is None:
136+
tasks = ["wikitext"]
137+
if "hendrycks_test" in tasks:
138+
tasks.remove("hendrycks_test")
139+
tasks += list(lm_eval.tasks.hendrycks_test.create_all_tasks().keys())
140+
task_dict = get_task_dict(tasks)
141+
eval_results = evaluate(
142+
eval_wrapper,
143+
task_dict,
144+
limit=limit,
145+
)
146+
return eval_results
147+
148+
149+
def run_wikitext_eval(m, tokenizer_path, seq_len):
150+
print("run_wikitext_eval calibration...")
151+
print("tokenizer_path: ", tokenizer_path)
152+
tokenizer = Tokenizer(str(tokenizer_path))
153+
eval_wrapper = EagerEvalWrapper(
154+
model=m.to(device="cuda"),
155+
tokenizer=tokenizer,
156+
max_seq_length=seq_len,
157+
use_kv_cache=False,
158+
)
159+
eval_results = eval(
160+
eval_wrapper,
161+
tasks=["wikitext"],
162+
limit=128,
163+
# limit=5,
164+
# limit=1,
165+
)
166+
for task, res in eval_results["results"].items():
167+
print(f"{task}: {res}")
168+
169+
170+
class LinearActFakeQuant(torch.nn.Module):
171+
def __init__(self, linear):
172+
super().__init__()
173+
self.linear = linear
174+
self.activation_fake_quant = torch.quantization.FakeQuantize(
175+
observer=torch.quantization.MovingAverageMinMaxObserver,
176+
dtype=torch.int32,
177+
quant_min=torch.iinfo(torch.uint16).min,
178+
quant_max=torch.iinfo(torch.uint16).max,
179+
)
180+
181+
def forward(self, x):
182+
x = self.activation_fake_quant(x)
183+
return self.linear(x)
184+
185+
186+
class LinearActQuant(torch.nn.Module):
187+
def __init__(self, linear_fake_quant):
188+
super().__init__()
189+
self.linear_fake_quant = linear_fake_quant
190+
self.quant_min = self.linear_fake_quant.activation_fake_quant.quant_min
191+
self.quant_max = self.linear_fake_quant.activation_fake_quant.quant_max
192+
qparams = self.linear_fake_quant.activation_fake_quant.calculate_qparams()
193+
self.scale = qparams[0]
194+
self.zero_point = qparams[1]
195+
196+
def forward(self, x):
197+
q_tensor = torch.round(x / self.scale + self.zero_point)
198+
# Clip to ensure within the range [quant_min, quant_max]
199+
q_tensor = torch.clamp(q_tensor, self.quant_min, self.quant_max)
200+
# Dequantize to the original scale
201+
dq_tensor = (q_tensor - self.zero_point) * self.scale
202+
linear_output = self.linear(dq_tensor)
203+
204+
# Quantize the linear output tensor
205+
q_linear_output = torch.round(linear_output / self.scale + self.zero_point)
206+
q_linear_output = torch.clamp(q_linear_output, self.quant_min, self.quant_max)
207+
# Dequantize the linear output tensor
208+
dq_linear_output = (q_linear_output - self.zero_point) * self.scale
209+
210+
return dq_linear_output
211+
212+
def _replace_linear_q_act(module: torch.nn.Module, stage: str):
213+
for name, child in module.named_children():
214+
if stage == "convert":
215+
if isinstance(child, LinearActFakeQuant):
216+
new_linear = LinearActQuant(child)
217+
setattr(module, name, new_linear)
218+
else:
219+
_replace_linear_q_act(child, stage)
220+
elif stage == "prepare":
221+
if isinstance(child, HQQLinear):
222+
new_linear = LinearActFakeQuant(child)
223+
setattr(module, name, new_linear)
224+
else:
225+
_replace_linear_q_act(child, stage)
226+
227+
228+
def replace_linear_q_act(module: torch.nn.Module, stage: str):
229+
_replace_linear_q_act(
230+
module,
231+
stage,
232+
)
233+
234+
235+
def prepare(model):
236+
replace_linear_q_act(model, "prepare")
237+
238+
239+
def convert(model):
240+
replace_linear_q_act(model, "convert")
241+
36242
def quantize(
37243
model: torch.nn.Module,
38244
qmode: str,
@@ -127,6 +333,62 @@ def quantize(
127333
group_size,
128334
)
129335
model = gptq_quantizer.quantize(model, inputs)
336+
return model
337+
elif qmode == "16a4w-hqq":
338+
print("running 16a4w-hqq")
339+
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear
340+
341+
def _replace_linear_16a4w_hqq(
342+
module: torch.nn.Module,
343+
quant_config,
344+
compute_dtype,
345+
del_orig=False,
346+
):
347+
for name, child in module.named_children():
348+
if isinstance(child, nn.Linear):
349+
new_linear = HQQLinear(
350+
child, quant_config, compute_dtype=compute_dtype, del_orig=False
351+
) # , device="cpu")
352+
setattr(module, name, new_linear)
353+
else:
354+
_replace_linear_16a4w_hqq(
355+
child,
356+
quant_config,
357+
compute_dtype,
358+
del_orig=False,
359+
)
360+
361+
def replace_linear_16a4w_hqq(
362+
module: torch.nn.Module,
363+
quant_config,
364+
compute_dtype,
365+
del_orig=False,
366+
):
367+
_replace_linear_16a4w_hqq(
368+
module,
369+
quant_config,
370+
compute_dtype,
371+
del_orig=False,
372+
)
373+
374+
compute_dtype = torch.float32 # torch.bfloat16 #[torch.float16, torch.bfloat16]
375+
quant_config = BaseQuantizeConfig(
376+
quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False
377+
)
378+
print("before replace_linear_16a4w_hqq model: ", model)
379+
replace_linear_16a4w_hqq(model, quant_config, compute_dtype)
380+
print("after replace_linear_16a4w_hqq model: ", model)
381+
382+
print("model before prepare: ", model)
383+
prepare(model)
384+
print("model after prepare: ", model)
385+
386+
# x = torch.tensor([[1]], device="cuda")
387+
# _ = model(x)
388+
run_wikitext_eval(model, tokenizer_path, 128)
389+
print("model after calibrate: ", model)
390+
convert(model)
391+
130392
return model
131393
else:
132394
raise Exception(f"Unrecognized quantize mode: {qmode}")

0 commit comments

Comments
 (0)