@@ -73,37 +73,49 @@ def _kv_calibrate(
73
73
max_seq_len = 512 ,
74
74
):
75
75
sp_model = get_tokenizer (tokenizer_model_path )
76
- _ , atten_mask , _ , k_caches , v_caches = example_inputs
77
76
78
77
# TODO: change criteria & support batch inputs if necessary
79
- pos = torch .tensor (0 , dtype = torch .int32 )
80
78
max_cache_len = max_seq_len - 1
81
- token_list = sp_model .encode (user_prompts , bos = True , eos = False )
82
79
83
- with torch .no_grad ():
84
- while token_list [- 1 ] != sp_model .eos_id and pos < max_cache_len :
85
- logits , new_k_caches , new_v_caches = module (
86
- torch .full ((1 , 1 ), token_list [pos ], dtype = torch .int32 ),
87
- atten_mask ,
88
- torch .full ((1 , 1 ), pos ),
89
- * k_caches ,
90
- * v_caches ,
91
- )
92
- k_caches = [
93
- torch .cat ([k_cache [:, :, 1 :], new_k_caches [i ]], dim = - 1 )
94
- for i , k_cache in enumerate (k_caches )
95
- ]
96
- v_caches = [
97
- torch .cat ([v_cache [:, 1 :, :], new_v_caches [i ]], dim = 1 )
98
- for i , v_cache in enumerate (v_caches )
99
- ]
100
-
101
- pos += 1
102
- atten_mask [0 ][- pos - 1 ] = 0
103
- if pos >= len (token_list ):
104
- token_list .append (torch .argmax (logits [:, - 1 ], dim = - 1 ).item ())
105
80
106
- print (f"calibration data:\n { sp_model .decode (token_list )} " )
81
+ # token_list = sp_model.encode(user_prompts, bos=True, eos=False)
82
+
83
+ user_token_list = [
84
+ # what is the capital of the united states
85
+ [128000 , 128006 , 882 , 128007 , 271 , 12840 , 374 , 279 , 6864 , 315 , 279 , 29292 , 5415 , 128009 , 128006 , 78191 , 128007 , 271 ],
86
+ # what is 1 + 1
87
+ [128000 , 128006 , 882 , 128007 , 271 , 12840 , 374 , 220 , 16 , 489 , 220 , 16 , 128009 , 128006 , 78191 , 128007 , 271 ],
88
+ # what is the meaning of life
89
+ [128000 , 128006 , 882 , 128007 , 271 , 12840 , 374 , 279 , 7438 , 315 , 2324 , 128009 , 128006 , 78191 , 128007 , 271 ],
90
+ ]
91
+
92
+ for token_list in user_token_list :
93
+ _ , atten_mask , _ , k_caches , v_caches = copy .deepcopy (example_inputs )
94
+ pos = torch .tensor (0 , dtype = torch .int32 )
95
+ with torch .no_grad ():
96
+ while token_list [- 1 ] != sp_model .eos_id and pos < max_cache_len :
97
+ logits , new_k_caches , new_v_caches = module (
98
+ torch .full ((1 , 1 ), token_list [pos ], dtype = torch .int32 ),
99
+ atten_mask ,
100
+ torch .full ((1 , 1 ), pos ),
101
+ * k_caches ,
102
+ * v_caches ,
103
+ )
104
+ k_caches = [
105
+ torch .cat ([k_cache [:, :, 1 :], new_k_caches [i ]], dim = - 1 )
106
+ for i , k_cache in enumerate (k_caches )
107
+ ]
108
+ v_caches = [
109
+ torch .cat ([v_cache [:, 1 :, :], new_v_caches [i ]], dim = 1 )
110
+ for i , v_cache in enumerate (v_caches )
111
+ ]
112
+
113
+ pos += 1
114
+ atten_mask [0 ][- pos - 1 ] = 0
115
+ if pos >= len (token_list ):
116
+ token_list .append (torch .argmax (logits [:, - 1 ], dim = - 1 ).item ())
117
+
118
+ logging .info (f"calibration data:\n { sp_model .decode (token_list )} " )
107
119
108
120
109
121
def _prefill_calibrate (
@@ -114,32 +126,44 @@ def _prefill_calibrate(
114
126
max_seq_len = 512 ,
115
127
):
116
128
sp_model = get_tokenizer (tokenizer_model_path )
117
- _ , atten_mask = example_inputs
118
129
max_cache_len = max_seq_len - 1
119
130
120
131
# TODO: change criteria & support batch inputs if necessary
121
- token_list = sp_model .encode (user_prompts , bos = True , eos = False )
122
- token_list = torch .tensor (token_list )[:max_cache_len ].reshape (1 , - 1 )
123
- last_prompt_pos = token_list .numel ()
124
- if last_prompt_pos < max_cache_len :
125
- token_list = torch .cat (
126
- [
127
- token_list ,
128
- torch .zeros ((1 , max_cache_len - last_prompt_pos ), dtype = torch .int32 ),
129
- ],
130
- dim = 1 ,
131
- )
132
- else :
133
- token_list = token_list [:, :max_cache_len ]
134
-
135
- with torch .no_grad ():
136
- logits , new_k_caches , new_v_caches = module (
137
- token_list ,
138
- atten_mask ,
139
- )
140
- predict = [torch .argmax (logits [:, last_prompt_pos - 1 ], dim = - 1 ).item ()]
132
+
133
+ # token_list = sp_model.encode(user_prompts, bos=True, eos=False)
134
+
135
+ user_token_list = [
136
+ # what is the capital of the united states
137
+ [128000 , 128006 , 882 , 128007 , 271 , 12840 , 374 , 279 , 6864 , 315 , 279 , 29292 , 5415 , 128009 , 128006 , 78191 , 128007 , 271 ],
138
+ # what is 1 + 1
139
+ [128000 , 128006 , 882 , 128007 , 271 , 12840 , 374 , 220 , 16 , 489 , 220 , 16 , 128009 , 128006 , 78191 , 128007 , 271 ],
140
+ # what is the meaning of life
141
+ [128000 , 128006 , 882 , 128007 , 271 , 12840 , 374 , 279 , 7438 , 315 , 2324 , 128009 , 128006 , 78191 , 128007 , 271 ],
142
+ ]
143
+
144
+ for token_list in user_token_list :
145
+ _ , atten_mask = copy .deepcopy (example_inputs )
146
+ token_list = torch .tensor (token_list )[:max_cache_len ].reshape (1 , - 1 )
147
+ last_prompt_pos = token_list .numel ()
148
+ if last_prompt_pos < max_cache_len :
149
+ token_list = torch .cat (
150
+ [
151
+ token_list ,
152
+ torch .zeros ((1 , max_cache_len - last_prompt_pos ), dtype = torch .int32 ),
153
+ ],
154
+ dim = 1 ,
155
+ )
156
+ else :
157
+ token_list = token_list [:, :max_cache_len ]
141
158
142
- print (f"calibration data:\n { sp_model .decode (predict )} " )
159
+ with torch .no_grad ():
160
+ logits , new_k_caches , new_v_caches = module (
161
+ token_list ,
162
+ atten_mask ,
163
+ )
164
+ predict = [torch .argmax (logits [:, last_prompt_pos - 1 ], dim = - 1 ).item ()]
165
+
166
+ logging .info (f"calibration data:\n { sp_model .decode (predict )} " )
143
167
144
168
145
169
def calibrate (
@@ -249,7 +273,17 @@ def quantize(self, quant_dtype, args, custom_annotations=()):
249
273
max_seq_len = self .llama_meta ["get_max_seq_len" ],
250
274
)
251
275
252
- self .llama_model = convert_pt2e (fx_graph_module )
276
+ fx_graph_module = convert_pt2e (fx_graph_module )
277
+
278
+ logging .info ("Evaluating the converted model..." )
279
+ calibrate (
280
+ self .get_example_inputs (self .llama_meta ["get_use_kv_cache" ]),
281
+ args .prompt ,
282
+ fx_graph_module ,
283
+ tokenizer_model_path = args .tokenizer_model ,
284
+ max_seq_len = self .llama_meta ["get_max_seq_len" ],
285
+ )
286
+ self .llama_model = fx_graph_module
253
287
254
288
def lowering_modules (
255
289
self ,
0 commit comments