@@ -177,14 +177,15 @@ def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None
177
177
)
178
178
179
179
180
- def calibrate (
180
+ def _kv_calibrate (
181
181
example_inputs ,
182
182
user_prompts ,
183
183
module : torch .fx .GraphModule ,
184
184
tokenizer_model_path = "tokenizer.model" ,
185
+ max_seq_len = 512 ,
185
186
):
186
187
sp_model = SentencePieceProcessor (model_file = tokenizer_model_path )
187
- _ , _ , atten_mask , k_caches , v_caches = example_inputs
188
+ _ , atten_mask , _ , k_caches , v_caches = example_inputs
188
189
189
190
# TODO: change criteria & support batch inputs if necessary
190
191
pos = torch .tensor (0 , dtype = torch .int32 )
@@ -202,11 +203,11 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor:
202
203
return probs_indices .gather (dim = - 1 , index = next_token )
203
204
204
205
with torch .no_grad ():
205
- while token_list [- 1 ] != sp_model .eos_id () and pos < 128 :
206
+ while token_list [- 1 ] != sp_model .eos_id () and pos < max_seq_len - 1 :
206
207
logits , new_k_caches , new_v_caches = module (
207
208
torch .full ((1 , 1 ), token_list [pos ]),
208
- torch .full ((1 , 1 ), pos ),
209
209
atten_mask ,
210
+ torch .full ((1 , 1 ), pos ),
210
211
* k_caches ,
211
212
* v_caches ,
212
213
)
@@ -228,15 +229,84 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor:
228
229
print (f"calibration data:\n { sp_model .decode (token_list )} " )
229
230
230
231
232
+ def _batch_prefill_calibrate (
233
+ example_inputs ,
234
+ user_prompts ,
235
+ module : torch .fx .GraphModule ,
236
+ tokenizer_model_path = "tokenizer.model" ,
237
+ max_seq_len = 512 ,
238
+ ):
239
+ sp_model = SentencePieceProcessor (model_file = tokenizer_model_path )
240
+ _ , atten_mask = example_inputs
241
+ max_cache_len = max_seq_len - 1
242
+
243
+ # TODO: change criteria & support batch inputs if necessary
244
+ token_list = sp_model .encode (user_prompts , bos = True , eos = False )
245
+ token_list = torch .tensor (token_list )[:max_cache_len ].reshape (1 , - 1 )
246
+ last_prompt_pos = token_list .numel ()
247
+ if last_prompt_pos < max_cache_len :
248
+ token_list = torch .cat (
249
+ [
250
+ token_list ,
251
+ torch .zeros ((1 , max_cache_len - last_prompt_pos ), dtype = torch .int32 ),
252
+ ],
253
+ dim = 1 ,
254
+ )
255
+ else :
256
+ token_list = token_list [:, :max_cache_len ]
257
+
258
+ with torch .no_grad ():
259
+ logits , new_k_caches , new_v_caches = module (
260
+ token_list ,
261
+ atten_mask ,
262
+ )
263
+ predict = [torch .argmax (logits [:, last_prompt_pos - 1 ], dim = - 1 ).item ()]
264
+
265
+ print (f"calibration data:\n { sp_model .decode (predict )} " )
266
+
267
+
268
+ def calibrate (
269
+ example_inputs ,
270
+ user_prompts ,
271
+ module : torch .fx .GraphModule ,
272
+ tokenizer_model_path = "tokenizer.model" ,
273
+ max_seq_len = 512 ,
274
+ ):
275
+ if len (example_inputs ) == 2 :
276
+ _batch_prefill_calibrate (
277
+ example_inputs ,
278
+ user_prompts ,
279
+ module ,
280
+ tokenizer_model_path ,
281
+ max_seq_len ,
282
+ )
283
+ elif len (example_inputs ) == 5 :
284
+ _kv_calibrate (
285
+ example_inputs ,
286
+ user_prompts ,
287
+ module ,
288
+ tokenizer_model_path ,
289
+ max_seq_len ,
290
+ )
291
+ else :
292
+ raise RuntimeError ("Get wrong inputs" )
293
+
294
+
231
295
class SingleLlama :
232
296
def __init__ (self , llama_model ) -> None :
233
297
super ().__init__ ()
234
298
self .llama_model = llama_model
235
299
self .quant_dtype = None
236
300
self .llama_meta = self .llama_model .get_metadata ()
237
301
self .has_quant_io = False
238
- tokens , pos_ids , atten_mask , k_caches , v_caches = self .get_example_inputs ()
239
- self .inputs = (tokens , pos_ids , atten_mask , * k_caches , * v_caches )
302
+ if self .llama_meta ["get_use_kv_cache" ]:
303
+ tokens , atten_mask , pos_ids , k_caches , v_caches = self .get_example_inputs (
304
+ use_kv_cache = True
305
+ )
306
+ self .inputs = (tokens , atten_mask , pos_ids , * k_caches , * v_caches )
307
+ else :
308
+ tokens , atten_mask = self .get_example_inputs (use_kv_cache = False )
309
+ self .inputs = (tokens , atten_mask )
240
310
241
311
def _tag_kv_ios (self , gm : torch .fx .GraphModule , kv_type ):
242
312
if not self .has_quant_io :
@@ -256,11 +326,17 @@ def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type):
256
326
n .meta [QCOM_QUANTIZED_IO ] = kv_type
257
327
elif n .op == "output" :
258
328
for a in n .args [0 ]:
329
+ # single head, kv mode
259
330
if (
260
331
a .meta ["val" ].flatten ().size ()[0 ]
261
332
== self .llama_meta ["get_head_dim" ]
262
333
):
263
334
a .meta [QCOM_QUANTIZED_IO ] = kv_type
335
+ # single head, batch_prefill mode
336
+ elif a .meta ["val" ].flatten ().size ()[0 ] == self .llama_meta [
337
+ "get_head_dim"
338
+ ] * (self .llama_meta ["get_max_seq_len" ] - 1 ):
339
+ a .meta [QCOM_QUANTIZED_IO ] = kv_type
264
340
265
341
def quantize (self , quant_dtype , custom_annotations = ()):
266
342
self .quant_dtype = quant_dtype
@@ -281,11 +357,13 @@ def quantize(self, quant_dtype, custom_annotations=()):
281
357
).module ()
282
358
fx_graph_module = prepare_pt2e (fx_graph_module , quantizer )
283
359
print ("Quantizing the model..." )
360
+
284
361
calibrate (
285
- self .get_example_inputs (),
362
+ self .get_example_inputs (self . llama_meta [ "get_use_kv_cache" ] ),
286
363
args .prompt ,
287
364
fx_graph_module ,
288
365
tokenizer_model_path = args .tokenizer_model ,
366
+ max_seq_len = args .seq_len ,
289
367
)
290
368
291
369
self .llama_model = convert_pt2e (fx_graph_module )
@@ -328,18 +406,29 @@ def lowering_modules(
328
406
with open (f"{ work_space } /{ pte_filename } .pte" , "wb" ) as file :
329
407
exec_prog_mgr .write_to_file (file )
330
408
331
- def get_example_inputs (self ):
332
- return self .llama_model .get_example_inputs ()
409
+ def get_example_inputs (self , use_kv_cache = True ):
410
+ return self .llama_model .get_example_inputs (use_kv_cache )
333
411
334
412
335
413
def compile (args ):
336
414
os .makedirs (args .artifact , exist_ok = True )
337
415
start_ts = time .time ()
416
+
417
+ if args .model_mode == "kv" :
418
+ use_kv_cache = output_new_cache_only = True
419
+ elif args .model_mode == "batch_prefill" or args .model_mode == "hybrid" :
420
+ raise NotImplementedError (
421
+ f"model_mode { args .model_mode } is not implemented yet."
422
+ )
423
+ else :
424
+ raise RuntimeError (f"No such model_mode { args .model_mode } ." )
425
+
338
426
with open (args .params ) as f :
339
427
config = ModelArgs (** json .load (f ))
340
428
# TODO: support batch inputs if necessary
341
429
config .max_batch_size = 1
342
- config .max_seq_len = 1024
430
+ config .max_seq_len = args .seq_len
431
+ config .use_kv_cache = use_kv_cache
343
432
state_dict = torch .load (
344
433
args .checkpoint , weights_only = True , map_location = "cpu" , mmap = True
345
434
)
@@ -348,7 +437,7 @@ def compile(args):
348
437
349
438
llama_instance = None
350
439
with torch .device ("meta" ):
351
- llama_instance = LlamaModel (config , output_new_cache_only = True )
440
+ llama_instance = LlamaModel (config , output_new_cache_only = output_new_cache_only )
352
441
if "model" in state_dict :
353
442
state_dict = state_dict ["model" ]
354
443
llama_instance .load_state_dict (
@@ -398,6 +487,11 @@ def compile(args):
398
487
def inference (args , pre_gen_pte = "" ):
399
488
workspace = f"/data/local/tmp/{ getpass .getuser ()} /executorch/single_llama"
400
489
490
+ if args .model_mode != "kv" :
491
+ raise NotImplementedError (
492
+ f"model_mode { args .model_mode } is not implemented yet."
493
+ )
494
+
401
495
runner_args = " " .join (
402
496
[
403
497
f"--model_path { pte_filename } .pte" ,
@@ -550,6 +644,21 @@ def post_process():
550
644
type = str ,
551
645
)
552
646
647
+ parser .add_argument (
648
+ "--num_sharding" ,
649
+ type = int ,
650
+ default = 0 ,
651
+ help = "Specify the number of splits by inserting the fallback custom op. The graph will be split evenly by layers." ,
652
+ )
653
+
654
+ parser .add_argument (
655
+ "--model_mode" ,
656
+ help = "Export and inference batch_prefill mode, kv mode or hybrid(TBD) mode" ,
657
+ default = "kv" ,
658
+ choices = ["batch_prefill" , "kv" , "hybrid" ],
659
+ type = str ,
660
+ )
661
+
553
662
args = parser .parse_args ()
554
663
if args .compile_only and args .pre_gen_pte :
555
664
exit ("Cannot set both compile_only and pre_gen_pte as true" )
0 commit comments