4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import sys
7
8
import argparse
8
9
import copy
9
10
import json
11
+ import torch
12
+ from functools import partial
13
+
14
+ from lm_eval .evaluator import simple_evaluate
10
15
11
16
from typing import List , Optional , Tuple
12
17
26
31
27
32
from pytorch_tokenizers import get_tokenizer
28
33
34
+ from executorch .examples .qualcomm .oss_scripts .llama .llama import calibrate
35
+
36
+ from executorch .examples .qualcomm .utils import make_quantizer
37
+
38
+ from executorch .examples .models .llama .source_transformation .quantize import (
39
+ get_quant_embedding_transform ,
40
+ )
41
+
42
+ from torchao .quantization .pt2e import MinMaxObserver
43
+ from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
44
+
45
+
46
+ from executorch .backends .qualcomm .quantizer .quantizer import QuantDtype
47
+ from executorch .backends .qualcomm .utils .utils import convert_linear_to_conv2d
48
+ from executorch .backends .qualcomm .quantizer .custom_annotation import (
49
+ annotate_linear_16a8w_in_affine_layer ,
50
+ annotate_matmul_16a8w ,
51
+ )
52
+
53
+
54
+ import logging
55
+ sys .setrecursionlimit (4096 )
56
+ FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
57
+ logging .basicConfig (level = logging .INFO , format = FORMAT )
58
+ logging .getLogger ().setLevel (logging .INFO )
59
+
29
60
30
61
class WrappedLlamaModel (nn .Module ):
31
- def __init__ (self , model , use_kv_cache = False , max_seq_len = 512 , device = "cuda" ):
62
+ def __init__ (self , model , atten_mask , use_kv_cache = False , max_seq_len = 512 , device = "cuda" ):
32
63
super (WrappedLlamaModel , self ).__init__ ()
33
64
self .model = model
34
65
self .max_seq_len = max_seq_len
35
66
self .use_kv_cache = use_kv_cache
36
67
self .device = device
68
+ self .atten_mask = atten_mask
37
69
38
70
def forward (
39
71
self ,
40
72
tokens : torch .Tensor ,
41
- input_pos : Optional [torch .Tensor ] = None ,
42
73
* args ,
43
74
) -> Tuple [torch .Tensor , List [torch .Tensor ], List [torch .Tensor ]]:
44
75
# Pad input if necessary, since LlamaModel requires static shape
45
76
if tokens .shape [1 ] != self .max_seq_len :
46
77
tokens = torch .nn .functional .pad (
47
- tokens , (self .max_seq_len - tokens .shape [1 ], 0 )
78
+ tokens , (0 , self .max_seq_len - tokens .shape [1 ])
48
79
)
49
- atten_mask = (
50
- self .model .get_example_inputs (self .use_kv_cache )[1 ]
51
- .to (device = self .device )
52
- .to (dtype = torch .bfloat16 )
53
- )
54
- return self .model .forward (tokens , atten_mask , input_pos , * args )
80
+ return self .model .forward (tokens , self .atten_mask )
55
81
56
82
57
83
def gen_eval_wrapper (model_name , args ):
@@ -119,14 +145,73 @@ def permute(w, heads):
119
145
layer .feed_forward .prepare_feedfoward_conv ()
120
146
121
147
model .to (dtype = torch .bfloat16 )
122
- model .to (args .device )
148
+ model .to (device = args .device )
149
+
150
+ tokens , atten_mask = model .get_example_inputs (use_kv_cache = False )
151
+ tokens = tokens .to (device = args .device )
152
+ atten_mask = atten_mask .to (device = args .device )
153
+ atten_mask = atten_mask .to (dtype = torch .bfloat16 )
154
+ inputs = (tokens , atten_mask )
155
+
156
+ if args .embedding_quantize :
157
+ model = get_quant_embedding_transform (
158
+ embedding_quantize = args .embedding_quantize
159
+ )(model )
160
+
161
+ model = convert_linear_to_conv2d (model )
123
162
124
- wrapped_model = WrappedLlamaModel (
125
- model , args .use_kv_cache , args .max_seq_length , args .device
163
+ if args .ptq :
164
+ quant_dtype = getattr (QuantDtype , f"use_{ args .ptq } " )
165
+
166
+ custom_annotations = (annotate_matmul_16a8w ,)
167
+ if args .llama_model == "stories110m" :
168
+ custom_annotations = custom_annotations + (
169
+ annotate_linear_16a8w_in_affine_layer ,
170
+ )
171
+ quantizer = make_quantizer (
172
+ quant_dtype = quant_dtype ,
173
+ per_channel_conv = True ,
174
+ per_channel_linear = True ,
175
+ act_observer = MinMaxObserver ,
176
+ )
177
+ quantizer .add_custom_quant_annotations (custom_annotations )
178
+
179
+ model .has_quant_io = True
180
+
181
+ with torch .no_grad ():
182
+ model = torch .export .export (
183
+ model , inputs , strict = True
184
+ ).module ()
185
+ if quant_dtype == QuantDtype .use_16a4w_block :
186
+ conv_nodes = [
187
+ n for n in model .graph .nodes if "conv" in n .name
188
+ ]
189
+ block_size_map = {n .name : (1 , 64 , 1 , 1 ) for n in conv_nodes }
190
+ quantizer .set_block_size_map (block_size_map )
191
+
192
+ model = prepare_pt2e (model , quantizer )
193
+
194
+ logging .info ("Quantizing the model..." )
195
+
196
+ calibrate (
197
+ inputs ,
198
+ 'Once upon a time' ,
199
+ model ,
200
+ tokenizer = tokenizer ,
201
+ ar_len = args .prefill_ar_len ,
202
+ max_seq_len = args .max_seq_len ,
203
+ kv_updater = None ,
204
+ use_i64_token = use_i64_token ,
205
+ )
206
+
207
+ model = convert_pt2e (model )
208
+
209
+ model = WrappedLlamaModel (
210
+ model , atten_mask , args .use_kv_cache , args .max_seq_length , args .device
126
211
)
127
212
128
213
return GraphModuleEvalWrapper (
129
- model = wrapped_model ,
214
+ model = model ,
130
215
tokenizer = tokenizer ,
131
216
max_seq_length = args .calibration_seq_length ,
132
217
use_kv_cache = args .use_kv_cache ,
@@ -167,6 +252,7 @@ def main() -> None:
167
252
modelname = "llama2"
168
253
parser = build_args_parser ()
169
254
args = parser .parse_args ()
255
+ args .llama_model = "llama3_2"
170
256
# Overrides this arg, because evaluation requires full logits.
171
257
args .generate_full_logits = True
172
258
@@ -177,7 +263,15 @@ def main() -> None:
177
263
args .use_kv_cache = False
178
264
args .prefill_ar_len = args .max_seq_length
179
265
266
+ # To do fewer samples for faster evaluation
267
+ args .limit = 0.1
268
+ # args.samples = {'wikitext': list(range(1))}
269
+
180
270
args .device = "cuda" if torch .cuda .is_available () else "cpu"
271
+ torch .set_default_device (args .device )
272
+
273
+ args .ptq = '8a8w'
274
+
181
275
182
276
eval_llama (modelname , args )
183
277
0 commit comments