Skip to content

Commit 490d06f

Browse files
committed
Expanding tests to add chained-ambiguity stress tests as well as simple timing metrics.
1 parent f4183af commit 490d06f

File tree

1 file changed

+128
-4
lines changed

1 file changed

+128
-4
lines changed

tests/test-grammar-integration.cpp

Lines changed: 128 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,107 @@ ws ::= [ \t\n\r]?)""";
201201
llama_grammar_free(grammar);
202202
}
203203

204+
static void test_chained_ambiguity() {
205+
// Test case for a grammar that has chained ambiguity
206+
const std::string grammar_str = R"""(root ::= [0-9] ("a"? "a"? "a"? "a"? "a"? "a"? "a"? "a"? "a"? "a"? [0-9])*)""";
207+
// const std::string grammar_str = R"""(root ::= [0-9] (("a" ("a" ("a" ("a" ("a" ("a" ("a" ("a" ("a" ("a")?)?)?)?)?)?)?)?)?)? [0-9])*)""";
208+
209+
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
210+
211+
// Ensure we parsed correctly
212+
assert(!parsed_grammar.rules.empty());
213+
214+
// Ensure we have a root node
215+
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
216+
217+
std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
218+
llama_grammar* grammar = llama_grammar_init(
219+
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
220+
221+
std::string input = "1aa2aa3aa4aa5";
222+
223+
auto decoded = decode_utf8(input, {});
224+
225+
const auto & code_points = decoded.first;
226+
227+
size_t cnt = 0;
228+
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
229+
//fprintf(stderr, "Parsing character %zu ('%c'), stack size %zu\n", cnt, input[cnt], grammar->stacks.size());
230+
++cnt;
231+
232+
auto prev_stacks = grammar->stacks;
233+
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
234+
if (grammar->stacks.empty()) {
235+
fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
236+
}
237+
assert(!grammar->stacks.empty());
238+
}
239+
240+
bool completed_grammar = false;
241+
242+
for (const auto & stack : grammar->stacks) {
243+
if (stack.empty()) {
244+
completed_grammar = true;
245+
break;
246+
}
247+
}
248+
249+
assert(completed_grammar);
250+
251+
// Clean up allocated memory
252+
llama_grammar_free(grammar);
253+
}
254+
255+
static void test_chained_ambiguity_grouped() {
256+
// Test case for a grammar that has chained ambiguity
257+
const std::string grammar_str = R"""(root ::= [0-9] (("a" ("a" ("a" ("a" ("a" ("a" ("a" ("a" ("a" ("a")?)?)?)?)?)?)?)?)?)? [0-9])*)""";
258+
259+
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
260+
261+
// Ensure we parsed correctly
262+
assert(!parsed_grammar.rules.empty());
263+
264+
// Ensure we have a root node
265+
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
266+
267+
std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
268+
llama_grammar* grammar = llama_grammar_init(
269+
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
270+
271+
std::string input = "1aa2aa3aa4aa5";
272+
273+
auto decoded = decode_utf8(input, {});
274+
275+
const auto & code_points = decoded.first;
276+
277+
size_t cnt = 0;
278+
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
279+
//fprintf(stderr, "Parsing character %zu ('%c'), stack size %zu\n", cnt, input[cnt], grammar->stacks.size());
280+
++cnt;
281+
282+
auto prev_stacks = grammar->stacks;
283+
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
284+
if (grammar->stacks.empty()) {
285+
fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
286+
}
287+
assert(!grammar->stacks.empty());
288+
}
289+
290+
bool completed_grammar = false;
291+
292+
for (const auto & stack : grammar->stacks) {
293+
if (stack.empty()) {
294+
completed_grammar = true;
295+
break;
296+
}
297+
}
298+
299+
assert(completed_grammar);
300+
301+
// Clean up allocated memory
302+
llama_grammar_free(grammar);
303+
}
304+
204305
static void test_failure_missing_root() {
205306
// Test case for a grammar that is missing a root rule
206307
const std::string grammar_str = R"""(rot ::= expr
@@ -234,10 +335,33 @@ number ::= [0-9]+)""";
234335
fprintf(stderr, "End of expected error. Test successful.\n");
235336
}
236337

338+
static std::vector<int64_t> times;
339+
static std::vector<std::string> time_labels;
340+
341+
typedef void (*bench_func)(void);
342+
343+
static void bench(bench_func func, const char* label = "") {
344+
func();
345+
times.push_back(ggml_time_us());
346+
time_labels.push_back(label);
347+
}
348+
237349
int main() {
238-
test_simple_grammar();
239-
test_complex_grammar();
240-
test_failure_missing_root();
241-
test_failure_missing_reference();
350+
ggml_time_init();
351+
times.push_back(ggml_time_us());
352+
time_labels.push_back("Start");
353+
bench(test_simple_grammar, "Simple grammar");
354+
bench(test_complex_grammar, "Complex grammar");
355+
bench(test_chained_ambiguity, "Chained ambiguity");
356+
bench(test_chained_ambiguity_grouped, "Chained ambiguity (grouped)");
357+
bench(test_failure_missing_root, "Failure missing root");
358+
bench(test_failure_missing_reference, "Failure missing reference");
359+
360+
// Print timings
361+
fprintf(stdout, "\nTimings:\n");
362+
for (size_t i = 1; i < times.size(); ++i) {
363+
fprintf(stdout, "%s: %lld us\n", time_labels[i].c_str(), times[i] - times[i - 1]);
364+
}
365+
242366
return 0;
243367
}

0 commit comments

Comments
 (0)