Skip to content

Commit f6ad32d

Browse files
Apply min_p to unsorted tokens
1 parent c9b316c commit f6ad32d

File tree

1 file changed

+45
-9
lines changed

1 file changed

+45
-9
lines changed

llama.cpp

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include <algorithm>
5353
#include <array>
5454
#include <cassert>
55+
#include <cfloat>
5556
#include <cinttypes>
5657
#include <climits>
5758
#include <cmath>
@@ -8007,21 +8008,56 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
80078008
return;
80088009
}
80098010

8010-
llama_sample_softmax(ctx, candidates);
8011-
80128011
const int64_t t_start_sample_us = ggml_time_us();
80138012

8014-
float scale = candidates->data[0].p; // scale by max prob
8015-
size_t i = 1; // first token always matches
8013+
bool min_p_applied = false;
8014+
8015+
// if the candidates aren't sorted, try the unsorted implementation first
8016+
if (!candidates->sorted) {
8017+
std::vector<llama_token_data> filtered_tokens;
8018+
8019+
float max_logit = -FLT_MAX;
8020+
for (size_t i = 0; i < candidates->size; ++i) {
8021+
max_logit = std::max(max_logit, candidates->data[i].logit);
8022+
}
8023+
const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
8024+
8025+
for (size_t i = 0; i < candidates->size; ++i) {
8026+
if (candidates->data[i].logit >= min_logit) {
8027+
filtered_tokens.push_back(candidates->data[i]);
8028+
}
8029+
}
80168030

8017-
for (; i < candidates->size; ++i) {
8018-
if (candidates->data[i].p < p * scale && i >= min_keep) {
8019-
break; // prob too small
8031+
// if we have enough values the operation was a success
8032+
if (filtered_tokens.size() >= min_keep) {
8033+
memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
8034+
candidates->size = filtered_tokens.size();
8035+
min_p_applied = true;
80208036
}
80218037
}
80228038

8023-
// Resize the output vector to keep only the matching tokens
8024-
candidates->size = i;
8039+
// if the candidates are sorted or the unsorted implementation failed, use this implementation
8040+
if (!min_p_applied) {
8041+
// Sort the logits in descending order
8042+
if (!candidates->sorted) {
8043+
std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
8044+
return a.logit > b.logit;
8045+
});
8046+
candidates->sorted = true;
8047+
}
8048+
8049+
const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
8050+
size_t i = 1; // first token always matches
8051+
8052+
for (; i < candidates->size; ++i) {
8053+
if (candidates->data[i].logit < min_logit && i >= min_keep) {
8054+
break; // prob too small
8055+
}
8056+
}
8057+
8058+
// Resize the output vector to keep only the matching tokens
8059+
candidates->size = i;
8060+
}
80258061

80268062
if (ctx) {
80278063
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;

0 commit comments

Comments
 (0)