Skip to content

Commit ba199d8

Browse files
committed
speculative : avoid grammar_mem
1 parent eef008e commit ba199d8

File tree

1 file changed

+13
-30
lines changed

1 file changed

+13
-30
lines changed

examples/speculative/speculative.cpp

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,6 @@ int main(int argc, char ** argv) {
116116

117117
grammar_parser::parse_state parsed_grammar;
118118

119-
std::vector<llama_grammar *> grammar_mem(n_draft, NULL);
120-
121119
// if requested - load the grammar, error checking is omitted for brevity
122120
if (!params.grammar.empty()) {
123121
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
@@ -127,7 +125,6 @@ int main(int argc, char ** argv) {
127125
}
128126

129127
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
130-
grammar_dft = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
131128
grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
132129
}
133130

@@ -173,11 +170,6 @@ int main(int argc, char ** argv) {
173170
if (i_dft < (int) drafted.size()) {
174171
LOG("the %dth drafted token (%d, '%s') does not match the sampled target token (%d, '%s') - rejected\n",
175172
i_dft, drafted[i_dft], llama_token_to_piece(ctx_dft, drafted[i_dft]).c_str(), id, token_str.c_str());
176-
177-
if (grammar_mem[i_dft]) {
178-
grammar_dft = llama_grammar_copy(grammar_mem[i_dft]);
179-
LOG("restored draft grammar state %d\n", i_dft);
180-
}
181173
} else {
182174
LOG("out of drafted tokens\n");
183175
}
@@ -188,34 +180,25 @@ int main(int argc, char ** argv) {
188180
drafted.clear();
189181
drafted.push_back(id);
190182

191-
if (grammar_dft != NULL) {
192-
llama_grammar_accept_token(ctx_dft, grammar_dft, id);
193-
}
183+
break;
184+
}
194185

186+
if (n_predict > params.n_predict || has_eos) {
195187
break;
196188
}
197189

198-
for (int i = 0; i < (int) grammar_mem.size(); ++i) {
199-
auto & g = grammar_mem[i];
200-
if (g) {
201-
LOG("freeing grammar state %d\n", i);
202-
llama_grammar_free(g);
203-
g = NULL;
190+
if (grammar_tgt) {
191+
if (grammar_dft) {
192+
llama_grammar_free(grammar_dft);
204193
}
205-
}
194+
grammar_dft = llama_grammar_copy(grammar_tgt);
206195

207-
if (n_predict > params.n_predict || has_eos) {
208-
break;
196+
LOG("copied target grammar to draft grammar\n");
209197
}
210198

211199
// sample n_draft tokens from the draft model using greedy decoding
212200
int n_past_cur = n_past_dft;
213201
for (int i = 0; i < n_draft; ++i) {
214-
// remember the grammar state
215-
if (grammar_dft != NULL) {
216-
grammar_mem[i] = llama_grammar_copy(grammar_dft);
217-
}
218-
219202
float * logits = llama_get_logits(ctx_dft);
220203

221204
candidates.clear();
@@ -238,17 +221,13 @@ int main(int argc, char ** argv) {
238221

239222
// TODO: better logic?
240223
if (cur_p.data[0].p < 2*cur_p.data[1].p) {
241-
LOG("stopping drafting, probability too low: %8.f < 2*%8.f\n", cur_p.data[0].p, cur_p.data[1].p);
224+
LOG("stopping drafting, probability too low: %.3f < 2*%.3f\n", cur_p.data[0].p, cur_p.data[1].p);
242225
break;
243226
}
244227

245228
// drafted token
246229
const llama_token id = cur_p.data[0].id;
247230

248-
if (grammar_dft != NULL) {
249-
llama_grammar_accept_token(ctx_dft, grammar_dft, id);
250-
}
251-
252231
drafted.push_back(id);
253232
++n_drafted;
254233

@@ -260,6 +239,10 @@ int main(int argc, char ** argv) {
260239
// evaluate the drafted token on the draft model
261240
llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads);
262241
++n_past_cur;
242+
243+
if (grammar_dft != NULL) {
244+
llama_grammar_accept_token(ctx_dft, grammar_dft, id);
245+
}
263246
}
264247

265248
// evaluate the target model on the drafted tokens

0 commit comments

Comments
 (0)