Skip to content

Commit 35d9680

Browse files
committed
squash! llama : return nullptr from llama_grammar_init
Add checks for nullptr when calling llama_grammar_init. Signed-off-by: Daniel Bevenius <[email protected]>
1 parent 6189bce commit 35d9680

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

common/sampling.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,13 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
2828

2929
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
3030

31-
result->grammar = llama_grammar_init(
31+
struct llama_grammar * grammar = llama_grammar_init(
3232
grammar_rules.data(),
3333
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
34+
if (grammar == nullptr) {
35+
throw std::runtime_error("Failed to initialize llama_grammar");
36+
}
37+
result->grammar = grammar;
3438
}
3539

3640
result->prev.resize(params.n_prev);
@@ -59,9 +63,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
5963
if (!ctx->parsed_grammar.rules.empty()) {
6064
std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
6165

62-
ctx->grammar = llama_grammar_init(
66+
struct llama_grammar * grammar = llama_grammar_init(
6367
grammar_rules.data(),
6468
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
69+
if (grammar == nullptr) {
70+
throw std::runtime_error("Failed to initialize llama_grammar");
71+
}
72+
ctx->grammar = grammar;
6573
}
6674

6775
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);

examples/gbnf-validator/gbnf-validator.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ int main(int argc, char** argv) {
101101
auto grammar = llama_grammar_init(
102102
grammar_rules.data(),
103103
grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
104-
104+
if (grammar == nullptr) {
105+
throw std::runtime_error("Failed to initialize llama_grammar");
106+
}
105107
// Read the input file
106108
std::string input_str;
107109
{

tests/test-llama-grammar.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ int main()
116116
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
117117
grammar = llama_grammar_init(
118118
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
119+
if (grammar == nullptr)
120+
{
121+
throw std::runtime_error("Failed to initialize llama_grammar");
122+
}
119123

120124
std::vector<std::vector<llama_grammar_element>> expected_stacks = {
121125
{

0 commit comments

Comments
 (0)