Skip to content

Commit 0160469

Browse files
author
Olivier Chafik
committed
grammars: x{min,max} repetition operator + tweak +/*/? to avoid duplication of original over alternates
1 parent 4cc120c commit 0160469

File tree

1 file changed

+119
-31
lines changed

1 file changed

+119
-31
lines changed

common/grammar-parser.cpp

Lines changed: 119 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,12 @@ namespace grammar_parser {
4646
state.rules[rule_id] = rule;
4747
}
4848

49+
static bool is_digit_char(char c) {
50+
return '0' <= c && c <= '9';
51+
}
52+
4953
static bool is_word_char(char c) {
50-
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
54+
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
5155
}
5256

5357
static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
@@ -99,6 +103,17 @@ namespace grammar_parser {
99103
return pos;
100104
}
101105

106+
static const char * parse_int(const char * src) {
107+
const char * pos = src;
108+
while (is_digit_char(*pos)) {
109+
pos++;
110+
}
111+
if (pos == src) {
112+
throw std::runtime_error(std::string("expecting name at ") + src);
113+
}
114+
return pos;
115+
}
116+
102117
static std::pair<uint32_t, const char *> parse_char(const char * src) {
103118
if (*src == '\\') {
104119
switch (src[1]) {
@@ -137,6 +152,81 @@ namespace grammar_parser {
137152
bool is_nested) {
138153
size_t last_sym_start = out_elements.size();
139154
const char * pos = src;
155+
156+
auto handle_repetitions = [&](size_t min_times, int max_times) {
157+
158+
if (last_sym_start == out_elements.size()) {
159+
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
160+
}
161+
162+
// S* --> S{0,}
163+
// S+ --> S{1,}
164+
// S? --> S{0,1}
165+
// S{m,n} --> S' ::= Scopy Scopy Scopy... (m times) S(n-m)
166+
// Scopy ::= S
167+
// S(x) ::= Scopy S(x-1) |
168+
// S(x-1) ::= Scopy S(x-2) |
169+
// S(1) ::= Scopy |
170+
// S{m,} --> S' ::= Scopy Scopy Scopy (m times) Sstar
171+
// Scopy ::= S
172+
// Sstar ::= Scopy Sstar |
173+
174+
uint32_t content_rule_id = 0;
175+
if (out_elements[last_sym_start].type == LLAMA_GRETYPE_RULE_REF) {
176+
// The repeated content is already a rule ref, no need to copy it
177+
content_rule_id = out_elements[last_sym_start].value;
178+
} else {
179+
content_rule_id = generate_symbol_id(state, rule_name);
180+
// add preceding symbol to generated copy rule
181+
std::vector<llama_grammar_element> copy_rule(out_elements.begin() + last_sym_start, out_elements.end());
182+
copy_rule.push_back({LLAMA_GRETYPE_END, 0});
183+
add_rule(state, content_rule_id, copy_rule);
184+
}
185+
186+
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
187+
std::vector<llama_grammar_element> sub_rule;
188+
for (size_t i = 0; i < min_times; i++) {
189+
sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, content_rule_id});
190+
}
191+
if (max_times < 0) {
192+
uint32_t star_rule_id = generate_symbol_id(state, rule_name + "_star");
193+
add_rule(state, star_rule_id, {
194+
{LLAMA_GRETYPE_RULE_REF, content_rule_id},
195+
{LLAMA_GRETYPE_RULE_REF, star_rule_id},
196+
{LLAMA_GRETYPE_ALT, 0},
197+
{LLAMA_GRETYPE_END, 0}
198+
});
199+
sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, star_rule_id});
200+
} else {
201+
uint32_t last_rec_rule_id = 0;
202+
for (int i = 0, n = max_times - min_times; i < n; i++) {
203+
uint32_t rec_rule_id = generate_symbol_id(state, rule_name + "_" + std::to_string(i + 1));
204+
if (i == 0) {
205+
add_rule(state, rec_rule_id, {
206+
{LLAMA_GRETYPE_RULE_REF, content_rule_id},
207+
{LLAMA_GRETYPE_ALT, 0},
208+
{LLAMA_GRETYPE_END, 0}
209+
});
210+
} else {
211+
add_rule(state, rec_rule_id, {
212+
{LLAMA_GRETYPE_RULE_REF, content_rule_id},
213+
{LLAMA_GRETYPE_RULE_REF, last_rec_rule_id},
214+
{LLAMA_GRETYPE_ALT, 0},
215+
{LLAMA_GRETYPE_END, 0}
216+
});
217+
}
218+
last_rec_rule_id = rec_rule_id;
219+
}
220+
sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
221+
}
222+
sub_rule.push_back({LLAMA_GRETYPE_END, 0});
223+
add_rule(state, sub_rule_id, sub_rule);
224+
225+
// in original rule, replace previous symbol with reference to generated rule
226+
out_elements.resize(last_sym_start);
227+
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
228+
};
229+
140230
while (*pos) {
141231
if (*pos == '"') { // literal string
142232
pos++;
@@ -188,40 +278,38 @@ namespace grammar_parser {
188278
throw std::runtime_error(std::string("expecting ')' at ") + pos);
189279
}
190280
pos = parse_space(pos + 1, is_nested);
191-
} else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
192-
if (last_sym_start == out_elements.size()) {
193-
throw std::runtime_error(std::string("expecting preceding item to */+/? at ") + pos);
281+
} else if (*pos == '*') {
282+
pos = parse_space(pos + 1, is_nested);
283+
handle_repetitions(0, -1);
284+
} else if (*pos == '+') {
285+
pos = parse_space(pos + 1, is_nested);
286+
handle_repetitions(1, -1);
287+
} else if (*pos == '?') {
288+
pos = parse_space(pos + 1, is_nested);
289+
handle_repetitions(0, 1);
290+
} else if (*pos == '{') {
291+
pos = parse_space(pos + 1, is_nested);
292+
size_t min_times = 0;
293+
int max_times = -1;
294+
if (is_digit_char(*pos)) {
295+
const char * int_end = parse_int(pos);
296+
min_times = std::stoul(std::string(pos, int_end - pos));
297+
pos = parse_space(int_end, is_nested);
194298
}
195-
196-
// apply transformation to previous symbol (last_sym_start to end) according to
197-
// rewrite rules:
198-
// S* --> S' ::= S S' |
199-
// S+ --> S' ::= S S' | S
200-
// S? --> S' ::= S |
201-
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
202-
std::vector<llama_grammar_element> sub_rule;
203-
// add preceding symbol to generated rule
204-
sub_rule.insert(
205-
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
206-
if (*pos == '*' || *pos == '+') {
207-
// cause generated rule to recurse
208-
sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
299+
if (*pos != ',') {
300+
throw std::runtime_error(std::string("expecting ',' at ") + pos);
209301
}
210-
// mark start of alternate def
211-
sub_rule.push_back({LLAMA_GRETYPE_ALT, 0});
212-
if (*pos == '+') {
213-
// add preceding symbol as alternate only for '+' (otherwise empty)
214-
sub_rule.insert(
215-
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
302+
pos = parse_space(pos + 1, is_nested);
303+
if (is_digit_char(*pos)) {
304+
const char * int_end = parse_int(pos);
305+
max_times = std::stoul(std::string(pos, int_end - pos));
306+
pos = parse_space(int_end, is_nested);
307+
}
308+
if (*pos != '}') {
309+
throw std::runtime_error(std::string("expecting '}' at ") + pos);
216310
}
217-
sub_rule.push_back({LLAMA_GRETYPE_END, 0});
218-
add_rule(state, sub_rule_id, sub_rule);
219-
220-
// in original rule, replace previous symbol with reference to generated rule
221-
out_elements.resize(last_sym_start);
222-
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
223-
224311
pos = parse_space(pos + 1, is_nested);
312+
handle_repetitions(min_times, max_times);
225313
} else {
226314
break;
227315
}

0 commit comments

Comments
 (0)