Skip to content

Commit 8de8d9b

Browse files
rohansjoshifacebook-github-bot
authored andcommitted
Added quantization for evaluation script (#11822)
Summary: Added quantization to evaluation script. Quantization causes deterioriation in accuracy On wikitext task: | Model Name | max_seq_len | ptq | word_perplexity |----------|----------|----------|-----------| | Llama 3.2-1B Instruct | 128 | 16a4w | 5821003.055178451 | | Llama 3.2-1B Instruct | 128 | 16a4w_block | 5396240.078572427 | | Llama 3.2-1B Instruct | 128 | 8a8w | 533154.970440251 | Reviewed By: cccclai Differential Revision: D76837572
1 parent 496022e commit 8de8d9b

File tree

2 files changed

+110
-13
lines changed

2 files changed

+110
-13
lines changed

examples/qualcomm/oss_scripts/llama/TARGETS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ python_binary(
4949
name = "eval_llama_qnn",
5050
srcs = ["eval_llama_qnn.py"],
5151
main_function = "executorch.examples.qualcomm.oss_scripts.llama.eval_llama_qnn.main",
52+
preload_deps = [
53+
"//executorch/extension/llm/custom_ops:model_sharding_py",
54+
],
5255
deps = [
5356
":llama_lib",
5457
"//executorch/examples/models/llama:eval_library",

examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py

Lines changed: 107 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import sys
78
import argparse
89
import copy
910
import json
11+
import torch
12+
from functools import partial
13+
14+
from lm_eval.evaluator import simple_evaluate
1015

1116
from typing import List, Optional, Tuple
1217

@@ -26,32 +31,53 @@
2631

2732
from pytorch_tokenizers import get_tokenizer
2833

34+
from executorch.examples.qualcomm.oss_scripts.llama.llama import calibrate
35+
36+
from executorch.examples.qualcomm.utils import make_quantizer
37+
38+
from executorch.examples.models.llama.source_transformation.quantize import (
39+
get_quant_embedding_transform,
40+
)
41+
42+
from torchao.quantization.pt2e import MinMaxObserver
43+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
44+
45+
46+
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
47+
from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d
48+
from executorch.backends.qualcomm.quantizer.custom_annotation import (
49+
annotate_linear_16a8w_in_affine_layer,
50+
annotate_matmul_16a8w,
51+
)
52+
53+
54+
import logging
55+
sys.setrecursionlimit(4096)
56+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
57+
logging.basicConfig(level=logging.INFO, format=FORMAT)
58+
logging.getLogger().setLevel(logging.INFO)
59+
2960

3061
class WrappedLlamaModel(nn.Module):
31-
def __init__(self, model, use_kv_cache=False, max_seq_len=512, device="cuda"):
62+
def __init__(self, model, atten_mask, use_kv_cache=False, max_seq_len=512, device="cuda"):
3263
super(WrappedLlamaModel, self).__init__()
3364
self.model = model
3465
self.max_seq_len = max_seq_len
3566
self.use_kv_cache = use_kv_cache
3667
self.device = device
68+
self.atten_mask = atten_mask
3769

3870
def forward(
3971
self,
4072
tokens: torch.Tensor,
41-
input_pos: Optional[torch.Tensor] = None,
4273
*args,
4374
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
4475
# Pad input if necessary, since LlamaModel requires static shape
4576
if tokens.shape[1] != self.max_seq_len:
4677
tokens = torch.nn.functional.pad(
47-
tokens, (self.max_seq_len - tokens.shape[1], 0)
78+
tokens, (0, self.max_seq_len - tokens.shape[1])
4879
)
49-
atten_mask = (
50-
self.model.get_example_inputs(self.use_kv_cache)[1]
51-
.to(device=self.device)
52-
.to(dtype=torch.bfloat16)
53-
)
54-
return self.model.forward(tokens, atten_mask, input_pos, *args)
80+
return self.model.forward(tokens, self.atten_mask)
5581

5682

5783
def gen_eval_wrapper(model_name, args):
@@ -119,14 +145,73 @@ def permute(w, heads):
119145
layer.feed_forward.prepare_feedfoward_conv()
120146

121147
model.to(dtype=torch.bfloat16)
122-
model.to(args.device)
148+
model.to(device=args.device)
149+
150+
tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
151+
tokens = tokens.to(device=args.device)
152+
atten_mask = atten_mask.to(device=args.device)
153+
atten_mask = atten_mask.to(dtype=torch.bfloat16)
154+
inputs = (tokens, atten_mask)
155+
156+
if args.embedding_quantize:
157+
model = get_quant_embedding_transform(
158+
embedding_quantize=args.embedding_quantize
159+
)(model)
160+
161+
model = convert_linear_to_conv2d(model)
123162

124-
wrapped_model = WrappedLlamaModel(
125-
model, args.use_kv_cache, args.max_seq_length, args.device
163+
if args.ptq:
164+
quant_dtype = getattr(QuantDtype, f"use_{args.ptq}")
165+
166+
custom_annotations = (annotate_matmul_16a8w,)
167+
if args.llama_model == "stories110m":
168+
custom_annotations = custom_annotations + (
169+
annotate_linear_16a8w_in_affine_layer,
170+
)
171+
quantizer = make_quantizer(
172+
quant_dtype=quant_dtype,
173+
per_channel_conv=True,
174+
per_channel_linear=True,
175+
act_observer=MinMaxObserver,
176+
)
177+
quantizer.add_custom_quant_annotations(custom_annotations)
178+
179+
model.has_quant_io = True
180+
181+
with torch.no_grad():
182+
model = torch.export.export(
183+
model, inputs, strict=True
184+
).module()
185+
if quant_dtype == QuantDtype.use_16a4w_block:
186+
conv_nodes = [
187+
n for n in model.graph.nodes if "conv" in n.name
188+
]
189+
block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes}
190+
quantizer.set_block_size_map(block_size_map)
191+
192+
model = prepare_pt2e(model, quantizer)
193+
194+
logging.info("Quantizing the model...")
195+
196+
calibrate(
197+
inputs,
198+
'Once upon a time',
199+
model,
200+
tokenizer=tokenizer,
201+
ar_len=args.prefill_ar_len,
202+
max_seq_len=args.max_seq_len,
203+
kv_updater=None,
204+
use_i64_token=use_i64_token,
205+
)
206+
207+
model = convert_pt2e(model)
208+
209+
model = WrappedLlamaModel(
210+
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
126211
)
127212

128213
return GraphModuleEvalWrapper(
129-
model=wrapped_model,
214+
model=model,
130215
tokenizer=tokenizer,
131216
max_seq_length=args.calibration_seq_length,
132217
use_kv_cache=args.use_kv_cache,
@@ -167,6 +252,7 @@ def main() -> None:
167252
modelname = "llama2"
168253
parser = build_args_parser()
169254
args = parser.parse_args()
255+
args.llama_model = "llama3_2"
170256
# Overrides this arg, because evaluation requires full logits.
171257
args.generate_full_logits = True
172258

@@ -177,7 +263,15 @@ def main() -> None:
177263
args.use_kv_cache = False
178264
args.prefill_ar_len = args.max_seq_length
179265

266+
# To do fewer samples for faster evaluation
267+
args.limit = 0.1
268+
# args.samples = {'wikitext': list(range(1))}
269+
180270
args.device = "cuda" if torch.cuda.is_available() else "cpu"
271+
torch.set_default_device(args.device)
272+
273+
args.ptq = '8a8w'
274+
181275

182276
eval_llama(modelname, args)
183277

0 commit comments

Comments
 (0)