Skip to content

Commit 84767a7

Browse files
committed
Cleaning up integration tests to share code between tests and make it simpler to add new tests.
1 parent ab9a324 commit 84767a7

File tree

1 file changed

+33
-80
lines changed

1 file changed

+33
-80
lines changed

tests/test-grammar-integration.cpp

Lines changed: 33 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,8 @@
1111
#include <cassert>
1212
#include <string>
1313

14-
static void test_simple_grammar() {
15-
// Test case for a simple grammar
16-
const std::string grammar_str = R"""(root ::= expr
17-
expr ::= term ("+" term)*
18-
term ::= number
19-
number ::= [0-9]+)""";
20-
21-
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
14+
static llama_grammar* build_grammar(const std::string & grammar_str) {
15+
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
2216

2317
// Ensure we parsed correctly
2418
assert(!parsed_grammar.rules.empty());
@@ -30,28 +24,45 @@ number ::= [0-9]+)""";
3024
llama_grammar* grammar = llama_grammar_init(
3125
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
3226

33-
std::string input = "123+456";
27+
return grammar;
28+
}
3429

30+
static bool match_string(const std::string & input, llama_grammar* grammar) {
3531
auto decoded = decode_utf8(input, {});
3632

3733
const auto & code_points = decoded.first;
3834

3935
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
4036
auto prev_stacks = grammar->stacks;
4137
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
42-
assert(!grammar->stacks.empty());
38+
if (grammar->stacks.empty()) {
39+
// no stacks means that the grammar failed to match at this point
40+
return false;
41+
}
4342
}
4443

45-
bool completed_grammar = false;
46-
4744
for (const auto & stack : grammar->stacks) {
4845
if (stack.empty()) {
49-
completed_grammar = true;
50-
break;
46+
// An empty stack means that the grammar has been completed
47+
return true;
5148
}
5249
}
5350

54-
assert(completed_grammar);
51+
return false;
52+
}
53+
54+
static void test_simple_grammar() {
55+
// Test case for a simple grammar
56+
const std::string grammar_str = R"""(root ::= expr
57+
expr ::= term ("+" term)*
58+
term ::= number
59+
number ::= [0-9]+)""";
60+
61+
auto grammar = build_grammar(grammar_str);
62+
63+
bool matched = match_string("123+456", grammar);
64+
65+
assert(matched);
5566

5667
// Clean up allocated memory
5768
llama_grammar_free(grammar);
@@ -68,17 +79,7 @@ variable ::= [a-zA-Z_][a-zA-Z0-9_]*
6879
function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
6980
ws ::= [ \t\n\r]?)""";
7081

71-
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
72-
73-
// Ensure we parsed correctly
74-
assert(!parsed_grammar.rules.empty());
75-
76-
// Ensure we have a root node
77-
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
78-
79-
std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
80-
llama_grammar* grammar = llama_grammar_init(
81-
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
82+
auto grammar = build_grammar(grammar_str);
8283

8384
// Save the original grammar stacks so that we can reset after every new string we want to test
8485
auto original_stacks = grammar->stacks;
@@ -130,68 +131,19 @@ ws ::= [ \t\n\r]?)""";
130131

131132
// Passing strings
132133
for (const auto & test_string : test_strings_pass) {
133-
auto decoded = decode_utf8(test_string, {});
134-
135-
const auto & code_points = decoded.first;
136-
137-
int pos = 0;
138-
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
139-
++pos;
140-
auto prev_stacks = grammar->stacks;
141-
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
142-
143-
// Expect that each code point will not cause the grammar to fail
144-
if (grammar->stacks.empty()) {
145-
fprintf(stdout, "Error at position %d\n", pos);
146-
fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
147-
fprintf(stderr, "Input string is %s:\n", test_string.c_str());
148-
}
149-
assert(!grammar->stacks.empty());
150-
}
151-
152-
bool completed_grammar = false;
153-
154-
for (const auto & stack : grammar->stacks) {
155-
if (stack.empty()) {
156-
completed_grammar = true;
157-
break;
158-
}
159-
}
134+
bool matched = match_string(test_string, grammar);
160135

161-
assert(completed_grammar);
136+
assert(matched);
162137

163138
// Reset the grammar stacks
164139
grammar->stacks = original_stacks;
165140
}
166141

167142
// Failing strings
168143
for (const auto & test_string : test_strings_fail) {
169-
auto decoded = decode_utf8(test_string, {});
170-
171-
const auto & code_points = decoded.first;
172-
bool parse_failed = false;
173-
174-
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
175-
auto prev_stacks = grammar->stacks;
176-
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
177-
if (grammar->stacks.empty()) {
178-
parse_failed = true;
179-
break;
180-
}
181-
assert(!grammar->stacks.empty());
182-
}
183-
184-
bool completed_grammar = false;
185-
186-
for (const auto & stack : grammar->stacks) {
187-
if (stack.empty()) {
188-
completed_grammar = true;
189-
break;
190-
}
191-
}
144+
bool matched = match_string(test_string, grammar);
192145

193-
// Ensure that the grammar is not completed, or that each string failed to match as-expected
194-
assert((!completed_grammar) || parse_failed);
146+
assert(!matched);
195147

196148
// Reset the grammar stacks
197149
grammar->stacks = original_stacks;
@@ -231,13 +183,14 @@ number ::= [0-9]+)""";
231183
// Ensure we did NOT parsed correctly
232184
assert(parsed_grammar.rules.empty());
233185

234-
fprintf(stderr, "End of expected error. Test successful.\n");
186+
fprintf(stderr, "End of expected error.\n");
235187
}
236188

237189
int main() {
238190
test_simple_grammar();
239191
test_complex_grammar();
240192
test_failure_missing_root();
241193
test_failure_missing_reference();
194+
fprintf(stdout, "All tests passed.\n");
242195
return 0;
243196
}

0 commit comments

Comments
 (0)