Skip to content

Commit a948952

Browse files
committed
Make sampling not throw exception
1 parent aa3094c commit a948952

File tree

9 files changed

+36
-4
lines changed

9 files changed

+36
-4
lines changed

common/sampling.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,13 +189,16 @@ static llama_token llama_sampling_sample_impl(
189189

190190
std::vector<float> original_logits;
191191
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
192+
if (cur_p.data == NULL) {
193+
return -1;
194+
}
192195
if (ctx_sampling->grammar != NULL && !is_resampling) {
193196
GGML_ASSERT(!original_logits.empty());
194197
}
195198
llama_token id = 0;
196199
// Get a pointer to the logits
197200
float * logits = llama_get_logits_ith(ctx_main, idx);
198-
GGML_ASSERT(logits); // already checked in llama_sampling_prepare
201+
GGML_ASSERT(logits); // already checked in llama_sampling_prepare
199202

200203
if (temp < 0.0) {
201204
// greedy sampling, with probs
@@ -286,7 +289,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
286289
// Get a pointer to the logits
287290
float * logits = llama_get_logits_ith(ctx_main, idx);
288291
if (!logits) {
289-
throw std::runtime_error("llama_get_logits_ith failed");
292+
return {NULL, 0, false};
290293
}
291294

292295
if (ctx_sampling->grammar != NULL && !apply_grammar) {
@@ -303,7 +306,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
303306
if (ctx_cfg) {
304307
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
305308
if (!logits_guidance) {
306-
throw std::runtime_error("llama_get_logits_ith failed");
309+
return {NULL, 0, false};
307310
}
308311
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
309312
}

examples/infill/infill.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,9 @@ int main(int argc, char ** argv) {
530530
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
531531

532532
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
533+
if (id == -1) {
534+
return 1;
535+
}
533536

534537
llama_sampling_accept(ctx_sampling, ctx, id, true);
535538

examples/llava/llava-cli.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling,
4444
struct llama_context * ctx_llama,
4545
int * n_past) {
4646
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
47+
GGML_ASSERT(id != -1);
4748
llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
4849
static std::string ret;
4950
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {

examples/lookahead/lookahead.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ int main(int argc, char ** argv) {
159159
// sample first token
160160
{
161161
id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
162+
if (id == -1) {
163+
return 1;
164+
}
162165

163166
llama_sampling_accept(ctx_sampling, ctx, id, true);
164167

@@ -284,6 +287,9 @@ int main(int argc, char ** argv) {
284287

285288
// sample the next token
286289
id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
290+
if (id == -1) {
291+
return 1;
292+
}
287293

288294
llama_sampling_accept(ctx_sampling, ctx, id, true);
289295

@@ -361,6 +367,9 @@ int main(int argc, char ** argv) {
361367
// sample from the last level
362368
for (int i = 0; i < W; i++) {
363369
tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
370+
if (tokens_j[N - 2][i] == -1) {
371+
return 1;
372+
}
364373
}
365374
} else {
366375
for (int i = 0; i < W; i++) {

examples/lookup/lookup.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ int main(int argc, char ** argv){
131131
while (true) {
132132
// sample from the target model
133133
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
134+
GGML_ASSERT(id != -1);
134135

135136
llama_sampling_accept(ctx_sampling, ctx, id, true);
136137

examples/main/main.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,9 @@ int main(int argc, char ** argv) {
706706
}
707707

708708
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
709+
if (id == -1) {
710+
return 1;
711+
}
709712

710713
llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);
711714

examples/parallel/parallel.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ int main(int argc, char ** argv) {
341341
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
342342

343343
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
344+
GGML_ASSERT(id != -1);
344345

345346
llama_sampling_accept(client.ctx_sampling, ctx, id, true);
346347

examples/server/server.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2257,6 +2257,9 @@ struct server_context {
22572257

22582258
completion_token_output result;
22592259
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
2260+
if (id == -1) {
2261+
continue; // keep going, don't crash, already logged
2262+
}
22602263

22612264
llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
22622265

examples/speculative/speculative.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,9 @@ int main(int argc, char ** argv) {
229229
// stochastic verification
230230

231231
llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
232+
if (dist_tgt.data == NULL) {
233+
return 1;
234+
}
232235
llama_sample_softmax(ctx_tgt, &dist_tgt);
233236
float p_tgt = 0, p_dft = 0;
234237

@@ -337,6 +340,9 @@ int main(int argc, char ** argv) {
337340
// sample from the target model
338341
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
339342
token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
343+
if (token_id == -1) {
344+
return 1;
345+
}
340346

341347
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
342348

@@ -457,7 +463,9 @@ int main(int argc, char ** argv) {
457463
continue;
458464
}
459465

460-
llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
466+
if (llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft) == -1) {
467+
return -1;
468+
}
461469

462470
const auto & cur_p = drafts[s].ctx_sampling->cur;
463471

0 commit comments

Comments
 (0)