@@ -114,47 +114,34 @@ int main(int argc, char ** argv) {
114
114
struct llama_grammar * grammar_dft = NULL ;
115
115
struct llama_grammar * grammar_tgt = NULL ;
116
116
117
- grammar_parser::parse_state parsed_grammar_dft;
118
- grammar_parser::parse_state parsed_grammar_tgt;
117
+ grammar_parser::parse_state parsed_grammar;
119
118
120
119
std::vector<llama_grammar *> grammar_mem (n_draft, NULL );
121
120
121
+ // if requested - load the grammar, error checking is omitted for brevity
122
122
if (!params.grammar .empty ()) {
123
- // dft
124
- {
125
- parsed_grammar_dft = grammar_parser::parse (params.grammar .c_str ());
126
- // will be empty (default) if there are parse errors
127
- if (parsed_grammar_dft.rules .empty ()) {
128
- return 1 ;
129
- }
130
-
131
- std::vector<const llama_grammar_element *> grammar_rules (parsed_grammar_dft.c_rules ());
132
- grammar_dft = llama_grammar_init (grammar_rules.data (), grammar_rules.size (), parsed_grammar_dft.symbol_ids .at (" root" ));
123
+ parsed_grammar = grammar_parser::parse (params.grammar .c_str ());
124
+ // will be empty (default) if there are parse errors
125
+ if (parsed_grammar.rules .empty ()) {
126
+ return 1 ;
133
127
}
134
128
135
- // tgt
136
- {
137
- parsed_grammar_tgt = grammar_parser::parse (params.grammar .c_str ());
138
- // will be empty (default) if there are parse errors
139
- if (parsed_grammar_tgt.rules .empty ()) {
140
- return 1 ;
141
- }
142
-
143
- std::vector<const llama_grammar_element *> grammar_rules (parsed_grammar_tgt.c_rules ());
144
- grammar_tgt = llama_grammar_init (grammar_rules.data (), grammar_rules.size (), parsed_grammar_tgt.symbol_ids .at (" root" ));
145
- }
129
+ std::vector<const llama_grammar_element *> grammar_rules (parsed_grammar.c_rules ());
130
+ grammar_dft = llama_grammar_init (grammar_rules.data (), grammar_rules.size (), parsed_grammar.symbol_ids .at (" root" ));
131
+ grammar_tgt = llama_grammar_init (grammar_rules.data (), grammar_rules.size (), parsed_grammar.symbol_ids .at (" root" ));
146
132
}
147
133
148
134
const auto t_dec_start = ggml_time_us ();
149
135
150
136
while (true ) {
151
137
LOG (" drafted: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx_dft, drafted));
152
138
153
- // sample from the drafted tokens if any
154
139
int i_dft = 0 ;
155
140
while (true ) {
141
+ // sample from the target model
156
142
const llama_token id = llama_sample_token (ctx_tgt, NULL , grammar_tgt, params, last_tokens, candidates, i_dft);
157
143
144
+ // remember which tokens were sampled - used for repetition penalties during sampling
158
145
last_tokens.erase (last_tokens.begin ());
159
146
last_tokens.push_back (id);
160
147
@@ -170,8 +157,9 @@ int main(int argc, char ** argv) {
170
157
171
158
++n_predict;
172
159
160
+ // check if the draft matches the target
173
161
if (i_dft < (int ) drafted.size () && id == drafted[i_dft]) {
174
- LOG (" drafted token %d accepted\n " , id );
162
+ LOG (" the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n " , i_dft, id, token_str. c_str () );
175
163
++n_accept;
176
164
++n_past_tgt;
177
165
++n_past_dft;
@@ -180,25 +168,20 @@ int main(int argc, char ** argv) {
180
168
continue ;
181
169
}
182
170
171
+ // the drafted token was rejected or we are out of drafted tokens
172
+
183
173
if (i_dft < (int ) drafted.size ()) {
184
- LOG (" drafted token %d rejected\n " , id);
174
+ LOG (" the %dth drafted token (%d, '%s') does not match the sampled target token (%d, '%s') - rejected\n " ,
175
+ i_dft, drafted[i_dft], llama_token_to_piece (ctx_dft, drafted[i_dft]).c_str (), id, token_str.c_str ());
185
176
186
177
if (grammar_mem[i_dft]) {
187
178
grammar_dft = llama_grammar_copy (grammar_mem[i_dft]);
188
- LOG (" restored grammar %d\n " , i_dft);
179
+ LOG (" restored draft grammar state %d\n " , i_dft);
189
180
}
181
+ } else {
182
+ LOG (" out of drafted tokens\n " );
190
183
}
191
184
192
- for (auto & g : grammar_mem) {
193
- if (g) {
194
- llama_grammar_free (g);
195
- g = NULL ;
196
- }
197
- }
198
-
199
- LOG (" i_dft = %d, drafted.size() = %d\n " , i_dft, (int ) drafted.size ());
200
-
201
- // the drafted token was rejected or we are out of drafted tokens
202
185
llama_eval (ctx_dft, &id, 1 , n_past_dft, params.n_threads );
203
186
++n_past_dft;
204
187
@@ -212,11 +195,20 @@ int main(int argc, char ** argv) {
212
195
break ;
213
196
}
214
197
198
+ for (int i = 0 ; i < (int ) grammar_mem.size (); ++i) {
199
+ auto & g = grammar_mem[i];
200
+ if (g) {
201
+ LOG (" freeing grammar state %d\n " , i);
202
+ llama_grammar_free (g);
203
+ g = NULL ;
204
+ }
205
+ }
206
+
215
207
if (n_predict > params.n_predict || has_eos) {
216
208
break ;
217
209
}
218
210
219
- // sample n_draft tokens from the draft model picking the best token
211
+ // sample n_draft tokens from the draft model using greedy decoding
220
212
int n_past_cur = n_past_dft;
221
213
for (int i = 0 ; i < n_draft; ++i) {
222
214
// remember the grammar state
@@ -244,11 +236,13 @@ int main(int argc, char ** argv) {
244
236
LOG (" - draft candidate %3d: %6d (%8.3f) '%s'\n " , i, cur_p.data [i].id , cur_p.data [i].p , llama_token_to_piece (ctx_dft, cur_p.data [i].id ).c_str ());
245
237
}
246
238
247
- // too low probability, stop drafting
239
+ // TODO: better logic?
248
240
if (cur_p.data [0 ].p < 2 *cur_p.data [1 ].p ) {
241
+ LOG (" stopping drafting, probability too low: %8.f < 2*%8.f\n " , cur_p.data [0 ].p , cur_p.data [1 ].p );
249
242
break ;
250
243
}
251
244
245
+ // drafted token
252
246
const llama_token id = cur_p.data [0 ].id ;
253
247
254
248
if (grammar_dft != NULL ) {
@@ -258,17 +252,21 @@ int main(int argc, char ** argv) {
258
252
drafted.push_back (id);
259
253
++n_drafted;
260
254
261
- if (i < n_draft - 1 ) {
262
- // evaluate the drafted token on the draft model
263
- llama_eval (ctx_dft, &drafted.back (), 1 , n_past_cur, params.n_threads );
264
- ++n_past_cur;
255
+ // no need to evaluate the last drafted token, since we won't use the result
256
+ if (i == n_draft - 1 ) {
257
+ break ;
265
258
}
259
+
260
+ // evaluate the drafted token on the draft model
261
+ llama_eval (ctx_dft, &drafted.back (), 1 , n_past_cur, params.n_threads );
262
+ ++n_past_cur;
266
263
}
267
264
268
265
// evaluate the target model on the drafted tokens
269
266
llama_eval (ctx_tgt, drafted.data (), drafted.size (), n_past_tgt, params.n_threads );
270
267
++n_past_tgt;
271
268
269
+ // the first token is always proposed by the traget model before the speculation loop
272
270
drafted.erase (drafted.begin ());
273
271
}
274
272
0 commit comments