Skip to content

Commit c8b698e

Browse files
JohannesGaesslerhodlen
authored andcommitted
Apply min_p to unsorted tokens (ggml-org#5115)
1 parent 0280fa9 commit c8b698e

File tree

1 file changed

+45
-9
lines changed

1 file changed

+45
-9
lines changed

llama.cpp

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include <algorithm>
5353
#include <array>
5454
#include <cassert>
55+
#include <cfloat>
5556
#include <cinttypes>
5657
#include <climits>
5758
#include <cmath>
@@ -8246,21 +8247,56 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
82468247
return;
82478248
}
82488249

8249-
llama_sample_softmax(ctx, candidates);
8250-
82518250
const int64_t t_start_sample_us = ggml_time_us();
82528251

8253-
float scale = candidates->data[0].p; // scale by max prob
8254-
size_t i = 1; // first token always matches
8252+
bool min_p_applied = false;
8253+
8254+
// if the candidates aren't sorted, try the unsorted implementation first
8255+
if (!candidates->sorted) {
8256+
std::vector<llama_token_data> filtered_tokens;
8257+
8258+
float max_logit = -FLT_MAX;
8259+
for (size_t i = 0; i < candidates->size; ++i) {
8260+
max_logit = std::max(max_logit, candidates->data[i].logit);
8261+
}
8262+
const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
8263+
8264+
for (size_t i = 0; i < candidates->size; ++i) {
8265+
if (candidates->data[i].logit >= min_logit) {
8266+
filtered_tokens.push_back(candidates->data[i]);
8267+
}
8268+
}
82558269

8256-
for (; i < candidates->size; ++i) {
8257-
if (candidates->data[i].p < p * scale && i >= min_keep) {
8258-
break; // prob too small
8270+
// if we have enough values the operation was a success
8271+
if (filtered_tokens.size() >= min_keep) {
8272+
memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
8273+
candidates->size = filtered_tokens.size();
8274+
min_p_applied = true;
82598275
}
82608276
}
82618277

8262-
// Resize the output vector to keep only the matching tokens
8263-
candidates->size = i;
8278+
// if the candidates are sorted or the unsorted implementation failed, use this implementation
8279+
if (!min_p_applied) {
8280+
// Sort the logits in descending order
8281+
if (!candidates->sorted) {
8282+
std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
8283+
return a.logit > b.logit;
8284+
});
8285+
candidates->sorted = true;
8286+
}
8287+
8288+
const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
8289+
size_t i = 1; // first token always matches
8290+
8291+
for (; i < candidates->size; ++i) {
8292+
if (candidates->data[i].logit < min_logit && i >= min_keep) {
8293+
break; // prob too small
8294+
}
8295+
}
8296+
8297+
// Resize the output vector to keep only the matching tokens
8298+
candidates->size = i;
8299+
}
82648300

82658301
if (ctx) {
82668302
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;

0 commit comments

Comments
 (0)