8
8
import copy
9
9
import json
10
10
11
- from typing import List , Optional , Tuple
11
+ import logging
12
+ import sys
13
+
14
+ from typing import List , Tuple
12
15
13
16
import torch
14
17
import torch .nn as nn
18
+ from executorch .backends .qualcomm .quantizer .custom_annotation import (
19
+ annotate_linear_16a8w_in_affine_layer ,
20
+ annotate_matmul_16a8w ,
21
+ )
22
+
23
+ from executorch .backends .qualcomm .quantizer .quantizer import QuantDtype
24
+ from executorch .backends .qualcomm .utils .utils import convert_linear_to_conv2d
15
25
16
26
from executorch .examples .models .llama .eval_llama_lib import (
17
27
build_args_parser ,
18
28
GraphModuleEvalWrapper ,
19
29
)
20
30
31
+ from executorch .examples .models .llama .source_transformation .quantize import (
32
+ get_quant_embedding_transform ,
33
+ )
34
+
35
+ from executorch .examples .qualcomm .oss_scripts .llama .llama import calibrate
36
+
21
37
from executorch .examples .qualcomm .oss_scripts .llama .model .static_llama import (
22
38
LlamaModel ,
23
39
ModelArgs ,
24
40
)
41
+
42
+ from executorch .examples .qualcomm .utils import make_quantizer
43
+
25
44
from lm_eval .evaluator import simple_evaluate
26
45
27
46
from pytorch_tokenizers import get_tokenizer
28
47
48
+ from torchao .quantization .pt2e import MinMaxObserver
49
+ from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
50
+
51
+ sys .setrecursionlimit (4096 )
52
+ FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
53
+ logging .basicConfig (level = logging .INFO , format = FORMAT )
54
+ logging .getLogger ().setLevel (logging .INFO )
55
+
29
56
30
57
class WrappedLlamaModel (nn .Module ):
31
- def __init__ (self , model , use_kv_cache = False , max_seq_len = 512 , device = "cuda" ):
58
+ def __init__ (
59
+ self , model , atten_mask , use_kv_cache = False , max_seq_len = 512 , device = "cuda"
60
+ ):
32
61
super (WrappedLlamaModel , self ).__init__ ()
33
62
self .model = model
34
63
self .max_seq_len = max_seq_len
35
64
self .use_kv_cache = use_kv_cache
36
65
self .device = device
66
+ self .atten_mask = atten_mask
37
67
38
68
def forward (
39
69
self ,
40
70
tokens : torch .Tensor ,
41
- input_pos : Optional [torch .Tensor ] = None ,
42
71
* args ,
43
72
) -> Tuple [torch .Tensor , List [torch .Tensor ], List [torch .Tensor ]]:
44
73
# Pad input if necessary, since LlamaModel requires static shape
45
74
if tokens .shape [1 ] != self .max_seq_len :
46
75
tokens = torch .nn .functional .pad (
47
- tokens , (self .max_seq_len - tokens .shape [1 ], 0 )
76
+ tokens , (0 , self .max_seq_len - tokens .shape [1 ])
48
77
)
49
- atten_mask = (
50
- self .model .get_example_inputs (self .use_kv_cache )[1 ]
51
- .to (device = self .device )
52
- .to (dtype = torch .bfloat16 )
53
- )
54
- return self .model .forward (tokens , atten_mask , input_pos , * args )
78
+ return self .model .forward (tokens , self .atten_mask )
55
79
56
80
57
81
def gen_eval_wrapper (model_name , args ):
@@ -119,14 +143,69 @@ def permute(w, heads):
119
143
layer .feed_forward .prepare_feedfoward_conv ()
120
144
121
145
model .to (dtype = torch .bfloat16 )
122
- model .to (args .device )
146
+ model .to (device = args .device )
123
147
124
- wrapped_model = WrappedLlamaModel (
125
- model , args .use_kv_cache , args .max_seq_length , args .device
148
+ tokens , atten_mask = model .get_example_inputs (use_kv_cache = False )
149
+ tokens = tokens .to (device = args .device )
150
+ atten_mask = atten_mask .to (device = args .device )
151
+ atten_mask = atten_mask .to (dtype = torch .bfloat16 )
152
+ inputs = (tokens , atten_mask )
153
+
154
+ if args .embedding_quantize :
155
+ model = get_quant_embedding_transform (
156
+ embedding_quantize = args .embedding_quantize
157
+ )(model )
158
+
159
+ model = convert_linear_to_conv2d (model )
160
+
161
+ if args .ptq :
162
+ quant_dtype = getattr (QuantDtype , f"use_{ args .ptq } " )
163
+
164
+ custom_annotations = (annotate_matmul_16a8w ,)
165
+ if args .llama_model == "stories110m" :
166
+ custom_annotations = custom_annotations + (
167
+ annotate_linear_16a8w_in_affine_layer ,
168
+ )
169
+ quantizer = make_quantizer (
170
+ quant_dtype = quant_dtype ,
171
+ per_channel_conv = True ,
172
+ per_channel_linear = True ,
173
+ act_observer = MinMaxObserver ,
174
+ )
175
+ quantizer .add_custom_quant_annotations (custom_annotations )
176
+
177
+ model .has_quant_io = True
178
+
179
+ with torch .no_grad ():
180
+ model = torch .export .export (model , inputs , strict = True ).module ()
181
+ if quant_dtype == QuantDtype .use_16a4w_block :
182
+ conv_nodes = [n for n in model .graph .nodes if "conv" in n .name ]
183
+ block_size_map = {n .name : (1 , 64 , 1 , 1 ) for n in conv_nodes }
184
+ quantizer .set_block_size_map (block_size_map )
185
+
186
+ model = prepare_pt2e (model , quantizer )
187
+
188
+ logging .info ("Quantizing the model..." )
189
+
190
+ calibrate (
191
+ inputs ,
192
+ "Once upon a time" ,
193
+ model ,
194
+ tokenizer = tokenizer ,
195
+ ar_len = args .prefill_ar_len ,
196
+ max_seq_len = args .max_seq_len ,
197
+ kv_updater = None ,
198
+ use_i64_token = use_i64_token ,
199
+ )
200
+
201
+ model = convert_pt2e (model )
202
+
203
+ model = WrappedLlamaModel (
204
+ model , atten_mask , args .use_kv_cache , args .max_seq_length , args .device
126
205
)
127
206
128
207
return GraphModuleEvalWrapper (
129
- model = wrapped_model ,
208
+ model = model ,
130
209
tokenizer = tokenizer ,
131
210
max_seq_length = args .calibration_seq_length ,
132
211
use_kv_cache = args .use_kv_cache ,
@@ -167,6 +246,7 @@ def main() -> None:
167
246
modelname = "llama2"
168
247
parser = build_args_parser ()
169
248
args = parser .parse_args ()
249
+ args .llama_model = "llama3_2"
170
250
# Overrides this arg, because evaluation requires full logits.
171
251
args .generate_full_logits = True
172
252
@@ -177,7 +257,14 @@ def main() -> None:
177
257
args .use_kv_cache = False
178
258
args .prefill_ar_len = args .max_seq_length
179
259
260
+ # To do fewer samples for faster evaluation
261
+ args .limit = 0.1
262
+ # args.samples = {'wikitext': list(range(1))}
263
+
180
264
args .device = "cuda" if torch .cuda .is_available () else "cpu"
265
+ torch .set_default_device (args .device )
266
+
267
+ args .ptq = "8a8w"
181
268
182
269
eval_llama (modelname , args )
183
270
0 commit comments