11
11
#include < cassert>
12
12
#include < string>
13
13
14
- static void test_simple_grammar () {
15
- // Test case for a simple grammar
16
- const std::string grammar_str = R"""( root ::= expr
17
- expr ::= term ("+" term)*
18
- term ::= number
19
- number ::= [0-9]+)""" ;
20
-
21
- grammar_parser::parse_state parsed_grammar = grammar_parser::parse (grammar_str.c_str ());
14
+ static llama_grammar* build_grammar (const std::string & grammar_str) {
15
+ auto parsed_grammar = grammar_parser::parse (grammar_str.c_str ());
22
16
23
17
// Ensure we parsed correctly
24
18
assert (!parsed_grammar.rules .empty ());
@@ -30,28 +24,45 @@ number ::= [0-9]+)""";
30
24
llama_grammar* grammar = llama_grammar_init (
31
25
grammar_rules.data (), grammar_rules.size (), parsed_grammar.symbol_ids .at (" root" ));
32
26
33
- std::string input = " 123+456" ;
27
+ return grammar;
28
+ }
34
29
30
+ static bool match_string (const std::string & input, llama_grammar* grammar) {
35
31
auto decoded = decode_utf8 (input, {});
36
32
37
33
const auto & code_points = decoded.first ;
38
34
39
35
for (auto it = code_points.begin (), end = code_points.end () - 1 ; it != end; ++it) {
40
36
auto prev_stacks = grammar->stacks ;
41
37
llama_grammar_accept (grammar->rules , prev_stacks, *it, grammar->stacks );
42
- assert (!grammar->stacks .empty ());
38
+ if (grammar->stacks .empty ()) {
39
+ // no stacks means that the grammar failed to match at this point
40
+ return false ;
41
+ }
43
42
}
44
43
45
- bool completed_grammar = false ;
46
-
47
44
for (const auto & stack : grammar->stacks ) {
48
45
if (stack.empty ()) {
49
- completed_grammar = true ;
50
- break ;
46
+ // An empty stack means that the grammar has been completed
47
+ return true ;
51
48
}
52
49
}
53
50
54
- assert (completed_grammar);
51
+ return false ;
52
+ }
53
+
54
+ static void test_simple_grammar () {
55
+ // Test case for a simple grammar
56
+ const std::string grammar_str = R"""( root ::= expr
57
+ expr ::= term ("+" term)*
58
+ term ::= number
59
+ number ::= [0-9]+)""" ;
60
+
61
+ auto grammar = build_grammar (grammar_str);
62
+
63
+ bool matched = match_string (" 123+456" , grammar);
64
+
65
+ assert (matched);
55
66
56
67
// Clean up allocated memory
57
68
llama_grammar_free (grammar);
@@ -68,17 +79,7 @@ variable ::= [a-zA-Z_][a-zA-Z0-9_]*
68
79
function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
69
80
ws ::= [ \t\n\r]?)""" ;
70
81
71
- grammar_parser::parse_state parsed_grammar = grammar_parser::parse (grammar_str.c_str ());
72
-
73
- // Ensure we parsed correctly
74
- assert (!parsed_grammar.rules .empty ());
75
-
76
- // Ensure we have a root node
77
- assert (!(parsed_grammar.symbol_ids .find (" root" ) == parsed_grammar.symbol_ids .end ()));
78
-
79
- std::vector<const llama_grammar_element*> grammar_rules (parsed_grammar.c_rules ());
80
- llama_grammar* grammar = llama_grammar_init (
81
- grammar_rules.data (), grammar_rules.size (), parsed_grammar.symbol_ids .at (" root" ));
82
+ auto grammar = build_grammar (grammar_str);
82
83
83
84
// Save the original grammar stacks so that we can reset after every new string we want to test
84
85
auto original_stacks = grammar->stacks ;
@@ -130,68 +131,19 @@ ws ::= [ \t\n\r]?)""";
130
131
131
132
// Passing strings
132
133
for (const auto & test_string : test_strings_pass) {
133
- auto decoded = decode_utf8 (test_string, {});
134
-
135
- const auto & code_points = decoded.first ;
136
-
137
- int pos = 0 ;
138
- for (auto it = code_points.begin (), end = code_points.end () - 1 ; it != end; ++it) {
139
- ++pos;
140
- auto prev_stacks = grammar->stacks ;
141
- llama_grammar_accept (grammar->rules , prev_stacks, *it, grammar->stacks );
142
-
143
- // Expect that each code point will not cause the grammar to fail
144
- if (grammar->stacks .empty ()) {
145
- fprintf (stdout, " Error at position %d\n " , pos);
146
- fprintf (stderr, " Unexpected character '%s'\n " , unicode_cpt_to_utf8 (*it).c_str ());
147
- fprintf (stderr, " Input string is %s:\n " , test_string.c_str ());
148
- }
149
- assert (!grammar->stacks .empty ());
150
- }
151
-
152
- bool completed_grammar = false ;
153
-
154
- for (const auto & stack : grammar->stacks ) {
155
- if (stack.empty ()) {
156
- completed_grammar = true ;
157
- break ;
158
- }
159
- }
134
+ bool matched = match_string (test_string, grammar);
160
135
161
- assert (completed_grammar );
136
+ assert (matched );
162
137
163
138
// Reset the grammar stacks
164
139
grammar->stacks = original_stacks;
165
140
}
166
141
167
142
// Failing strings
168
143
for (const auto & test_string : test_strings_fail) {
169
- auto decoded = decode_utf8 (test_string, {});
170
-
171
- const auto & code_points = decoded.first ;
172
- bool parse_failed = false ;
173
-
174
- for (auto it = code_points.begin (), end = code_points.end () - 1 ; it != end; ++it) {
175
- auto prev_stacks = grammar->stacks ;
176
- llama_grammar_accept (grammar->rules , prev_stacks, *it, grammar->stacks );
177
- if (grammar->stacks .empty ()) {
178
- parse_failed = true ;
179
- break ;
180
- }
181
- assert (!grammar->stacks .empty ());
182
- }
183
-
184
- bool completed_grammar = false ;
185
-
186
- for (const auto & stack : grammar->stacks ) {
187
- if (stack.empty ()) {
188
- completed_grammar = true ;
189
- break ;
190
- }
191
- }
144
+ bool matched = match_string (test_string, grammar);
192
145
193
- // Ensure that the grammar is not completed, or that each string failed to match as-expected
194
- assert ((!completed_grammar) || parse_failed);
146
+ assert (!matched);
195
147
196
148
// Reset the grammar stacks
197
149
grammar->stacks = original_stacks;
@@ -231,13 +183,14 @@ number ::= [0-9]+)""";
231
183
// Ensure we did NOT parsed correctly
232
184
assert (parsed_grammar.rules .empty ());
233
185
234
- fprintf (stderr, " End of expected error. Test successful. \n " );
186
+ fprintf (stderr, " End of expected error.\n " );
235
187
}
236
188
237
189
int main () {
238
190
test_simple_grammar ();
239
191
test_complex_grammar ();
240
192
test_failure_missing_root ();
241
193
test_failure_missing_reference ();
194
+ fprintf (stdout, " All tests passed.\n " );
242
195
return 0 ;
243
196
}
0 commit comments