@@ -46,8 +46,12 @@ namespace grammar_parser {
46
46
state.rules [rule_id] = rule;
47
47
}
48
48
49
+ static bool is_digit_char (char c) {
50
+ return ' 0' <= c && c <= ' 9' ;
51
+ }
52
+
49
53
static bool is_word_char (char c) {
50
- return (' a' <= c && c <= ' z' ) || (' A' <= c && c <= ' Z' ) || c == ' -' || ( ' 0 ' <= c && c <= ' 9 ' );
54
+ return (' a' <= c && c <= ' z' ) || (' A' <= c && c <= ' Z' ) || c == ' -' || is_digit_char (c );
51
55
}
52
56
53
57
static std::pair<uint32_t , const char *> parse_hex (const char * src, int size) {
@@ -99,6 +103,17 @@ namespace grammar_parser {
99
103
return pos;
100
104
}
101
105
106
+ static const char * parse_int (const char * src) {
107
+ const char * pos = src;
108
+ while (is_digit_char (*pos)) {
109
+ pos++;
110
+ }
111
+ if (pos == src) {
112
+ throw std::runtime_error (std::string (" expecting name at " ) + src);
113
+ }
114
+ return pos;
115
+ }
116
+
102
117
static std::pair<uint32_t , const char *> parse_char (const char * src) {
103
118
if (*src == ' \\ ' ) {
104
119
switch (src[1 ]) {
@@ -137,6 +152,81 @@ namespace grammar_parser {
137
152
bool is_nested) {
138
153
size_t last_sym_start = out_elements.size ();
139
154
const char * pos = src;
155
+
156
+ auto handle_repetitions = [&](size_t min_times, int max_times) {
157
+
158
+ if (last_sym_start == out_elements.size ()) {
159
+ throw std::runtime_error (std::string (" expecting preceding item to */+/?/{ at " ) + pos);
160
+ }
161
+
162
+ // S* --> S{0,}
163
+ // S+ --> S{1,}
164
+ // S? --> S{0,1}
165
+ // S{m,n} --> S' ::= Scopy Scopy Scopy... (m times) S(n-m)
166
+ // Scopy ::= S
167
+ // S(x) ::= Scopy S(x-1) |
168
+ // S(x-1) ::= Scopy S(x-2) |
169
+ // S(1) ::= Scopy |
170
+ // S{m,} --> S' ::= Scopy Scopy Scopy (m times) Sstar
171
+ // Scopy ::= S
172
+ // Sstar ::= Scopy Sstar |
173
+
174
+ uint32_t content_rule_id = 0 ;
175
+ if (out_elements[last_sym_start].type == LLAMA_GRETYPE_RULE_REF) {
176
+ // The repeated content is already a rule ref, no need to copy it
177
+ content_rule_id = out_elements[last_sym_start].value ;
178
+ } else {
179
+ content_rule_id = generate_symbol_id (state, rule_name);
180
+ // add preceding symbol to generated copy rule
181
+ std::vector<llama_grammar_element> copy_rule (out_elements.begin () + last_sym_start, out_elements.end ());
182
+ copy_rule.push_back ({LLAMA_GRETYPE_END, 0 });
183
+ add_rule (state, content_rule_id, copy_rule);
184
+ }
185
+
186
+ uint32_t sub_rule_id = generate_symbol_id (state, rule_name);
187
+ std::vector<llama_grammar_element> sub_rule;
188
+ for (size_t i = 0 ; i < min_times; i++) {
189
+ sub_rule.push_back ({LLAMA_GRETYPE_RULE_REF, content_rule_id});
190
+ }
191
+ if (max_times < 0 ) {
192
+ uint32_t star_rule_id = generate_symbol_id (state, rule_name + " _star" );
193
+ add_rule (state, star_rule_id, {
194
+ {LLAMA_GRETYPE_RULE_REF, content_rule_id},
195
+ {LLAMA_GRETYPE_RULE_REF, star_rule_id},
196
+ {LLAMA_GRETYPE_ALT, 0 },
197
+ {LLAMA_GRETYPE_END, 0 }
198
+ });
199
+ sub_rule.push_back ({LLAMA_GRETYPE_RULE_REF, star_rule_id});
200
+ } else {
201
+ uint32_t last_rec_rule_id = 0 ;
202
+ for (int i = 0 , n = max_times - min_times; i < n; i++) {
203
+ uint32_t rec_rule_id = generate_symbol_id (state, rule_name + " _" + std::to_string (i + 1 ));
204
+ if (i == 0 ) {
205
+ add_rule (state, rec_rule_id, {
206
+ {LLAMA_GRETYPE_RULE_REF, content_rule_id},
207
+ {LLAMA_GRETYPE_ALT, 0 },
208
+ {LLAMA_GRETYPE_END, 0 }
209
+ });
210
+ } else {
211
+ add_rule (state, rec_rule_id, {
212
+ {LLAMA_GRETYPE_RULE_REF, content_rule_id},
213
+ {LLAMA_GRETYPE_RULE_REF, last_rec_rule_id},
214
+ {LLAMA_GRETYPE_ALT, 0 },
215
+ {LLAMA_GRETYPE_END, 0 }
216
+ });
217
+ }
218
+ last_rec_rule_id = rec_rule_id;
219
+ }
220
+ sub_rule.push_back ({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
221
+ }
222
+ sub_rule.push_back ({LLAMA_GRETYPE_END, 0 });
223
+ add_rule (state, sub_rule_id, sub_rule);
224
+
225
+ // in original rule, replace previous symbol with reference to generated rule
226
+ out_elements.resize (last_sym_start);
227
+ out_elements.push_back ({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
228
+ };
229
+
140
230
while (*pos) {
141
231
if (*pos == ' "' ) { // literal string
142
232
pos++;
@@ -188,40 +278,38 @@ namespace grammar_parser {
188
278
throw std::runtime_error (std::string (" expecting ')' at " ) + pos);
189
279
}
190
280
pos = parse_space (pos + 1 , is_nested);
191
- } else if (*pos == ' *' || *pos == ' +' || *pos == ' ?' ) { // repetition operator
192
- if (last_sym_start == out_elements.size ()) {
193
- throw std::runtime_error (std::string (" expecting preceding item to */+/? at " ) + pos);
281
+ } else if (*pos == ' *' ) {
282
+ pos = parse_space (pos + 1 , is_nested);
283
+ handle_repetitions (0 , -1 );
284
+ } else if (*pos == ' +' ) {
285
+ pos = parse_space (pos + 1 , is_nested);
286
+ handle_repetitions (1 , -1 );
287
+ } else if (*pos == ' ?' ) {
288
+ pos = parse_space (pos + 1 , is_nested);
289
+ handle_repetitions (0 , 1 );
290
+ } else if (*pos == ' {' ) {
291
+ pos = parse_space (pos + 1 , is_nested);
292
+ size_t min_times = 0 ;
293
+ int max_times = -1 ;
294
+ if (is_digit_char (*pos)) {
295
+ const char * int_end = parse_int (pos);
296
+ min_times = std::stoul (std::string (pos, int_end - pos));
297
+ pos = parse_space (int_end, is_nested);
194
298
}
195
-
196
- // apply transformation to previous symbol (last_sym_start to end) according to
197
- // rewrite rules:
198
- // S* --> S' ::= S S' |
199
- // S+ --> S' ::= S S' | S
200
- // S? --> S' ::= S |
201
- uint32_t sub_rule_id = generate_symbol_id (state, rule_name);
202
- std::vector<llama_grammar_element> sub_rule;
203
- // add preceding symbol to generated rule
204
- sub_rule.insert (
205
- sub_rule.end (), out_elements.begin () + last_sym_start, out_elements.end ());
206
- if (*pos == ' *' || *pos == ' +' ) {
207
- // cause generated rule to recurse
208
- sub_rule.push_back ({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
299
+ if (*pos != ' ,' ) {
300
+ throw std::runtime_error (std::string (" expecting ',' at " ) + pos);
209
301
}
210
- // mark start of alternate def
211
- sub_rule.push_back ({LLAMA_GRETYPE_ALT, 0 });
212
- if (*pos == ' +' ) {
213
- // add preceding symbol as alternate only for '+' (otherwise empty)
214
- sub_rule.insert (
215
- sub_rule.end (), out_elements.begin () + last_sym_start, out_elements.end ());
302
+ pos = parse_space (pos + 1 , is_nested);
303
+ if (is_digit_char (*pos)) {
304
+ const char * int_end = parse_int (pos);
305
+ max_times = std::stoul (std::string (pos, int_end - pos));
306
+ pos = parse_space (int_end, is_nested);
307
+ }
308
+ if (*pos != ' }' ) {
309
+ throw std::runtime_error (std::string (" expecting '}' at " ) + pos);
216
310
}
217
- sub_rule.push_back ({LLAMA_GRETYPE_END, 0 });
218
- add_rule (state, sub_rule_id, sub_rule);
219
-
220
- // in original rule, replace previous symbol with reference to generated rule
221
- out_elements.resize (last_sym_start);
222
- out_elements.push_back ({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
223
-
224
311
pos = parse_space (pos + 1 , is_nested);
312
+ handle_repetitions (min_times, max_times);
225
313
} else {
226
314
break ;
227
315
}
0 commit comments