@@ -82,7 +82,6 @@ def _kv_calibrate(
82
82
_ , atten_mask , _ , k_caches , v_caches = example_inputs
83
83
84
84
# TODO: change criteria & support batch inputs if necessary
85
- pos = torch .tensor (0 , dtype = torch .int32 )
86
85
max_cache_len = max_seq_len - 1
87
86
88
87
token_list = []
@@ -114,10 +113,42 @@ def _kv_calibrate(
114
113
for i , v_cache in enumerate (v_caches )
115
114
]
116
115
117
- pos += 1
118
- atten_mask [0 ][- pos - 1 ] = 0
119
- if pos >= len (token_list ):
120
- token_list .append (torch .argmax (logits [:, - 1 ], dim = - 1 ).item ())
116
+ # token_list = sp_model.encode(user_prompts, bos=True, eos=False)
117
+
118
+ user_token_list = [
119
+ # what is the capital of the united states
120
+ [128000 , 128006 , 882 , 128007 , 271 , 12840 , 374 , 279 , 6864 , 315 , 279 , 29292 , 5415 , 128009 , 128006 , 78191 , 128007 , 271 ],
121
+ # what is 1 + 1
122
+ [128000 , 128006 , 882 , 128007 , 271 , 12840 , 374 , 220 , 16 , 489 , 220 , 16 , 128009 , 128006 , 78191 , 128007 , 271 ],
123
+ # what is the meaning of life
124
+ [128000 , 128006 , 882 , 128007 , 271 , 12840 , 374 , 279 , 7438 , 315 , 2324 , 128009 , 128006 , 78191 , 128007 , 271 ],
125
+ ]
126
+
127
+ for token_list in user_token_list :
128
+ _ , atten_mask , _ , k_caches , v_caches = copy .deepcopy (example_inputs )
129
+ pos = torch .tensor (0 , dtype = torch .int32 )
130
+ with torch .no_grad ():
131
+ while token_list [- 1 ] != sp_model .eos_id and pos < max_cache_len :
132
+ logits , new_k_caches , new_v_caches = module (
133
+ torch .full ((1 , 1 ), token_list [pos ], dtype = torch .int32 ),
134
+ atten_mask ,
135
+ torch .full ((1 , 1 ), pos ),
136
+ * k_caches ,
137
+ * v_caches ,
138
+ )
139
+ k_caches = [
140
+ torch .cat ([k_cache [:, :, 1 :], new_k_caches [i ]], dim = - 1 )
141
+ for i , k_cache in enumerate (k_caches )
142
+ ]
143
+ v_caches = [
144
+ torch .cat ([v_cache [:, 1 :, :], new_v_caches [i ]], dim = 1 )
145
+ for i , v_cache in enumerate (v_caches )
146
+ ]
147
+
148
+ pos += 1
149
+ atten_mask [0 ][- pos - 1 ] = 0
150
+ if pos >= len (token_list ):
151
+ token_list .append (torch .argmax (logits [:, - 1 ], dim = - 1 ).item ())
121
152
122
153
print (f"kv calibration data:\n { tokenizer .decode (token_list )} " )
123
154
@@ -328,7 +359,17 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
328
359
max_seq_len = self .llama_meta ["get_max_seq_len" ],
329
360
)
330
361
331
- self .llama_model = convert_pt2e (fx_graph_module )
362
+ fx_graph_module = convert_pt2e (fx_graph_module )
363
+
364
+ logging .info ("Evaluating the converted model..." )
365
+ calibrate (
366
+ self .get_example_inputs (self .llama_meta ["get_use_kv_cache" ]),
367
+ args .prompt ,
368
+ fx_graph_module ,
369
+ tokenizer_model_path = args .tokenizer_model ,
370
+ max_seq_len = self .llama_meta ["get_max_seq_len" ],
371
+ )
372
+ self .llama_model = fx_graph_module
332
373
333
374
def lowering_modules (
334
375
self ,
0 commit comments