Skip to content

Commit 713bc40

Browse files
committed
add 16a4w_hqq quant mode
Pull Request resolved: #3752 Prerequistie: install hqq following https://github.com/mobiusml/hqq Step 1: use hqq to quantize weight to 4bit Step 2: use static quant to quantize activation to 16bit Currently the graph calibration is too slow, so adding the the quant oberserver to the eager model for faster iteration command: ``` python -m examples.models.llama2.eval_llama -t /data/users/chenlai/models/llama2/tokenizer.model -p /data/users/chenlai/models/llama2/params.json -c /data/users/chenlai/models/llama2/consolidated.00.pth --max_seq_len 129 -qmode 16a4w-hqq --limit 5 2>&1 | tee hqq_16a4w.log ``` ghstack-source-id: 228051126 Differential Revision: [D57849772](https://our.internmc.facebook.com/intern/diff/D57849772/)
1 parent 1f9a1c0 commit 713bc40

File tree

3 files changed

+245
-1
lines changed

3 files changed

+245
-1
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def build_args_parser() -> argparse.ArgumentParser:
119119
"--quantization_mode",
120120
type=str,
121121
default=None,
122-
choices=["int8", "8da4w", "8da4w-gptq"],
122+
choices=["int8", "8da4w", "8da4w-gptq", "16a4w-hqq"],
123123
help="type of quantization",
124124
)
125125

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.examples.models.llama2.evaluate import EagerEvalWrapper, evaluate_model
9+
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear
10+
11+
########################## Run HQQ ###############################
12+
13+
14+
def _replace_linear_4w_hqq(
15+
module: torch.nn.Module,
16+
quant_config,
17+
compute_dtype,
18+
del_orig=False,
19+
):
20+
"""
21+
Recursively replacing all Linear layers with HQQLinear with the 4bit quantized weights
22+
"""
23+
for name, child in module.named_children():
24+
if isinstance(child, torch.nn.Linear):
25+
new_linear = HQQLinear(
26+
child,
27+
quant_config,
28+
compute_dtype=compute_dtype,
29+
del_orig=True,
30+
device="cpu",
31+
)
32+
setattr(module, name, new_linear)
33+
else:
34+
_replace_linear_4w_hqq(
35+
child,
36+
quant_config,
37+
compute_dtype,
38+
del_orig=False,
39+
)
40+
41+
42+
def replace_linear_4w_hqq(
43+
module: torch.nn.Module,
44+
quant_config: BaseQuantizeConfig,
45+
compute_dtype,
46+
del_orig=False,
47+
):
48+
"""
49+
Replace all Linear layers with HQQLinear with the 4bit quantized weights
50+
"""
51+
_replace_linear_4w_hqq(
52+
module,
53+
quant_config,
54+
compute_dtype,
55+
del_orig=False,
56+
)
57+
58+
59+
def run_hqq_quantize(model: torch.nn.Module) -> None:
60+
"""
61+
Inplace update the model with the hqq quantized weights
62+
"""
63+
64+
quant_config = BaseQuantizeConfig(
65+
quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False
66+
)
67+
68+
replace_linear_4w_hqq(model, quant_config=quant_config, compute_dtype=torch.float32)
69+
70+
71+
########################## Use static quantization with HQQ Linear ###############################
72+
73+
74+
def calibrate(
75+
model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length
76+
):
77+
print("run calibration...")
78+
eval_wrapper = EagerEvalWrapper(
79+
model=model,
80+
tokenizer=tokenizer,
81+
max_seq_length=calibration_seq_length,
82+
use_kv_cache=False,
83+
)
84+
eval_results = evaluate_model(
85+
eval_wrapper,
86+
tasks=calibration_tasks,
87+
limit=calibration_limit,
88+
)
89+
for task, res in eval_results["results"].items():
90+
print(f"Reference result with hqq model: {task}: {res}")
91+
92+
93+
class LinearActivationFakeQuant(torch.nn.Module):
94+
def __init__(self, linear):
95+
super().__init__()
96+
self.linear = linear
97+
self.input_activation_fake_quant = torch.quantization.FakeQuantize(
98+
observer=torch.quantization.MovingAverageMinMaxObserver,
99+
dtype=torch.int32,
100+
quant_min=torch.iinfo(torch.uint16).min,
101+
quant_max=torch.iinfo(torch.uint16).max,
102+
)
103+
self.output_activation_fake_quant = torch.quantization.FakeQuantize(
104+
observer=torch.quantization.MovingAverageMinMaxObserver,
105+
dtype=torch.int32,
106+
quant_min=torch.iinfo(torch.uint16).min,
107+
quant_max=torch.iinfo(torch.uint16).max,
108+
)
109+
110+
def forward(self, x):
111+
x = self.input_activation_fake_quant(x)
112+
return self.output_activation_fake_quant(self.linear(x))
113+
114+
115+
def get_quant_params(activation_fake_quant):
116+
quant_min = activation_fake_quant.quant_min
117+
quant_max = activation_fake_quant.quant_max
118+
qparams = activation_fake_quant.calculate_qparams()
119+
scale = qparams[0]
120+
zero_point = qparams[1]
121+
return (quant_min, quant_max, scale, zero_point)
122+
123+
124+
class LinearActivationQuant(torch.nn.Module):
125+
126+
def __init__(self, linear_fake_quant):
127+
super().__init__()
128+
self.linear_fake_quant = linear_fake_quant
129+
(
130+
self.input_quant_min,
131+
self.input_quant_max,
132+
self.input_scale,
133+
self.input_zero_point,
134+
) = get_quant_params(linear_fake_quant.input_activation_fake_quant)
135+
136+
(
137+
self.output_quant_min,
138+
self.output_quant_max,
139+
self.output_scale,
140+
self.output_zero_point,
141+
) = get_quant_params(linear_fake_quant.output_activation_fake_quant)
142+
143+
def forward(self, x):
144+
# Manually quantize the input tensor using observed min and max values
145+
q_tensor = torch.round(x / self.input_scale + self.input_zero_point)
146+
# Clip to ensure within the range [quant min and quant max]
147+
q_tensor = torch.clamp(q_tensor, self.input_quant_min, self.input_quant_max)
148+
# Dequantize to the original scale
149+
dequantized_tensor = (q_tensor - self.input_zero_point) * self.input_scale
150+
151+
linear_output = self.linear_fake_quant.linear(dequantized_tensor)
152+
153+
# # Quantize the linear output tensor
154+
q_linear_output = torch.round(
155+
linear_output / self.output_scale + self.output_zero_point
156+
)
157+
q_linear_output = torch.clamp(
158+
q_linear_output, self.output_quant_min, self.output_quant_max
159+
)
160+
# Dequantize the linear output tensor
161+
dq_linear_output = (
162+
q_linear_output - self.output_zero_point
163+
) * self.output_scale
164+
165+
return dq_linear_output
166+
167+
168+
def _replace_linear_quant_activation(module: torch.nn.Module, stage: str):
169+
for name, child in module.named_children():
170+
if stage == "convert":
171+
if isinstance(child, LinearActivationFakeQuant):
172+
new_linear = LinearActivationQuant(child)
173+
setattr(module, name, new_linear)
174+
else:
175+
_replace_linear_quant_activation(child, stage)
176+
elif stage == "prepare":
177+
if isinstance(child, HQQLinear):
178+
new_linear = LinearActivationFakeQuant(child)
179+
setattr(module, name, new_linear)
180+
else:
181+
_replace_linear_quant_activation(child, stage)
182+
else:
183+
raise ValueError(f"Unsupported stage {stage}")
184+
185+
186+
def replace_linear_quant_activation(module: torch.nn.Module, stage: str):
187+
_replace_linear_quant_activation(
188+
module,
189+
stage,
190+
)
191+
192+
193+
def prepare(model):
194+
"""
195+
Prepare the model for quantization by manually inserting the observors
196+
"""
197+
replace_linear_quant_activation(model, "prepare")
198+
199+
200+
def convert(model):
201+
"""
202+
Convert the observors the actual quant/dequant nodes, in this implementation, we manually
203+
calling add, mul, clamp for quick prototyping
204+
"""
205+
replace_linear_quant_activation(model, "convert")

examples/models/llama2/source_transformation/quantize.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
import torch.nn as nn
1313
import torch.nn.functional as F
14+
from executorch.examples.models.llama2.tokenizer.tokenizer import Tokenizer
1415

1516
from sentencepiece import SentencePieceProcessor
1617

@@ -127,6 +128,44 @@ def quantize(
127128
group_size,
128129
)
129130
model = gptq_quantizer.quantize(model, inputs)
131+
return model
132+
elif qmode == "16a4w-hqq":
133+
try:
134+
from executorch.examples.models.llama2.source_transformation import (
135+
hqq_16a4w,
136+
)
137+
except ImportError:
138+
print(
139+
"Please follow instruction in https://github.com/mobiusml/hqq to install the latest version."
140+
)
141+
if calibration_tasks is None:
142+
calibration_tasks = ["wikitext"]
143+
if calibration_limit is None:
144+
calibration_limit = 5
145+
if calibration_seq_length is None:
146+
calibration_seq_length = 128
147+
if tokenizer_path is None:
148+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
149+
assert tokenizer_path.is_file(), tokenizer_path
150+
tokenizer = Tokenizer(model_path=str(tokenizer_path)) # pyre-ignore[28]
151+
152+
# Step 1: Run hqq quantization, the linear inside the model will be replaced with HQQ linear
153+
hqq_16a4w.run_hqq_quantize(model)
154+
155+
# Run hqq quantization first
156+
# Insert observer
157+
hqq_16a4w.prepare(model)
158+
# Calibration
159+
hqq_16a4w.calibrate(
160+
model=model,
161+
tokenizer=tokenizer,
162+
calibration_tasks=calibration_tasks,
163+
calibration_limit=calibration_limit,
164+
calibration_seq_length=calibration_seq_length,
165+
)
166+
# Convert observer to the fake quantized model
167+
hqq_16a4w.convert(model)
168+
130169
return model
131170
else:
132171
raise Exception(f"Unrecognized quantize mode: {qmode}")

0 commit comments

Comments
 (0)