66
66
logging .getLogger ().setLevel (logging .INFO )
67
67
68
68
69
+ def smart_mask_updator (atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches ):
70
+ for i , k_cache in enumerate (k_caches ):
71
+ k_cache [:, :, pos ] = new_k_caches [i ][:, :, 0 ]
72
+
73
+ for i , v_cache in enumerate (v_caches ):
74
+ v_cache [:, pos , :] = new_v_caches [i ]
75
+
76
+ atten_mask [0 ][pos ] = 0
77
+ pos += 1
78
+ return (atten_mask , pos , k_caches , v_caches )
79
+
80
+
81
+ def shift_pointer_updator (
82
+ atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
83
+ ):
84
+ k_caches = [
85
+ torch .cat ([k_cache [:, :, 1 :], new_k_caches [i ]], dim = - 1 )
86
+ for i , k_cache in enumerate (k_caches )
87
+ ]
88
+ v_caches = [
89
+ torch .cat ([v_cache [:, 1 :, :], new_v_caches [i ]], dim = 1 )
90
+ for i , v_cache in enumerate (v_caches )
91
+ ]
92
+
93
+ pos += 1
94
+ atten_mask [0 ][- pos - 1 ] = 0
95
+ return (atten_mask , pos , k_caches , v_caches )
96
+
97
+
69
98
def _kv_calibrate (
70
99
example_inputs ,
71
100
user_prompts ,
72
101
module : torch .fx .GraphModule ,
73
102
tokenizer_model_path = "tokenizer.model" ,
74
103
max_seq_len = 512 ,
104
+ updator = smart_mask_updator ,
75
105
):
76
106
sp_model = get_tokenizer (tokenizer_model_path )
77
107
_ , atten_mask , _ , k_caches , v_caches = example_inputs
@@ -92,17 +122,9 @@ def _kv_calibrate(
92
122
* k_caches ,
93
123
* v_caches ,
94
124
)
95
- k_caches = [
96
- torch .cat ([k_cache [:, :, 1 :], new_k_caches [i ]], dim = - 1 )
97
- for i , k_cache in enumerate (k_caches )
98
- ]
99
- v_caches = [
100
- torch .cat ([v_cache [:, 1 :, :], new_v_caches [i ]], dim = 1 )
101
- for i , v_cache in enumerate (v_caches )
102
- ]
103
-
104
- pos += 1
105
- atten_mask [0 ][- pos - 1 ] = 0
125
+ atten_mask , pos , k_caches , v_caches = updator (
126
+ atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
127
+ )
106
128
if pos >= len (token_list ):
107
129
token_list .append (torch .argmax (logits [:, - 1 ], dim = - 1 ).item ())
108
130
@@ -153,6 +175,7 @@ def calibrate(
153
175
module : torch .fx .GraphModule ,
154
176
tokenizer_model_path = "tokenizer.model" ,
155
177
max_seq_len = 512 ,
178
+ kv_updator = smart_mask_updator ,
156
179
):
157
180
if len (example_inputs ) == 2 :
158
181
_prefill_calibrate (
@@ -169,6 +192,7 @@ def calibrate(
169
192
module ,
170
193
tokenizer_model_path ,
171
194
max_seq_len ,
195
+ updator = kv_updator ,
172
196
)
173
197
else :
174
198
raise RuntimeError ("Get wrong inputs" )
@@ -298,13 +322,15 @@ def quantize(self, quant_dtype, args, custom_annotations=()):
298
322
self .llama_model , self .inputs , strict = True
299
323
).module ()
300
324
fx_graph_module = prepare_pt2e (fx_graph_module , quantizer )
325
+
301
326
logging .info ("Quantizing the model..." )
302
327
calibrate (
303
328
self .get_example_inputs (self .llama_meta ["get_use_kv_cache" ]),
304
329
args .prompt ,
305
330
fx_graph_module ,
306
331
tokenizer_model_path = args .tokenizer_model ,
307
332
max_seq_len = self .llama_meta ["get_max_seq_len" ],
333
+ kv_updator = args .kv_updator ,
308
334
)
309
335
310
336
self .llama_model = convert_pt2e (fx_graph_module )
@@ -316,6 +342,7 @@ def lowering_modules(
316
342
use_fp16 = False ,
317
343
soc_model = QcomChipset .SM8650 ,
318
344
num_sharding = 0 ,
345
+ shared_buffer = False ,
319
346
):
320
347
executorch_config = ExecutorchBackendConfig (
321
348
# For shared buffer, user must pass the memory address
@@ -336,7 +363,7 @@ def lowering_modules(
336
363
compiler_specs = generate_qnn_executorch_compiler_spec (
337
364
soc_model = soc_model ,
338
365
backend_options = backend_options ,
339
- shared_buffer = False ,
366
+ shared_buffer = shared_buffer ,
340
367
)
341
368
skip_node_op_set = {"llama.fallback.default" }
342
369
partitioner = QnnPartitioner (
@@ -366,7 +393,7 @@ def lowering_modules(
366
393
if num_sharding > 0 :
367
394
update_spill_fill_size (edge_prog_mgr .exported_program ())
368
395
exec_prog_mgr = edge_prog_mgr .to_executorch (config = executorch_config )
369
- with open (f"{ work_space } /{ pte_filename } .pte" , "wb" ) as file :
396
+ with open (f"{ work_space } /{ self . pte_filename } .pte" , "wb" ) as file :
370
397
exec_prog_mgr .write_to_file (file )
371
398
372
399
def get_example_inputs (self , use_kv_cache = True ):
@@ -491,6 +518,7 @@ def compile(args, pte_filename):
491
518
use_fp16 = use_fp16 ,
492
519
soc_model = get_soc_to_chipset_map ()[args .model ],
493
520
num_sharding = args .num_sharding ,
521
+ shared_buffer = args .shared_buffer ,
494
522
)
495
523
quant_attrs = llama_instance_list [0 ].get_quant_attrs ()
496
524
else :
@@ -525,7 +553,7 @@ def compile(args, pte_filename):
525
553
generate_qnn_executorch_compiler_spec (
526
554
soc_model = get_soc_to_chipset_map ()[args .model ],
527
555
backend_options = backend_options ,
528
- shared_buffer = True ,
556
+ shared_buffer = args . shared_buffer ,
529
557
multiple_graphs = True ,
530
558
graph_name = graph_name ,
531
559
)
@@ -697,6 +725,7 @@ def inference(args, quant_attrs, pte_filename, pre_gen_pte=""):
697
725
f"--system_prompt '{ args .system_prompt } '" ,
698
726
f"--logits_scale { quant_attrs ['scale' ]} " ,
699
727
f"--logits_offset { quant_attrs ['zero_point' ]} " ,
728
+ f"--kv_updator { 'SmartMask' if args .kv_updator == smart_mask_updator else 'ShiftPointer' } " ,
700
729
]
701
730
)
702
731
runner_cmd = " " .join (
@@ -862,6 +891,14 @@ def main():
862
891
type = int ,
863
892
)
864
893
894
+ parser .add_argument (
895
+ "--kv_updator" ,
896
+ help = "Choose how to update kv cache during runtime" ,
897
+ choices = ["smart_mask" , "shift_pointer" ],
898
+ default = "smart_mask" ,
899
+ type = str ,
900
+ )
901
+
865
902
args = parser .parse_args ()
866
903
if args .compile_only and args .pre_gen_pte :
867
904
exit ("Cannot set both compile_only and pre_gen_pte as true" )
@@ -878,6 +915,14 @@ def main():
878
915
else :
879
916
raise RuntimeError (f"No such model_mode { args .model_mode } ." )
880
917
918
+ if args .kv_updator == "smart_mask" :
919
+ args .shared_buffer = True
920
+ args .kv_updator = smart_mask_updator
921
+ elif args .kv_updator == "shift_pointer" :
922
+ args .kv_updator = shift_pointer_updator
923
+ else :
924
+ exit (f"Using an unkown kv update { args .kv_updator } " )
925
+
881
926
if args .pre_gen_pte :
882
927
quant_attrs = json .load (
883
928
open (f"{ args .pre_gen_pte } /{ pte_filename } _quant_attrs.txt" )
0 commit comments