6
6
7
7
#include " common.h"
8
8
#include " llama.h"
9
+ #include " grammar-parser.h"
9
10
10
11
#include < cmath>
11
12
#include < cstdio>
@@ -109,6 +110,41 @@ int main(int argc, char ** argv) {
109
110
// used to determine end of generation
110
111
bool has_eos = false ;
111
112
113
+ // grammar stuff
114
+ struct llama_grammar * grammar_dft = NULL ;
115
+ struct llama_grammar * grammar_tgt = NULL ;
116
+
117
+ grammar_parser::parse_state parsed_grammar_dft;
118
+ grammar_parser::parse_state parsed_grammar_tgt;
119
+
120
+ std::vector<llama_grammar *> grammar_mem (n_draft, NULL );
121
+
122
+ if (!params.grammar .empty ()) {
123
+ // dft
124
+ {
125
+ parsed_grammar_dft = grammar_parser::parse (params.grammar .c_str ());
126
+ // will be empty (default) if there are parse errors
127
+ if (parsed_grammar_dft.rules .empty ()) {
128
+ return 1 ;
129
+ }
130
+
131
+ std::vector<const llama_grammar_element *> grammar_rules (parsed_grammar_dft.c_rules ());
132
+ grammar_dft = llama_grammar_init (grammar_rules.data (), grammar_rules.size (), parsed_grammar_dft.symbol_ids .at (" root" ));
133
+ }
134
+
135
+ // tgt
136
+ {
137
+ parsed_grammar_tgt = grammar_parser::parse (params.grammar .c_str ());
138
+ // will be empty (default) if there are parse errors
139
+ if (parsed_grammar_tgt.rules .empty ()) {
140
+ return 1 ;
141
+ }
142
+
143
+ std::vector<const llama_grammar_element *> grammar_rules (parsed_grammar_tgt.c_rules ());
144
+ grammar_tgt = llama_grammar_init (grammar_rules.data (), grammar_rules.size (), parsed_grammar_tgt.symbol_ids .at (" root" ));
145
+ }
146
+ }
147
+
112
148
const auto t_dec_start = ggml_time_us ();
113
149
114
150
while (true ) {
@@ -117,7 +153,7 @@ int main(int argc, char ** argv) {
117
153
// sample from the drafted tokens if any
118
154
int i_dft = 0 ;
119
155
while (true ) {
120
- const llama_token id = llama_sample_token (ctx_tgt, NULL , NULL , params, last_tokens, candidates, i_dft);
156
+ const llama_token id = llama_sample_token (ctx_tgt, NULL , grammar_tgt , params, last_tokens, candidates, i_dft);
121
157
122
158
last_tokens.erase (last_tokens.begin ());
123
159
last_tokens.push_back (id);
@@ -144,13 +180,35 @@ int main(int argc, char ** argv) {
144
180
continue ;
145
181
}
146
182
183
+ if (i_dft < (int ) drafted.size ()) {
184
+ LOG (" drafted token %d rejected\n " , id);
185
+
186
+ if (grammar_mem[i_dft]) {
187
+ grammar_dft = llama_grammar_copy (grammar_mem[i_dft]);
188
+ LOG (" restored grammar %d\n " , i_dft);
189
+ }
190
+ }
191
+
192
+ for (auto & g : grammar_mem) {
193
+ if (g) {
194
+ llama_grammar_free (g);
195
+ g = NULL ;
196
+ }
197
+ }
198
+
199
+ LOG (" i_dft = %d, drafted.size() = %d\n " , i_dft, (int ) drafted.size ());
200
+
147
201
// the drafted token was rejected or we are out of drafted tokens
148
202
llama_eval (ctx_dft, &id, 1 , n_past_dft, params.n_threads );
149
203
++n_past_dft;
150
204
151
205
drafted.clear ();
152
206
drafted.push_back (id);
153
207
208
+ if (grammar_dft != NULL ) {
209
+ llama_grammar_accept_token (ctx_dft, grammar_dft, id);
210
+ }
211
+
154
212
break ;
155
213
}
156
214
@@ -161,6 +219,11 @@ int main(int argc, char ** argv) {
161
219
// sample n_draft tokens from the draft model picking the best token
162
220
int n_past_cur = n_past_dft;
163
221
for (int i = 0 ; i < n_draft; ++i) {
222
+ // remember the grammar state
223
+ if (grammar_dft != NULL ) {
224
+ grammar_mem[i] = llama_grammar_copy (grammar_dft);
225
+ }
226
+
164
227
float * logits = llama_get_logits (ctx_dft);
165
228
166
229
candidates.clear ();
@@ -170,6 +233,10 @@ int main(int argc, char ** argv) {
170
233
171
234
llama_token_data_array cur_p = { candidates.data (), candidates.size (), false };
172
235
236
+ if (grammar_dft != NULL ) {
237
+ llama_sample_grammar (ctx_dft, &cur_p, grammar_dft);
238
+ }
239
+
173
240
// computes softmax and sorts the candidates
174
241
llama_sample_softmax (ctx_dft, &cur_p);
175
242
@@ -182,7 +249,13 @@ int main(int argc, char ** argv) {
182
249
break ;
183
250
}
184
251
185
- drafted.push_back (cur_p.data [0 ].id );
252
+ const llama_token id = cur_p.data [0 ].id ;
253
+
254
+ if (grammar_dft != NULL ) {
255
+ llama_grammar_accept_token (ctx_dft, grammar_dft, id);
256
+ }
257
+
258
+ drafted.push_back (id);
186
259
++n_drafted;
187
260
188
261
if (i < n_draft - 1 ) {
@@ -226,6 +299,10 @@ int main(int argc, char ** argv) {
226
299
llama_free (ctx_dft);
227
300
llama_free_model (model_dft);
228
301
302
+ if (grammar_dft != NULL ) {
303
+ llama_grammar_free (grammar_dft);
304
+ llama_grammar_free (grammar_tgt);
305
+ }
229
306
llama_backend_free ();
230
307
231
308
fprintf (stderr, " \n\n " );
0 commit comments