@@ -201,6 +201,107 @@ ws ::= [ \t\n\r]?)""";
201
201
llama_grammar_free (grammar);
202
202
}
203
203
204
+ static void test_chained_ambiguity () {
205
+ // Test case for a grammar that has chained ambiguity
206
+ const std::string grammar_str = R"""( root ::= [0-9] ("a"? "a"? "a"? "a"? "a"? "a"? "a"? "a"? "a"? "a"? [0-9])*)""" ;
207
+ // const std::string grammar_str = R"""(root ::= [0-9] (("a" ("a" ("a" ("a" ("a" ("a" ("a" ("a" ("a" ("a")?)?)?)?)?)?)?)?)?)? [0-9])*)""";
208
+
209
+ grammar_parser::parse_state parsed_grammar = grammar_parser::parse (grammar_str.c_str ());
210
+
211
+ // Ensure we parsed correctly
212
+ assert (!parsed_grammar.rules .empty ());
213
+
214
+ // Ensure we have a root node
215
+ assert (!(parsed_grammar.symbol_ids .find (" root" ) == parsed_grammar.symbol_ids .end ()));
216
+
217
+ std::vector<const llama_grammar_element*> grammar_rules (parsed_grammar.c_rules ());
218
+ llama_grammar* grammar = llama_grammar_init (
219
+ grammar_rules.data (), grammar_rules.size (), parsed_grammar.symbol_ids .at (" root" ));
220
+
221
+ std::string input = " 1aa2aa3aa4aa5" ;
222
+
223
+ auto decoded = decode_utf8 (input, {});
224
+
225
+ const auto & code_points = decoded.first ;
226
+
227
+ size_t cnt = 0 ;
228
+ for (auto it = code_points.begin (), end = code_points.end () - 1 ; it != end; ++it) {
229
+ // fprintf(stderr, "Parsing character %zu ('%c'), stack size %zu\n", cnt, input[cnt], grammar->stacks.size());
230
+ ++cnt;
231
+
232
+ auto prev_stacks = grammar->stacks ;
233
+ grammar->stacks = llama_grammar_accept (grammar->rules , grammar->stacks , *it);
234
+ if (grammar->stacks .empty ()) {
235
+ fprintf (stderr, " Unexpected character '%s'\n " , unicode_cpt_to_utf8 (*it).c_str ());
236
+ }
237
+ assert (!grammar->stacks .empty ());
238
+ }
239
+
240
+ bool completed_grammar = false ;
241
+
242
+ for (const auto & stack : grammar->stacks ) {
243
+ if (stack.empty ()) {
244
+ completed_grammar = true ;
245
+ break ;
246
+ }
247
+ }
248
+
249
+ assert (completed_grammar);
250
+
251
+ // Clean up allocated memory
252
+ llama_grammar_free (grammar);
253
+ }
254
+
255
+ static void test_chained_ambiguity_grouped () {
256
+ // Test case for a grammar that has chained ambiguity
257
+ const std::string grammar_str = R"""( root ::= [0-9] (("a" ("a" ("a" ("a" ("a" ("a" ("a" ("a" ("a" ("a")?)?)?)?)?)?)?)?)?)? [0-9])*)""" ;
258
+
259
+ grammar_parser::parse_state parsed_grammar = grammar_parser::parse (grammar_str.c_str ());
260
+
261
+ // Ensure we parsed correctly
262
+ assert (!parsed_grammar.rules .empty ());
263
+
264
+ // Ensure we have a root node
265
+ assert (!(parsed_grammar.symbol_ids .find (" root" ) == parsed_grammar.symbol_ids .end ()));
266
+
267
+ std::vector<const llama_grammar_element*> grammar_rules (parsed_grammar.c_rules ());
268
+ llama_grammar* grammar = llama_grammar_init (
269
+ grammar_rules.data (), grammar_rules.size (), parsed_grammar.symbol_ids .at (" root" ));
270
+
271
+ std::string input = " 1aa2aa3aa4aa5" ;
272
+
273
+ auto decoded = decode_utf8 (input, {});
274
+
275
+ const auto & code_points = decoded.first ;
276
+
277
+ size_t cnt = 0 ;
278
+ for (auto it = code_points.begin (), end = code_points.end () - 1 ; it != end; ++it) {
279
+ // fprintf(stderr, "Parsing character %zu ('%c'), stack size %zu\n", cnt, input[cnt], grammar->stacks.size());
280
+ ++cnt;
281
+
282
+ auto prev_stacks = grammar->stacks ;
283
+ grammar->stacks = llama_grammar_accept (grammar->rules , grammar->stacks , *it);
284
+ if (grammar->stacks .empty ()) {
285
+ fprintf (stderr, " Unexpected character '%s'\n " , unicode_cpt_to_utf8 (*it).c_str ());
286
+ }
287
+ assert (!grammar->stacks .empty ());
288
+ }
289
+
290
+ bool completed_grammar = false ;
291
+
292
+ for (const auto & stack : grammar->stacks ) {
293
+ if (stack.empty ()) {
294
+ completed_grammar = true ;
295
+ break ;
296
+ }
297
+ }
298
+
299
+ assert (completed_grammar);
300
+
301
+ // Clean up allocated memory
302
+ llama_grammar_free (grammar);
303
+ }
304
+
204
305
static void test_failure_missing_root () {
205
306
// Test case for a grammar that is missing a root rule
206
307
const std::string grammar_str = R"""( rot ::= expr
@@ -234,10 +335,33 @@ number ::= [0-9]+)""";
234
335
fprintf (stderr, " End of expected error. Test successful.\n " );
235
336
}
236
337
338
+ static std::vector<int64_t > times;
339
+ static std::vector<std::string> time_labels;
340
+
341
+ typedef void (*bench_func)(void );
342
+
343
+ static void bench (bench_func func, const char * label = " " ) {
344
+ func ();
345
+ times.push_back (ggml_time_us ());
346
+ time_labels.push_back (label);
347
+ }
348
+
237
349
int main () {
238
- test_simple_grammar ();
239
- test_complex_grammar ();
240
- test_failure_missing_root ();
241
- test_failure_missing_reference ();
350
+ ggml_time_init ();
351
+ times.push_back (ggml_time_us ());
352
+ time_labels.push_back (" Start" );
353
+ bench (test_simple_grammar, " Simple grammar" );
354
+ bench (test_complex_grammar, " Complex grammar" );
355
+ bench (test_chained_ambiguity, " Chained ambiguity" );
356
+ bench (test_chained_ambiguity_grouped, " Chained ambiguity (grouped)" );
357
+ bench (test_failure_missing_root, " Failure missing root" );
358
+ bench (test_failure_missing_reference, " Failure missing reference" );
359
+
360
+ // Print timings
361
+ fprintf (stdout, " \n Timings:\n " );
362
+ for (size_t i = 1 ; i < times.size (); ++i) {
363
+ fprintf (stdout, " %s: %lld us\n " , time_labels[i].c_str (), times[i] - times[i - 1 ]);
364
+ }
365
+
242
366
return 0 ;
243
367
}
0 commit comments