Skip to content

Commit 880c40f

Browse files
committed
Check for llama_get_logits_ith() errors
Embeddings models like BERT don't have logits. This caused the llamafile software to crash for users who tried to inference mxbai-embed-large-v1. This change potentially helps prevent the server from crashing. Since it is possible for this function to fail having callers check the result is a good idea from a defensive coding standpoint. The older exception code has also been refactored, since it's no longer needed.
1 parent 201cc11 commit 880c40f

File tree

8 files changed

+88
-61
lines changed

8 files changed

+88
-61
lines changed

common/sampling.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ static llama_token llama_sampling_sample_impl(
195195
llama_token id = 0;
196196
// Get a pointer to the logits
197197
float * logits = llama_get_logits_ith(ctx_main, idx);
198+
if (!logits) {
199+
throw std::runtime_error("llama_get_logits_ith failed");
200+
}
198201

199202
if (temp < 0.0) {
200203
// greedy sampling, with probs
@@ -284,6 +287,9 @@ static llama_token_data_array llama_sampling_prepare_impl(
284287

285288
// Get a pointer to the logits
286289
float * logits = llama_get_logits_ith(ctx_main, idx);
290+
if (!logits) {
291+
throw std::runtime_error("llama_get_logits_ith failed");
292+
}
287293

288294
if (ctx_sampling->grammar != NULL && !apply_grammar) {
289295
GGML_ASSERT(original_logits != NULL);
@@ -298,6 +304,9 @@ static llama_token_data_array llama_sampling_prepare_impl(
298304

299305
if (ctx_cfg) {
300306
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
307+
if (!logits_guidance) {
308+
throw std::runtime_error("llama_get_logits_ith failed");
309+
}
301310
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
302311
}
303312

examples/batched/batched.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ int main(int argc, char ** argv) {
169169

170170
auto n_vocab = llama_n_vocab(model);
171171
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
172+
if (!logits) {
173+
return 1;
174+
}
172175

173176
std::vector<llama_token_data> candidates;
174177
candidates.reserve(n_vocab);

examples/gritlm/gritlm.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
5858
// sum up all token embeddings
5959
for (int32_t k = n_inst; k < n_toks; k++) {
6060
float * emb = llama_get_embeddings_ith(ctx, k);
61+
if (!emb) {
62+
return 1;
63+
}
6164
for (uint64_t j = 0; j < n_embd; j++) {
6265
emb_unorm[j] += emb[j];
6366
}
@@ -114,6 +117,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
114117

115118
llama_decode(ctx, bat);
116119
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
120+
if (!logits) {
121+
throw std::runtime_error("llama_get_logits_ith failed");
122+
}
117123

118124
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
119125
auto n_candidates = (int32_t)candidates.size();

examples/llama.android/app/src/main/cpp/llama-android.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,9 @@ Java_com_example_llama_Llm_completion_1loop(
394394

395395
auto n_vocab = llama_n_vocab(model);
396396
auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
397+
if (!logits) {
398+
throw std::runtime_error("llama_get_logits_ith failed");
399+
}
397400

398401
std::vector<llama_token_data> candidates;
399402
candidates.reserve(n_vocab);

examples/passkey/passkey.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,9 @@ int main(int argc, char ** argv) {
239239
{
240240
auto n_vocab = llama_n_vocab(model);
241241
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
242+
if (!logits) {
243+
return 1;
244+
}
242245

243246
std::vector<llama_token_data> candidates;
244247
candidates.reserve(n_vocab);

examples/perplexity/perplexity.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
638638

639639
for (int seq = 0; seq < n_seq_batch; seq++) {
640640
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
641+
if (!all_logits) {
642+
return 1;
643+
}
641644

642645
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
643646
if (!params.logits_file.empty()) {

examples/simple/simple.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ int main(int argc, char ** argv) {
120120
{
121121
auto n_vocab = llama_n_vocab(model);
122122
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
123+
if (!logits) {
124+
return 1;
125+
}
123126

124127
std::vector<llama_token_data> candidates;
125128
candidates.reserve(n_vocab);

llama.cpp

Lines changed: 58 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -17301,42 +17301,39 @@ float * llama_get_logits(struct llama_context * ctx) {
1730117301
return ctx->logits;
1730217302
}
1730317303

17304+
static float * llama_get_logits_ith_fail(int i, std::string reason) {
17305+
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, reason.c_str());
17306+
#ifndef NDEBUG
17307+
GGML_ASSERT(false);
17308+
#endif
17309+
return nullptr;
17310+
}
17311+
1730417312
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
1730517313
int32_t j = -1;
1730617314
llama_synchronize(ctx);
17307-
17308-
try {
17309-
if (ctx->logits == nullptr) {
17310-
throw std::runtime_error("no logits");
17311-
}
17312-
17313-
if (i < 0) {
17314-
j = ctx->n_outputs + i;
17315-
if (j < 0) {
17316-
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
17317-
}
17318-
} else if ((size_t) i >= ctx->output_ids.size()) {
17319-
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
17320-
} else {
17321-
j = ctx->output_ids[i];
17322-
}
17323-
17315+
if (ctx->logits == nullptr) {
17316+
// this can happen for embeddings models like bert
17317+
return llama_get_logits_ith_fail(i, "no logits");
17318+
}
17319+
if (i < 0) {
17320+
j = ctx->n_outputs + i;
1732417321
if (j < 0) {
17325-
throw std::runtime_error(format("batch.logits[%d] != true", i));
17326-
}
17327-
if (j >= ctx->n_outputs) {
17328-
// This should not happen
17329-
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
17322+
return llama_get_logits_ith_fail(i, format("negative index out of range [0, %d)", ctx->n_outputs));
1733017323
}
17331-
17332-
return ctx->logits + j*ctx->model.hparams.n_vocab;
17333-
} catch (const std::exception & err) {
17334-
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
17335-
#ifndef NDEBUG
17336-
GGML_ASSERT(false);
17337-
#endif
17338-
return nullptr;
17324+
} else if ((size_t) i >= ctx->output_ids.size()) {
17325+
return llama_get_logits_ith_fail(i, format("out of range [0, %lu)", ctx->output_ids.size()));
17326+
} else {
17327+
j = ctx->output_ids[i];
1733917328
}
17329+
if (j < 0) {
17330+
return llama_get_logits_ith_fail(i, format("batch.logits[%d] != true", i));
17331+
}
17332+
if (j >= ctx->n_outputs) {
17333+
// This should not happen
17334+
return llama_get_logits_ith_fail(i, format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
17335+
}
17336+
return ctx->logits + j*ctx->model.hparams.n_vocab;
1734017337
}
1734117338

1734217339
float * llama_get_embeddings(struct llama_context * ctx) {
@@ -17345,43 +17342,43 @@ float * llama_get_embeddings(struct llama_context * ctx) {
1734517342
return ctx->embd;
1734617343
}
1734717344

17345+
static float * llama_get_embeddings_ith_fail(int i, std::string reason) {
17346+
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, reason.c_str());
17347+
#ifndef NDEBUG
17348+
GGML_ASSERT(false);
17349+
#endif
17350+
return nullptr;
17351+
}
17352+
1734817353
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
1734917354
int32_t j = -1;
17350-
1735117355
llama_synchronize(ctx);
17352-
17353-
try {
17354-
if (ctx->embd == nullptr) {
17355-
throw std::runtime_error("no embeddings");
17356-
}
17357-
17358-
if (i < 0) {
17359-
j = ctx->n_outputs + i;
17360-
if (j < 0) {
17361-
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
17362-
}
17363-
} else if ((size_t) i >= ctx->output_ids.size()) {
17364-
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
17365-
} else {
17366-
j = ctx->output_ids[i];
17367-
}
17368-
17356+
if (ctx->embd == nullptr) {
17357+
return llama_get_embeddings_ith_fail(i, "no embeddings");
17358+
}
17359+
if (i < 0) {
17360+
j = ctx->n_outputs + i;
1736917361
if (j < 0) {
17370-
throw std::runtime_error(format("batch.logits[%d] != true", i));
17371-
}
17372-
if (j >= ctx->n_outputs) {
17373-
// This should not happen
17374-
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
17362+
return llama_get_embeddings_ith_fail(
17363+
i, format("negative index out of range [0, %d)", ctx->n_outputs));
1737517364
}
17376-
17377-
return ctx->embd + j*ctx->model.hparams.n_embd;
17378-
} catch (const std::exception & err) {
17379-
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
17380-
#ifndef NDEBUG
17381-
GGML_ASSERT(false);
17382-
#endif
17383-
return nullptr;
17365+
} else if ((size_t) i >= ctx->output_ids.size()) {
17366+
return llama_get_embeddings_ith_fail(
17367+
i, format("out of range [0, %lu)", ctx->output_ids.size()));
17368+
} else {
17369+
j = ctx->output_ids[i];
17370+
}
17371+
if (j < 0) {
17372+
return llama_get_embeddings_ith_fail(
17373+
i, format("batch.logits[%d] != true", i));
17374+
}
17375+
if (j >= ctx->n_outputs) {
17376+
// This should not happen
17377+
return llama_get_embeddings_ith_fail(
17378+
i, format("corrupt output buffer (j=%d, n_outputs=%d)",
17379+
j, ctx->n_outputs));
1738417380
}
17381+
return ctx->embd + j*ctx->model.hparams.n_embd;
1738517382
}
1738617383

1738717384
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {

0 commit comments

Comments
 (0)