@@ -116,8 +116,6 @@ int main(int argc, char ** argv) {
116
116
117
117
grammar_parser::parse_state parsed_grammar;
118
118
119
- std::vector<llama_grammar *> grammar_mem (n_draft, NULL );
120
-
121
119
// if requested - load the grammar, error checking is omitted for brevity
122
120
if (!params.grammar .empty ()) {
123
121
parsed_grammar = grammar_parser::parse (params.grammar .c_str ());
@@ -127,7 +125,6 @@ int main(int argc, char ** argv) {
127
125
}
128
126
129
127
std::vector<const llama_grammar_element *> grammar_rules (parsed_grammar.c_rules ());
130
- grammar_dft = llama_grammar_init (grammar_rules.data (), grammar_rules.size (), parsed_grammar.symbol_ids .at (" root" ));
131
128
grammar_tgt = llama_grammar_init (grammar_rules.data (), grammar_rules.size (), parsed_grammar.symbol_ids .at (" root" ));
132
129
}
133
130
@@ -173,11 +170,6 @@ int main(int argc, char ** argv) {
173
170
if (i_dft < (int ) drafted.size ()) {
174
171
LOG (" the %dth drafted token (%d, '%s') does not match the sampled target token (%d, '%s') - rejected\n " ,
175
172
i_dft, drafted[i_dft], llama_token_to_piece (ctx_dft, drafted[i_dft]).c_str (), id, token_str.c_str ());
176
-
177
- if (grammar_mem[i_dft]) {
178
- grammar_dft = llama_grammar_copy (grammar_mem[i_dft]);
179
- LOG (" restored draft grammar state %d\n " , i_dft);
180
- }
181
173
} else {
182
174
LOG (" out of drafted tokens\n " );
183
175
}
@@ -188,34 +180,25 @@ int main(int argc, char ** argv) {
188
180
drafted.clear ();
189
181
drafted.push_back (id);
190
182
191
- if (grammar_dft != NULL ) {
192
- llama_grammar_accept_token (ctx_dft, grammar_dft, id);
193
- }
183
+ break ;
184
+ }
194
185
186
+ if (n_predict > params.n_predict || has_eos) {
195
187
break ;
196
188
}
197
189
198
- for (int i = 0 ; i < (int ) grammar_mem.size (); ++i) {
199
- auto & g = grammar_mem[i];
200
- if (g) {
201
- LOG (" freeing grammar state %d\n " , i);
202
- llama_grammar_free (g);
203
- g = NULL ;
190
+ if (grammar_tgt) {
191
+ if (grammar_dft) {
192
+ llama_grammar_free (grammar_dft);
204
193
}
205
- }
194
+ grammar_dft = llama_grammar_copy (grammar_tgt);
206
195
207
- if (n_predict > params.n_predict || has_eos) {
208
- break ;
196
+ LOG (" copied target grammar to draft grammar\n " );
209
197
}
210
198
211
199
// sample n_draft tokens from the draft model using greedy decoding
212
200
int n_past_cur = n_past_dft;
213
201
for (int i = 0 ; i < n_draft; ++i) {
214
- // remember the grammar state
215
- if (grammar_dft != NULL ) {
216
- grammar_mem[i] = llama_grammar_copy (grammar_dft);
217
- }
218
-
219
202
float * logits = llama_get_logits (ctx_dft);
220
203
221
204
candidates.clear ();
@@ -238,17 +221,13 @@ int main(int argc, char ** argv) {
238
221
239
222
// TODO: better logic?
240
223
if (cur_p.data [0 ].p < 2 *cur_p.data [1 ].p ) {
241
- LOG (" stopping drafting, probability too low: %8.f < 2*%8.f \n " , cur_p.data [0 ].p , cur_p.data [1 ].p );
224
+ LOG (" stopping drafting, probability too low: %.3f < 2*%.3f \n " , cur_p.data [0 ].p , cur_p.data [1 ].p );
242
225
break ;
243
226
}
244
227
245
228
// drafted token
246
229
const llama_token id = cur_p.data [0 ].id ;
247
230
248
- if (grammar_dft != NULL ) {
249
- llama_grammar_accept_token (ctx_dft, grammar_dft, id);
250
- }
251
-
252
231
drafted.push_back (id);
253
232
++n_drafted;
254
233
@@ -260,6 +239,10 @@ int main(int argc, char ** argv) {
260
239
// evaluate the drafted token on the draft model
261
240
llama_eval (ctx_dft, &drafted.back (), 1 , n_past_cur, params.n_threads );
262
241
++n_past_cur;
242
+
243
+ if (grammar_dft != NULL ) {
244
+ llama_grammar_accept_token (ctx_dft, grammar_dft, id);
245
+ }
263
246
}
264
247
265
248
// evaluate the target model on the drafted tokens
0 commit comments