Skip to content

Commit f679349

Browse files
namtranaseTrần Đức NamLe Hoang Anhggerganov
authored
llama : add AWQ for llama, llama2, mpt, and mistral models (#4593)
* update: awq support llama-7b model * update: change order * update: benchmark results for llama2-7b * update: mistral 7b v1 benchmark * update: support 4 models * fix: Readme * update: ready for PR * update: readme * fix: readme * update: change order import * black * format code * update: work for bot mpt and awqmpt * update: readme * Rename to llm_build_ffn_mpt_awq * Formatted other files * Fixed params count * fix: remove code * update: more detail for mpt * fix: readme * fix: readme * update: change folder architecture * fix: common.cpp * fix: readme * fix: remove ggml_repeat * update: cicd * update: cicd * uppdate: remove use_awq arg * update: readme * llama : adapt plamo to new ffn ggml-ci --------- Co-authored-by: Trần Đức Nam <[email protected]> Co-authored-by: Le Hoang Anh <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 879b690 commit f679349

File tree

8 files changed

+443
-5
lines changed

8 files changed

+443
-5
lines changed

awq-py/README.md

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# AWQ: Activation-aware Weight Quantization for LLM - version apply to llamacpp
2+
[[Paper](https://arxiv.org/abs/2306.00978)][[Original Repo](https://github.com/mit-han-lab/llm-awq)][[Easy-to-use Repo](https://github.com/casper-hansen/AutoAWQ)]
3+
4+
**Supported models:**
5+
6+
- [X] LLaMA
7+
- [x] LLaMA 2
8+
- [X] MPT
9+
- [X] Mistral AI v0.1
10+
- [ ] Bloom
11+
- [ ] Mixtral MoE
12+
13+
**TODO:**
14+
- [x] Update version work with both MPT and MPT-AWQ model
15+
- [ ] Add OPT model
16+
- [ ] Add Bloom model
17+
- [ ] Add Mixtral MoE
18+
- [ ] Support w3, w2
19+
20+
21+
## Contents
22+
23+
- [Install](##Install)
24+
- [Convert](##Convert)
25+
- [Quantize](##Quantize)
26+
- [Test](##Test)
27+
- [Benchmark](##Benchmark)
28+
- [Results](##Results)
29+
30+
## Install
31+
Install requirements
32+
```bash
33+
pip install -r requirements.txt
34+
```
35+
Get the pre-computed AWQ search results for multiple model families, including LLaMA, LLaMA2, MPT, OPT
36+
```bash
37+
git clone https://huggingface.co/datasets/mit-han-lab/awq-model-zoo awq_cache
38+
```
39+
40+
## Convert
41+
Example for llama model
42+
```bash
43+
# For llama7b and llama2 models
44+
python convert.py models/llama-7b/ --awq-path awq_cache/llama-7b-w4-g128.pt --outfile models/llama_7b_fp16.gguf
45+
# For mistral and mpt models
46+
python convert-hf-to-gguf.py models/mpt-7b/ --awq-path awq_cache/llama-7b-w4-g128.pt --outfile models/mpt_7b_fp16.gguf
47+
```
48+
49+
## Quantize
50+
```bash
51+
# We only benchmark and confirm the results on q4_0, q4_1, and q2_k types.
52+
./quantize models/llama_7b_fp16.gguf models/llama_7b_q4_0.gguf q4_0
53+
```
54+
55+
## Test
56+
```bash
57+
# For all models.
58+
./build/bin/main -m models/llama_7b_q4_0.gguf -n 128 --prompt "Once upon a time"
59+
```
60+
61+
## Benchmark
62+
The perplexity measurements in table above are done against the `wikitext2` test dataset (https://paperswithcode.com/dataset/wikitext-2), with context length of 512.
63+
```bash
64+
# For llama and llama2, and mistral models.
65+
./perplexity -m models/llama_7b_q4_0.gguf -f datasets/wikitext-2-raw/wiki.test.raw
66+
```
67+
68+
## Results
69+
Results are run on OpenBLAS (CPU) and CuBLAS (GPU) for fair comparison
70+
We use three types of llamacpp quantization methods to work with our version, including q4_0, q4_1, and q2_k
71+
72+
### Llama 7B (Build with OpenBLAS)
73+
74+
| Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K |
75+
|-----------:|--------------|-------:|-------:|-------:|-------:|
76+
|Llama 7B | perplexity | 5.9066 | 6.1214 | 6.0643 | 6.5808 |
77+
|Llama 7B | file size | 12.9G | 3.5G | 3.9G | 2.7G |
78+
|Llama 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
79+
|AWQ-LLama 7B| perplexity | 5.9175 | 6.0252 | 5.9987 | 6.3692 |
80+
|AWQ-LLama 7B| file size | 12.9G | 3.5G | 3.9G | 2.7G |
81+
|AWQ-LLama 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
82+
83+
84+
### Llama2 7B (Build with CuBLAS)
85+
86+
| Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K |
87+
|------------:|--------------|-------:|-------:|-------:|-------:|
88+
|Llama2 7B | perplexity | 5.8664 | 6.0260 | 6.0656 | 6.4496 |
89+
|Llama2 7B | file size | 12.9G | 3.5G | 3.9G | 2.7G |
90+
|Llama2 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
91+
|AWQ-LLama2 7B| perplexity | 5.8801 | 6.0054 | 5.9849 | 6.3650 |
92+
|AWQ-LLama2 7B| file size | 12.9G | 3.5G | 3.9G | 2.7G |
93+
|AWQ-LLama2 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
94+
95+
96+
### Mistral 7B v0.1 (Build with CuBLAS)
97+
98+
| Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K |
99+
|-------------:|--------------|-------:|-------:|-------:|-------:|
100+
|Mistral 7B | perplexity | 5.6931 | 5.8202 | 5.8268 | 6.1645 |
101+
|Mistral 7B | file size | 14.5G | 4.1G | 4.5G | 3.1G |
102+
|Mistral 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
103+
|AWQ-Mistral 7B| perplexity | 5.6934 | 5.8020 | 5.7691 | 6.0426 |
104+
|AWQ-Mistral 7B| file size | 14.5G | 4.1G | 4.5G | 3.1G |
105+
|AWQ-Mistral 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
106+
107+
### MPT 7B (Build with OpenBLAS)
108+
109+
| Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K |
110+
|---------:|--------------|-------:|-------:|-------:|--------:|
111+
|MPT 7B | perplexity | 8.4369 | 8.7956 | 8.6265 | 11.4913 |
112+
|MPT 7B | file size | 13.7G | 3.9G | 4.3G | 2.8G |
113+
|MPT 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
114+
|AWQ-MPT 7B| perplexity | 8.4944 | 8.7053 | 8.6750 | 10.2873|
115+
|AWQ-MPT 7B| file size | 13.7G | 3.9G | 4.3G | 2.8G |
116+
|AWQ-MPT 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |

awq-py/awq/apply_awq.py

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
"""
2+
Implements the AWQ for llama.cpp use cases.
3+
Original paper: https://arxiv.org/abs/2306.00978
4+
5+
This code is based on versions of the AWQ implementation found in the following repositories:
6+
* https://github.com/mit-han-lab/llm-awq
7+
* https://github.com/casper-hansen/AutoAWQ
8+
"""
9+
10+
import os
11+
import torch
12+
import torch.nn as nn
13+
14+
from transformers import AutoModelForCausalLM, AutoConfig
15+
from transformers.models.bloom.modeling_bloom import BloomGelu
16+
from transformers.models.llama.modeling_llama import LlamaRMSNorm
17+
from transformers.activations import GELUActivation
18+
19+
20+
class ScaledActivation(nn.Module):
21+
"""
22+
ScaledActivation module wraps an existing activation function and applies a
23+
scale factor to its output.
24+
25+
Args:
26+
module (nn.Module): The activation function to be scaled.
27+
scales (torch.Tensor): A tensor of size (num_features,) containing the initial
28+
scale factors for each feature.
29+
30+
Returns:
31+
torch.Tensor: The scaled output of the activation function.
32+
"""
33+
34+
def __init__(self, module, scales):
35+
super().__init__()
36+
self.act = module
37+
self.scales = nn.Parameter(scales.data)
38+
39+
def forward(self, x):
40+
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
41+
42+
43+
def set_op_by_name(layer, name, new_module):
44+
"""
45+
Set the new module for given module's name.
46+
47+
Args:
48+
layer (nn.Module): The layer in which to replace the submodule.
49+
name (str): The path to the submodule to be replaced, using dot notation
50+
to access nested modules.
51+
new_module (nn.Module): The new module to replace the existing one.
52+
"""
53+
levels = name.split(".")
54+
if len(levels) > 1:
55+
mod_ = layer
56+
for l_idx in range(len(levels) - 1):
57+
if levels[l_idx].isdigit():
58+
mod_ = mod_[int(levels[l_idx])]
59+
else:
60+
mod_ = getattr(mod_, levels[l_idx])
61+
setattr(mod_, levels[-1], new_module)
62+
else:
63+
setattr(layer, name, new_module)
64+
65+
66+
def get_op_by_name(module, op_name):
67+
"""
68+
Retrieves a submodule within a given layer based on its name.
69+
70+
Args:
71+
module (nn.Module): The layer containing the submodule to find.
72+
op_name (str): The name of the submodule.
73+
74+
Returns:
75+
nn.Module: The requested submodule found within the given layer.
76+
77+
Raises:
78+
ValueError: If the specified submodule cannot be found within the layer.
79+
"""
80+
for name, m in module.named_modules():
81+
if name == op_name:
82+
return m
83+
raise ValueError(f"Cannot find op {op_name} in module {module}")
84+
85+
86+
@torch.no_grad()
87+
def scale_ln_fcs(ln, fcs, scales):
88+
"""
89+
Scales the weights of a LayerNorm and a list of fully-connected layers proportionally.
90+
91+
Args:
92+
ln (nn.LayerNorm): The LayerNorm module to be scaled.
93+
fcs (List[nn.Linear]): A list of fully-connected layers to be scaled.
94+
scales (torch.Tensor): A 1D tensor of size (num_features,).
95+
"""
96+
97+
if not isinstance(fcs, list):
98+
fcs = [fcs]
99+
100+
scales = scales.to(ln.weight.device)
101+
102+
ln.weight.div_(scales)
103+
if hasattr(ln, "bias") and ln.bias is not None:
104+
ln.bias.div_(scales)
105+
106+
for fc in fcs:
107+
fc.weight.mul_(scales.view(1, -1))
108+
109+
for p in ln.parameters():
110+
assert torch.isnan(p).sum() == 0
111+
for fc in fcs:
112+
for p in fc.parameters():
113+
assert torch.isnan(p).sum() == 0
114+
115+
116+
@torch.no_grad()
117+
def scale_fc_fc(fc1, fc2, scales):
118+
"""
119+
Scales the weights of two fully-connected layers in a specific pattern.
120+
121+
Args:
122+
fc1 (nn.Linear): The first fully-connected layer to be scaled.
123+
fc2 (nn.Linear): The second fully-connected layer to be scaled.
124+
scales (torch.Tensor): A 1D tensor of size (num_features,).
125+
"""
126+
assert isinstance(fc1, nn.Linear)
127+
assert isinstance(fc2, nn.Linear)
128+
129+
scales = scales.to(fc1.weight.device)
130+
131+
fc1.weight[-scales.size(0):].div_(scales.view(-1, 1))
132+
if fc1.bias is not None:
133+
fc1.bias.div_(scales.view(-1))
134+
135+
fc2.weight.mul_(scales.view(1, -1))
136+
137+
for p in fc1.parameters():
138+
assert torch.isnan(p).sum() == 0
139+
for p in fc2.parameters():
140+
assert torch.isnan(p).sum() == 0
141+
142+
143+
@torch.no_grad()
144+
def scale_gelu_fc(gelu, fc, scales):
145+
"""
146+
Scales the weight of a GELU activation and a fully-connected layer proportionally.
147+
148+
Args:
149+
gelu (Union[nn.GELU, BloomGelu, GELUActivation]): The GELU activation module to be scaled.
150+
fc (nn.Linear): The fully-connected layer to be scaled.
151+
scales (torch.Tensor): A 1D tensor of size (num_features,).
152+
153+
Raises:
154+
TypeError: If the `gelu` module is not of type `nn.GELU`, `BloomGelu`, or `GELUActivation`.
155+
TypeError: If the `fc` module is not of type `nn.Linear`.
156+
"""
157+
assert isinstance(gelu, (nn.GELU, BloomGelu, GELUActivation))
158+
assert isinstance(fc, nn.Linear)
159+
160+
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
161+
162+
for p in fc.parameters():
163+
assert torch.isnan(p).sum() == 0
164+
165+
166+
def apply_scale(module, scales_list, input_feat_dict=None):
167+
"""
168+
Applies different scaling strategies to layers based on their type and hierarchy within a given module.
169+
170+
Args:
171+
module (nn.Module): The module containing the layers to be scaled.
172+
scales_list (List[Tuple[str, List[str], torch.Tensor]]): A list of tuples containing:
173+
* prev_op_name (str): The name of the preceding operation or module,
174+
relative to which the layers to be scaled are located.
175+
* layer_names (List[str]): A list of names of the layers to be scaled, relative to the preceding operation.
176+
* scales (torch.Tensor): A 1D tensor of size (num_features,) containing the scaling factors for each feature.
177+
input_feat_dict (Optional[Dict[str, torch.Tensor]]): A dictionary mapping layer names to their corresponding
178+
input features (optional).
179+
"""
180+
for prev_op_name, layer_names, scales in scales_list:
181+
prev_op = get_op_by_name(module, prev_op_name)
182+
layers = [get_op_by_name(module, name) for name in layer_names]
183+
184+
prev_op.cuda()
185+
for layer in layers:
186+
layer.cuda()
187+
scales.cuda()
188+
189+
if isinstance(prev_op, nn.Linear):
190+
assert len(layers) == 1
191+
scale_fc_fc(prev_op, layers[0], scales)
192+
elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)) or "rmsnorm" in str(prev_op.__class__).lower():
193+
scale_ln_fcs(prev_op, layers, scales)
194+
elif isinstance(prev_op, (nn.GELU, BloomGelu, GELUActivation)):
195+
new_module = ScaledActivation(prev_op, scales)
196+
set_op_by_name(module, prev_op_name, new_module)
197+
scale_gelu_fc(prev_op, layers[0], scales)
198+
else:
199+
raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!")
200+
201+
# apply the scaling to input feat if given; prepare it for clipping
202+
if input_feat_dict is not None:
203+
for layer_name in layer_names:
204+
inp = input_feat_dict[layer_name]
205+
inp.div_(scales.view(1, -1).to(inp.device))
206+
207+
prev_op.cpu()
208+
for layer in layers:
209+
layer.cpu()
210+
scales.cpu()
211+
212+
213+
@torch.no_grad()
214+
def apply_clip(module, clip_list):
215+
"""
216+
Applies element-wise clipping to the weight of a specific layer within a given module.
217+
218+
Args:
219+
module (nn.Module): The module containing the layer to be clipped.
220+
clip_list (List[Tuple[str, torch.Tensor]]): A list of tuples containing:
221+
* name (str): The name of the layer to be clipped, relative to the root of the module.
222+
* max_val (torch.Tensor): A 1D or 2D tensor defining the upper bound for each element of the layer's weight.
223+
"""
224+
for name, max_val in clip_list:
225+
layer = get_op_by_name(module, name)
226+
layer.cuda()
227+
max_val = max_val.to(layer.weight.device)
228+
org_shape = layer.weight.shape
229+
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
230+
layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
231+
layer.weight.data = layer.weight.data.reshape(org_shape)
232+
layer.cpu()
233+
234+
235+
def add_scale_weights(model_path, scale_path, tmp_path):
236+
"""
237+
Adds pre-computed Activation Weight Quantization (AWQ) results to a model,
238+
including scaling factors and clipping bounds.
239+
240+
Args:
241+
model_path (str): Path to the pre-trained model to be equipped with AWQ.
242+
scale_path (str): Path to the AWQ scale factors (.pt file).
243+
tmp_path (str): Path to the temporary directory where the equipped model will be saved.
244+
"""
245+
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
246+
model = AutoModelForCausalLM.from_pretrained(
247+
model_path, config=config, trust_remote_code=True
248+
)
249+
model.eval()
250+
awq_results = torch.load(str(scale_path), map_location="cpu")
251+
apply_scale(model, awq_results["scale"])
252+
apply_clip(model, awq_results["clip"])
253+
model.save_pretrained(str(tmp_path))
254+
os.system(f"cp {str(model_path)}/tokenizer* {str(tmp_path)}")

awq-py/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torch>=2.0.0
2+
transformers>=4.32.0

0 commit comments

Comments
 (0)