Skip to content

Commit 8a569cb

Browse files
committed
add 16a4w_hqq quant mode
Pull Request resolved: #3752 Prerequistie: install hqq following https://github.com/mobiusml/hqq Step 1: use hqq to quantize weight to 4bit Step 2: use static quant to quantize activation to 16bit Currently the graph calibration is too slow, so adding the the quant oberserver to the eager model for faster iteration command: ``` python -m examples.models.llama2.eval_llama -t /data/users/chenlai/models/llama2/tokenizer.model -p /data/users/chenlai/models/llama2/params.json -c /data/users/chenlai/models/llama2/consolidated.00.pth --max_seq_len 129 -qmode 16a4w-hqq --limit 5 2>&1 | tee hqq_16a4w.log ``` Differential Revision: [D57849772](https://our.internmc.facebook.com/intern/diff/D57849772/) ghstack-source-id: 227950317
1 parent 79e9b79 commit 8a569cb

File tree

3 files changed

+300
-4
lines changed

3 files changed

+300
-4
lines changed

examples/models/llama2/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
3434
from torch.nn.attention import SDPBackend
3535

36-
from ...portable.utils import export_to_edge, save_pte_program
36+
from examples.portable.utils import export_to_edge, save_pte_program
3737
from ..model_factory import EagerModelFactory
3838

3939
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"

examples/models/llama2/export_llama_lib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def build_args_parser() -> argparse.ArgumentParser:
119119
"--quantization_mode",
120120
type=str,
121121
default=None,
122-
choices=["int8", "8da4w", "8da4w-gptq"],
122+
choices=["int8", "8da4w", "8da4w-gptq", "16a4w-hqq"],
123123
help="type of quantization",
124124
)
125125

@@ -366,8 +366,8 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
366366
)
367367
.set_output_dir(output_dir_path)
368368
.set_metadata(args.metadata)
369-
.source_transform(transforms)
370369
.to_dtype(dtype_override)
370+
.source_transform(transforms)
371371
)
372372

373373

examples/models/llama2/source_transformation/quantize.py

Lines changed: 297 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,21 @@
66

77
from functools import partial
88
from pathlib import Path
9-
from typing import Any, Dict, Optional
9+
from typing import Any, Dict, Optional, Union
1010

1111
import torch
1212
import torch.nn as nn
1313
import torch.nn.functional as F
14+
from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer as Tiktoken
15+
from executorch.examples.models.llama2.tokenizer.tokenizer import (
16+
Tokenizer,
17+
Tokenizer as SentencePieceTokenizer,
18+
)
19+
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear
20+
from lm_eval.api.model import LM
21+
from lm_eval.evaluator import evaluate
22+
from lm_eval.models.huggingface import HFLM as eval_wrapper
23+
from lm_eval.tasks import get_task_dict
1424

1525
from sentencepiece import SentencePieceProcessor
1626

@@ -33,6 +43,233 @@
3343
fsLinear = nn.Linear
3444

3545

46+
class EagerEvalWrapper(eval_wrapper):
47+
"""
48+
A wrapper class based on GPTFast, providing integration with the lm-evaluation-harness library.
49+
"""
50+
51+
def __init__(
52+
self,
53+
model: torch.nn.Module,
54+
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
55+
max_seq_length: Optional[int] = None,
56+
use_kv_cache: bool = False,
57+
):
58+
device = "cuda" if torch.cuda.is_available() else "cpu"
59+
super().__init__(device=device)
60+
self._model = model
61+
self._tokenizer = tokenizer
62+
self._device = torch.device(device)
63+
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
64+
self._use_kv_cache = use_kv_cache
65+
66+
@property
67+
def eot_token_id(self):
68+
return self._tokenizer.eos_id
69+
70+
@property
71+
def max_length(self):
72+
return self._max_seq_length
73+
74+
@property
75+
def max_gen_toks(self):
76+
return 50
77+
78+
@property
79+
def batch_size(self):
80+
return 1
81+
82+
@property
83+
def device(self):
84+
return self._device
85+
86+
def tok_encode(self, string: str, **kwargs):
87+
tokens = self._tokenizer.encode(string, bos=True, eos=False)
88+
encoded = torch.tensor(tokens, dtype=torch.int, device=self.device)
89+
# encoded is a pytorch tensor, but some internal logic in the
90+
# eval harness expects it to be a list instead
91+
# TODO: verify this for multi-batch as well
92+
encoded = encoded.tolist()
93+
return encoded
94+
95+
def tok_decode(self, tokens):
96+
decoded = self._tokenizer.decode(tokens)
97+
return decoded
98+
99+
def _model_call(self, inps):
100+
bsz, seq_len = inps.shape
101+
if self._use_kv_cache:
102+
pos_tensor = torch.arange(
103+
self._max_seq_length, dtype=torch.int64, device=self.device
104+
)
105+
106+
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
107+
return logits
108+
else:
109+
logits = self._model(inps)
110+
return logits
111+
112+
def _model_generate(self, context, max_length, eos_token_id):
113+
raise Exception("unimplemented")
114+
115+
116+
@torch.no_grad()
117+
def eval(
118+
eval_wrapper: LM,
119+
tasks: Optional[list] = None,
120+
limit: Optional[int] = None,
121+
) -> dict:
122+
"""
123+
Evaluates a language model on a specified task using the lm-evaluation-harness library.
124+
Args:
125+
eval_wrapper (LM): A LM wrapper class compatible with lm-evaluation-harness evaluation
126+
task (str): The name of the evaluation task to perform.
127+
limit (Optional[int]): The maximum number of samples to evaluate (None for all available).
128+
Returns:
129+
eval_results (dict): A dictionary of evaluation results for the specified task(s).
130+
"""
131+
if tasks is None:
132+
tasks = ["wikitext"]
133+
if "hendrycks_test" in tasks:
134+
tasks.remove("hendrycks_test")
135+
tasks += list(lm_eval.tasks.hendrycks_test.create_all_tasks().keys())
136+
task_dict = get_task_dict(tasks)
137+
eval_results = evaluate(
138+
eval_wrapper,
139+
task_dict,
140+
limit=limit,
141+
)
142+
return eval_results
143+
144+
145+
def run_wikitext_eval(m, tokenizer_path, seq_len):
146+
print("run_wikitext_eval calibration...")
147+
print("tokenizer_path: ", tokenizer_path)
148+
tokenizer = Tokenizer(str(tokenizer_path))
149+
eval_wrapper = EagerEvalWrapper(
150+
model=m,
151+
tokenizer=tokenizer,
152+
max_seq_length=seq_len,
153+
use_kv_cache=False,
154+
)
155+
eval_results = eval(
156+
eval_wrapper,
157+
tasks=["wikitext"],
158+
# limit=128,
159+
limit=5,
160+
# limit=1,
161+
)
162+
for task, res in eval_results["results"].items():
163+
print(f"{task}: {res}")
164+
165+
166+
class LinearActFakeQuant(torch.nn.Module):
167+
def __init__(self, linear):
168+
super().__init__()
169+
self.linear = linear
170+
self.input_activation_fake_quant = torch.quantization.FakeQuantize(
171+
observer=torch.quantization.MovingAverageMinMaxObserver,
172+
dtype=torch.int32,
173+
quant_min=torch.iinfo(torch.uint16).min,
174+
quant_max=torch.iinfo(torch.uint16).max,
175+
)
176+
self.output_activation_fake_quant = torch.quantization.FakeQuantize(
177+
observer=torch.quantization.MovingAverageMinMaxObserver,
178+
dtype=torch.int32,
179+
quant_min=torch.iinfo(torch.uint16).min,
180+
quant_max=torch.iinfo(torch.uint16).max,
181+
)
182+
183+
def forward(self, x):
184+
x = self.input_activation_fake_quant(x)
185+
return self.output_activation_fake_quant(self.linear(x))
186+
187+
188+
def get_quant_params(activation_fake_quant):
189+
quant_min = activation_fake_quant.quant_min
190+
quant_max = activation_fake_quant.quant_max
191+
qparams = activation_fake_quant.calculate_qparams()
192+
scale = qparams[0]
193+
zero_point = qparams[1]
194+
return (quant_min, quant_max, scale, zero_point)
195+
196+
197+
class LinearActQuant(torch.nn.Module):
198+
199+
def __init__(self, linear_fake_quant):
200+
super().__init__()
201+
self.linear_fake_quant = linear_fake_quant
202+
self.input_quant_min, self.input_quant_max, input_scale, input_zero_point = (
203+
get_quant_params(linear_fake_quant.input_activation_fake_quant)
204+
)
205+
self.input_scale = input_scale.to(device="cuda")
206+
self.input_zero_point = input_zero_point.to(device="cuda")
207+
208+
(
209+
self.output_quant_min,
210+
self.output_quant_max,
211+
output_scale,
212+
output_zero_point,
213+
) = get_quant_params(linear_fake_quant.output_activation_fake_quant)
214+
self.output_scale = output_scale.to(device="cuda")
215+
self.output_zero_point = output_zero_point.to(device="cuda")
216+
217+
def forward(self, x):
218+
# Manually quantize the input tensor using observed min and max values
219+
q_tensor = torch.round(x / self.input_scale + self.input_zero_point)
220+
# Clip to ensure within the range [0, 255]
221+
q_tensor = torch.clamp(q_tensor, self.input_quant_min, self.input_quant_max)
222+
# Dequantize to the original scale
223+
dequantized_tensor = (q_tensor - self.input_zero_point) * self.input_scale
224+
225+
linear_output = self.linear_fake_quant.linear(dequantized_tensor)
226+
227+
# # Quantize the linear output tensor
228+
q_linear_output = torch.round(
229+
linear_output / self.output_scale + self.output_zero_point
230+
)
231+
q_linear_output = torch.clamp(
232+
q_linear_output, self.output_quant_min, self.output_quant_max
233+
)
234+
# Dequantize the linear output tensor
235+
dq_linear_output = (
236+
q_linear_output - self.output_zero_point
237+
) * self.output_scale
238+
239+
return dq_linear_output
240+
241+
242+
def _replace_linear_q_act(module: torch.nn.Module, stage: str):
243+
for name, child in module.named_children():
244+
if stage == "convert":
245+
if isinstance(child, LinearActFakeQuant):
246+
new_linear = LinearActQuant(child)
247+
setattr(module, name, new_linear)
248+
else:
249+
_replace_linear_q_act(child, stage)
250+
elif stage == "prepare":
251+
if isinstance(child, HQQLinear):
252+
new_linear = LinearActFakeQuant(child)
253+
setattr(module, name, new_linear)
254+
else:
255+
_replace_linear_q_act(child, stage)
256+
257+
258+
def replace_linear_q_act(module: torch.nn.Module, stage: str):
259+
_replace_linear_q_act(
260+
module,
261+
stage,
262+
)
263+
264+
265+
def prepare(model):
266+
replace_linear_q_act(model, "prepare")
267+
268+
269+
def convert(model):
270+
replace_linear_q_act(model, "convert")
271+
272+
36273
def quantize(
37274
model: torch.nn.Module,
38275
qmode: str,
@@ -127,6 +364,65 @@ def quantize(
127364
group_size,
128365
)
129366
model = gptq_quantizer.quantize(model, inputs)
367+
return model
368+
elif qmode == "16a4w-hqq":
369+
print("running 16a4w-hqq")
370+
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear
371+
372+
def _replace_linear_16a4w_hqq(
373+
module: torch.nn.Module,
374+
quant_config,
375+
compute_dtype,
376+
del_orig=False,
377+
):
378+
for name, child in module.named_children():
379+
if isinstance(child, nn.Linear):
380+
new_linear = HQQLinear(
381+
child,
382+
quant_config,
383+
compute_dtype=compute_dtype,
384+
del_orig=True,
385+
device="cpu",
386+
)
387+
setattr(module, name, new_linear)
388+
else:
389+
_replace_linear_16a4w_hqq(
390+
child,
391+
quant_config,
392+
compute_dtype,
393+
del_orig=False,
394+
)
395+
396+
def replace_linear_16a4w_hqq(
397+
module: torch.nn.Module,
398+
quant_config,
399+
compute_dtype,
400+
del_orig=False,
401+
):
402+
_replace_linear_16a4w_hqq(
403+
module,
404+
quant_config,
405+
compute_dtype,
406+
del_orig=False,
407+
)
408+
409+
compute_dtype = torch.float32 # torch.bfloat16 #[torch.float16, torch.bfloat16]
410+
quant_config = BaseQuantizeConfig(
411+
quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False
412+
)
413+
print("before replace_linear_16a4w_hqq model: ", model)
414+
replace_linear_16a4w_hqq(model, quant_config, compute_dtype)
415+
print("after replace_linear_16a4w_hqq model: ", model)
416+
417+
print("model before prepare: ", model)
418+
prepare(model)
419+
print("model after prepare: ", model)
420+
421+
# Calibration with wikitext, currently only use 5 samples and can be fine tuned
422+
run_wikitext_eval(model, tokenizer_path, 128)
423+
print("model after calibrate: ", model)
424+
convert(model)
425+
130426
return model
131427
else:
132428
raise Exception(f"Unrecognized quantize mode: {qmode}")

0 commit comments

Comments
 (0)