Skip to content

Commit 650adf1

Browse files
committed
grammar : clean-up
ggml-ci
1 parent 29f712e commit 650adf1

File tree

6 files changed

+43
-48
lines changed

6 files changed

+43
-48
lines changed

examples/gbnf-validator/gbnf-validator.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include "ggml.h"
22
#include "llama.h"
3-
#include "llama-vocab.h" // TMP
43
#include "llama-grammar.h"
54
#include "unicode.h"
65

@@ -11,7 +10,7 @@
1110
#include <string>
1211
#include <vector>
1312

14-
static bool llama_sample_grammar_string(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
13+
static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
1514
auto decoded = decode_utf8(input_str, {});
1615
const auto & code_points = decoded.first;
1716

@@ -22,7 +21,7 @@ static bool llama_sample_grammar_string(struct llama_grammar * grammar, const st
2221
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
2322
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
2423

25-
llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
24+
cur_stacks = llama_grammar_accept(rules, prev_stacks, *it);
2625

2726
if (cur_stacks.empty()) {
2827
error_pos = pos;
@@ -84,8 +83,7 @@ int main(int argc, char** argv) {
8483
grammar_str = buffer.str();
8584
}
8685

87-
llama_vocab vocab; // TMP
88-
llama_grammar * grammar = llama_grammar_init_impl(vocab, grammar_str.c_str(), "root");
86+
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
8987
if (grammar == nullptr) {
9088
throw std::runtime_error("Failed to initialize llama_grammar");
9189
}
@@ -102,7 +100,7 @@ int main(int argc, char** argv) {
102100
// Validate the input string against the grammar
103101
size_t error_pos;
104102
std::string error_msg;
105-
bool is_valid = llama_sample_grammar_string(grammar, input_str, error_pos, error_msg);
103+
bool is_valid = llama_grammar_validate(grammar, input_str, error_pos, error_msg);
106104

107105
if (is_valid) {
108106
fprintf(stdout, "Input string is valid according to the grammar.\n");

src/llama-grammar.cpp

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -754,10 +754,10 @@ static bool llama_grammar_detect_left_recursion(
754754
}
755755

756756
struct llama_grammar * llama_grammar_init_impl(
757-
const struct llama_vocab & vocab,
757+
const struct llama_vocab * vocab,
758758
const llama_grammar_element ** rules,
759-
size_t n_rules,
760-
size_t start_rule_index) {
759+
size_t n_rules,
760+
size_t start_rule_index) {
761761
const llama_grammar_element * pos;
762762

763763
// copy rule definitions into vectors
@@ -808,10 +808,10 @@ struct llama_grammar * llama_grammar_init_impl(
808808
// Important: vec_rules has to be moved here, not copied, because stacks contains
809809
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
810810
// then the pointers would be invalidated when the local vec_rules goes out of scope.
811-
return new llama_grammar{ vocab, std::move(vec_rules), std::move(stacks), {}, 0, 0, 0 };
811+
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, 0, 0, 0 };
812812
}
813813

814-
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
814+
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
815815
llama_grammar_parser parser;
816816

817817
// if there is a grammar, parse it
@@ -886,15 +886,15 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab & vocab,
886886
// Important: vec_rules has to be moved here, not copied, because stacks contains
887887
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
888888
// then the pointers would be invalidated when the local vec_rules goes out of scope.
889-
return new llama_grammar{ vocab, std::move(vec_rules), std::move(stacks), {}, 0, 0, 0 };
889+
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, 0, 0, 0 };
890890
}
891891

892892
void llama_grammar_free_impl(struct llama_grammar * grammar) {
893893
delete grammar;
894894
}
895895

896896
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & grammar) {
897-
llama_grammar * result = new llama_grammar{ grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, 0, 0, 0 };
897+
llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, 0, 0, 0 };
898898

899899
// redirect elements in stacks to point to new rules
900900
for (size_t is = 0; is < result->stacks.size(); is++) {
@@ -913,6 +913,8 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & gram
913913
}
914914

915915
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * candidates) {
916+
GGML_ASSERT(grammar.vocab != nullptr);
917+
916918
bool allow_eog = false;
917919
for (const auto & stack : grammar.stacks) {
918920
if (stack.empty()) {
@@ -929,9 +931,9 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
929931

930932
for (size_t i = 0; i < candidates->size; ++i) {
931933
const llama_token id = candidates->data[i].id;
932-
const std::string & piece = grammar.vocab.cache_token_to_piece.at(id);
934+
const std::string & piece = grammar.vocab->cache_token_to_piece.at(id);
933935

934-
if (llama_token_is_eog_impl(grammar.vocab, id)) {
936+
if (llama_token_is_eog_impl(*grammar.vocab, id)) {
935937
if (!allow_eog) {
936938
candidates->data[i].logit = -INFINITY;
937939
}
@@ -950,7 +952,9 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
950952
}
951953

952954
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
953-
if (llama_token_is_eog_impl(grammar.vocab, token)) {
955+
GGML_ASSERT(grammar.vocab != nullptr);
956+
957+
if (llama_token_is_eog_impl(*grammar.vocab, token)) {
954958
for (const auto & stack : grammar.stacks) {
955959
if (stack.empty()) {
956960
return;
@@ -959,16 +963,15 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
959963
GGML_ABORT("fatal error");
960964
}
961965

962-
const std::string & piece = grammar.vocab.cache_token_to_piece.at(token);
966+
const std::string & piece = grammar.vocab->cache_token_to_piece.at(token);
963967

964968
// Note terminating 0 in decoded string
965969
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
966970
const auto & code_points = decoded.first;
967971

968-
llama_grammar_stacks tmp_new_stacks;
969972
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
970-
llama_grammar_accept(grammar.rules, grammar.stacks, *it, tmp_new_stacks);
971-
grammar.stacks = tmp_new_stacks;
973+
llama_grammar_stacks new_stacks = llama_grammar_accept(grammar.rules, grammar.stacks, *it);
974+
grammar.stacks = std::move(new_stacks);
972975
}
973976

974977
grammar.partial_utf8 = decoded.second;
@@ -1045,12 +1048,12 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
10451048
return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
10461049
}
10471050

1048-
void llama_grammar_accept(
1051+
llama_grammar_stacks llama_grammar_accept(
10491052
const llama_grammar_rules & rules,
10501053
const llama_grammar_stacks & stacks,
1051-
const uint32_t chr,
1052-
llama_grammar_stacks & new_stacks) {
1053-
new_stacks.clear();
1054+
const uint32_t chr) {
1055+
llama_grammar_stacks result;
1056+
result.reserve(stacks.size());
10541057

10551058
for (const auto & stack : stacks) {
10561059
if (stack.empty()) {
@@ -1066,9 +1069,11 @@ void llama_grammar_accept(
10661069
if (!llama_grammar_is_end_of_sequence(pos)) {
10671070
new_stack.push_back(pos);
10681071
}
1069-
llama_grammar_advance_stack(rules, new_stack, new_stacks);
1072+
llama_grammar_advance_stack(rules, new_stack, result);
10701073
}
10711074
}
1075+
1076+
return result;
10721077
}
10731078

10741079
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(

src/llama-grammar.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,10 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
7171
// be positioned at a character range (see `llama_grammar_advance_stack`), and
7272
// produces the N possible stacks if the given char is accepted at those
7373
// positions
74-
void llama_grammar_accept(
74+
llama_grammar_stacks llama_grammar_accept(
7575
const llama_grammar_rules & rules,
7676
const llama_grammar_stacks & stacks,
77-
const uint32_t chr,
78-
llama_grammar_stacks & new_stacks);
77+
uint32_t chr);
7978

8079
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
8180
const llama_grammar_rules & rules,
@@ -113,7 +112,8 @@ struct llama_grammar_parser {
113112
};
114113

115114
struct llama_grammar {
116-
const llama_vocab & vocab;
115+
// note: allow null vocab for testing (not great)
116+
const llama_vocab * vocab;
117117

118118
const llama_grammar_rules rules; // TODO: shared ptr
119119
llama_grammar_stacks stacks;
@@ -131,14 +131,14 @@ struct llama_grammar {
131131
// internal API
132132
//
133133

134-
// TODO: temporary until the tests are fixed
134+
// note: needed for tests (not great)
135135
struct llama_grammar * llama_grammar_init_impl(
136-
const struct llama_vocab & vocab,
136+
const struct llama_vocab * vocab,
137137
const llama_grammar_element ** rules,
138138
size_t n_rules,
139139
size_t start_rule_index);
140140

141-
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root);
141+
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);
142142

143143
void llama_grammar_free_impl(struct llama_grammar * grammar);
144144

src/llama-sampling.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ void llama_sampling_reset_impl(struct llama_sampling & smpl) {
7676
}
7777

7878
if (!smpl.grammar_str.empty()) {
79-
smpl.grammar = llama_grammar_init_impl(smpl.vocab, smpl.grammar_str.data(), smpl.grammar_root.data());
79+
smpl.grammar = llama_grammar_init_impl(&smpl.vocab, smpl.grammar_str.data(), smpl.grammar_root.data());
8080
}
8181

8282
smpl.prev.clear();
@@ -100,7 +100,7 @@ void llama_sampling_set_grammar_impl(struct llama_sampling & smpl, const char *
100100
smpl.grammar_str = grammar_str;
101101
smpl.grammar_root = grammar_root;
102102

103-
smpl.grammar = llama_grammar_init_impl(smpl.vocab, grammar_str, grammar_root);
103+
smpl.grammar = llama_grammar_init_impl(&smpl.vocab, grammar_str, grammar_root);
104104
} else {
105105
smpl.grammar_str.clear();
106106
smpl.grammar_root.clear();

tests/test-grammar-integration.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,7 @@
22
#undef NDEBUG
33
#endif
44

5-
#include "ggml.h"
6-
#include "llama.h"
7-
#include "llama-vocab.h" // TMP
85
#include "llama-grammar.h"
9-
#include "unicode.h"
106
#include "json-schema-to-grammar.h"
117

128
#include <cassert>
@@ -15,10 +11,8 @@
1511

1612
using json = nlohmann::ordered_json;
1713

18-
llama_vocab vocab; // TMP
19-
2014
static llama_grammar * build_grammar(const std::string & grammar_str) {
21-
return llama_grammar_init_impl(vocab, grammar_str.c_str(), "root");
15+
return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
2216
}
2317

2418
static bool test_build_grammar_fails(const std::string & grammar_str) {
@@ -45,7 +39,7 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {
4539
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
4640
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
4741

48-
llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
42+
cur_stacks = llama_grammar_accept(rules, prev_stacks, *it);
4943

5044
if (cur_stacks.empty()) {
5145
// no stacks means that the grammar failed to match at this point

tests/test-llama-grammar.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#endif
44

55
#include "llama.h"
6-
#include "llama-vocab.h" // TMP
76
#include "llama-grammar.h"
87

98
#include <cassert>
@@ -117,8 +116,7 @@ int main()
117116
llama_grammar * grammar = NULL;
118117
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
119118

120-
llama_vocab vocab; // TMP
121-
grammar = llama_grammar_init_impl(vocab, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
119+
grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
122120
if (grammar == nullptr)
123121
{
124122
throw std::runtime_error("Failed to initialize llama_grammar");
@@ -175,13 +173,13 @@ int main()
175173
}};
176174

177175
auto index = 0;
178-
for (auto stack : llama_grammar_get_stacks(grammar))
176+
for (const llama_grammar_stack & stack : llama_grammar_get_stacks(grammar))
179177
{
180178
// compare stack to expected_stack
181179
for (uint32_t i = 0; i < stack.size(); i++)
182180
{
183-
auto element = stack[i];
184-
auto expected_element = expected_stacks[index][i];
181+
const llama_grammar_element * element = stack[i];
182+
const llama_grammar_element & expected_element = expected_stacks[index][i];
185183

186184
// pretty print error message before asserting
187185
if (expected_element.type != element->type || expected_element.value != element->value)

0 commit comments

Comments
 (0)