Skip to content

llama : fix indentation in llama-grammar [no ci] #11943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 19, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
320 changes: 160 additions & 160 deletions src/llama-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,194 +345,194 @@ const char * llama_grammar_parser::parse_sequence(
size_t last_sym_start = rule.size();
const char * pos = src;

auto handle_repetitions = [&](int min_times, int max_times) {
auto handle_repetitions = [&](int min_times, int max_times) {

if (last_sym_start == rule.size()) {
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
}
if (last_sym_start == rule.size()) {
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
}

// apply transformation to previous symbol (last_sym_start to end) according to
// the following rewrite rules:
// S{m,n} --> S S S (m times) S'(n-m)
// S'(x) ::= S S'(x-1) |
// (... n-m definitions of these S' rules ...)
// S'(1) ::= S |
// S{m,} --> S S S (m times) S'
// S' ::= S S' |
// S* --> S{0,}
// --> S' ::= S S' |
// S+ --> S{1,}
// --> S S'
// S' ::= S S' |
// S? --> S{0,1}
// --> S'
// S' ::= S |

llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
if (min_times == 0) {
rule.resize(last_sym_start);
} else {
// Repeat the previous elements (min_times - 1) times
for (int i = 1; i < min_times; i++) {
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
}
// apply transformation to previous symbol (last_sym_start to end) according to
// the following rewrite rules:
// S{m,n} --> S S S (m times) S'(n-m)
// S'(x) ::= S S'(x-1) |
// (... n-m definitions of these S' rules ...)
// S'(1) ::= S |
// S{m,} --> S S S (m times) S'
// S' ::= S S' |
// S* --> S{0,}
// --> S' ::= S S' |
// S+ --> S{1,}
// --> S S'
// S' ::= S S' |
// S? --> S{0,1}
// --> S'
// S' ::= S |

llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
if (min_times == 0) {
rule.resize(last_sym_start);
} else {
// Repeat the previous elements (min_times - 1) times
for (int i = 1; i < min_times; i++) {
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
}
}

uint32_t last_rec_rule_id = 0;
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
uint32_t last_rec_rule_id = 0;
auto n_opt = max_times < 0 ? 1 : max_times - min_times;

llama_grammar_rule rec_rule(prev_rule);
for (int i = 0; i < n_opt; i++) {
rec_rule.resize(prev_rule.size());
uint32_t rec_rule_id = generate_symbol_id( rule_name);
if (i > 0 || max_times < 0) {
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
}
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
add_rule( rec_rule_id, rec_rule);
last_rec_rule_id = rec_rule_id;
}
if (n_opt > 0) {
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
llama_grammar_rule rec_rule(prev_rule);
for (int i = 0; i < n_opt; i++) {
rec_rule.resize(prev_rule.size());
uint32_t rec_rule_id = generate_symbol_id( rule_name);
if (i > 0 || max_times < 0) {
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
}
};
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
add_rule( rec_rule_id, rec_rule);
last_rec_rule_id = rec_rule_id;
}
if (n_opt > 0) {
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
}
};

while (*pos) {
if (*pos == '"') { // literal string
pos++;
last_sym_start = rule.size();
while (*pos != '"') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
}
auto char_pair = parse_char(pos);
pos = char_pair.second;
rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
while (*pos) {
if (*pos == '"') { // literal string
pos++;
last_sym_start = rule.size();
while (*pos != '"') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '[') { // char range(s)
auto char_pair = parse_char(pos);
pos = char_pair.second;
rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '[') { // char range(s)
pos++;
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
if (*pos == '^') {
pos++;
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
if (*pos == '^') {
pos++;
start_type = LLAMA_GRETYPE_CHAR_NOT;
start_type = LLAMA_GRETYPE_CHAR_NOT;
}
last_sym_start = rule.size();
while (*pos != ']') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
}
last_sym_start = rule.size();
while (*pos != ']') {
if (!*pos) {
auto char_pair = parse_char(pos);
pos = char_pair.second;
enum llama_gretype type = last_sym_start < rule.size()
? LLAMA_GRETYPE_CHAR_ALT
: start_type;

rule.push_back({type, char_pair.first});
if (pos[0] == '-' && pos[1] != ']') {
if (!pos[1]) {
throw std::runtime_error("unexpected end of input");
}
auto char_pair = parse_char(pos);
pos = char_pair.second;
enum llama_gretype type = last_sym_start < rule.size()
? LLAMA_GRETYPE_CHAR_ALT
: start_type;

rule.push_back({type, char_pair.first});
if (pos[0] == '-' && pos[1] != ']') {
if (!pos[1]) {
throw std::runtime_error("unexpected end of input");
}
auto endchar_pair = parse_char(pos + 1);
pos = endchar_pair.second;
rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
}
}
pos = parse_space(pos + 1, is_nested);
} else if (is_word_char(*pos)) { // rule reference
const char * name_end = parse_name(pos);
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
pos = parse_space(name_end, is_nested);
last_sym_start = rule.size();
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
} else if (*pos == '(') { // grouping
// parse nested alternates into synthesized rule
pos = parse_space(pos + 1, true);
uint32_t sub_rule_id = generate_symbol_id(rule_name);
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
last_sym_start = rule.size();
// output reference to synthesized rule
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
if (*pos != ')') {
throw std::runtime_error(std::string("expecting ')' at ") + pos);
auto endchar_pair = parse_char(pos + 1);
pos = endchar_pair.second;
rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
}
}
pos = parse_space(pos + 1, is_nested);
} else if (is_word_char(*pos)) { // rule reference
const char * name_end = parse_name(pos);
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
pos = parse_space(name_end, is_nested);
last_sym_start = rule.size();
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
} else if (*pos == '(') { // grouping
// parse nested alternates into synthesized rule
pos = parse_space(pos + 1, true);
uint32_t sub_rule_id = generate_symbol_id(rule_name);
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
last_sym_start = rule.size();
// output reference to synthesized rule
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
if (*pos != ')') {
throw std::runtime_error(std::string("expecting ')' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '.') { // any char
last_sym_start = rule.size();
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '*') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, -1);
} else if (*pos == '+') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(1, -1);
} else if (*pos == '?') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, 1);
} else if (*pos == '{') {
pos = parse_space(pos + 1, is_nested);

if (!is_digit_char(*pos)) {
throw std::runtime_error(std::string("expecting an int at ") + pos);
}
const char * int_end = parse_int(pos);
int min_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);

int max_times = -1;

if (*pos == '}') {
max_times = min_times;
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '.') { // any char
last_sym_start = rule.size();
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '*') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, -1);
} else if (*pos == '+') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(1, -1);
} else if (*pos == '?') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, 1);
} else if (*pos == '{') {
} else if (*pos == ',') {
pos = parse_space(pos + 1, is_nested);

if (!is_digit_char(*pos)) {
throw std::runtime_error(std::string("expecting an int at ") + pos);
if (is_digit_char(*pos)) {
const char * int_end = parse_int(pos);
max_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
}
const char * int_end = parse_int(pos);
int min_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);

int max_times = -1;

if (*pos == '}') {
max_times = min_times;
pos = parse_space(pos + 1, is_nested);
} else if (*pos == ',') {
pos = parse_space(pos + 1, is_nested);

if (is_digit_char(*pos)) {
const char * int_end = parse_int(pos);
max_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
}

if (*pos != '}') {
throw std::runtime_error(std::string("expecting '}' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else {
throw std::runtime_error(std::string("expecting ',' at ") + pos);
if (*pos != '}') {
throw std::runtime_error(std::string("expecting '}' at ") + pos);
}
handle_repetitions(min_times, max_times);
pos = parse_space(pos + 1, is_nested);
} else {
break;
throw std::runtime_error(std::string("expecting ',' at ") + pos);
}
handle_repetitions(min_times, max_times);
} else {
break;
}
return pos;
}
return pos;
}

const char * llama_grammar_parser::parse_rule(const char * src) {
const char * name_end = parse_name(src);
const char * pos = parse_space(name_end, false);
size_t name_len = name_end - src;
uint32_t rule_id = get_symbol_id(src, name_len);
const std::string name(src, name_len);

if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
throw std::runtime_error(std::string("expecting ::= at ") + pos);
}
pos = parse_space(pos + 3, true);
const char * name_end = parse_name(src);
const char * pos = parse_space(name_end, false);
size_t name_len = name_end - src;
uint32_t rule_id = get_symbol_id(src, name_len);
const std::string name(src, name_len);

if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
throw std::runtime_error(std::string("expecting ::= at ") + pos);
}
pos = parse_space(pos + 3, true);

pos = parse_alternates(pos, name, rule_id, false);
pos = parse_alternates(pos, name, rule_id, false);

if (*pos == '\r') {
pos += pos[1] == '\n' ? 2 : 1;
} else if (*pos == '\n') {
pos++;
} else if (*pos) {
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
}
return parse_space(pos, true);
if (*pos == '\r') {
pos += pos[1] == '\n' ? 2 : 1;
} else if (*pos == '\n') {
pos++;
} else if (*pos) {
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
}
return parse_space(pos, true);
}

bool llama_grammar_parser::parse(const char * src) {
try {
Expand Down