|
52 | 52 | #include <algorithm>
|
53 | 53 | #include <array>
|
54 | 54 | #include <cassert>
|
| 55 | +#include <cfloat> |
55 | 56 | #include <cinttypes>
|
56 | 57 | #include <climits>
|
57 | 58 | #include <cmath>
|
@@ -8007,21 +8008,56 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
|
8007 | 8008 | return;
|
8008 | 8009 | }
|
8009 | 8010 |
|
8010 |
| - llama_sample_softmax(ctx, candidates); |
8011 |
| - |
8012 | 8011 | const int64_t t_start_sample_us = ggml_time_us();
|
8013 | 8012 |
|
8014 |
| - float scale = candidates->data[0].p; // scale by max prob |
8015 |
| - size_t i = 1; // first token always matches |
| 8013 | + bool min_p_applied = false; |
| 8014 | + |
| 8015 | + // if the candidates aren't sorted, try the unsorted implementation first |
| 8016 | + if (!candidates->sorted) { |
| 8017 | + std::vector<llama_token_data> filtered_tokens; |
| 8018 | + |
| 8019 | + float max_logit = -FLT_MAX; |
| 8020 | + for (size_t i = 0; i < candidates->size; ++i) { |
| 8021 | + max_logit = std::max(max_logit, candidates->data[i].logit); |
| 8022 | + } |
| 8023 | + const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max |
| 8024 | + |
| 8025 | + for (size_t i = 0; i < candidates->size; ++i) { |
| 8026 | + if (candidates->data[i].logit >= min_logit) { |
| 8027 | + filtered_tokens.push_back(candidates->data[i]); |
| 8028 | + } |
| 8029 | + } |
8016 | 8030 |
|
8017 |
| - for (; i < candidates->size; ++i) { |
8018 |
| - if (candidates->data[i].p < p * scale && i >= min_keep) { |
8019 |
| - break; // prob too small |
| 8031 | + // if we have enough values the operation was a success |
| 8032 | + if (filtered_tokens.size() >= min_keep) { |
| 8033 | + memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); |
| 8034 | + candidates->size = filtered_tokens.size(); |
| 8035 | + min_p_applied = true; |
8020 | 8036 | }
|
8021 | 8037 | }
|
8022 | 8038 |
|
8023 |
| - // Resize the output vector to keep only the matching tokens |
8024 |
| - candidates->size = i; |
| 8039 | + // if the candidates are sorted or the unsorted implementation failed, use this implementation |
| 8040 | + if (!min_p_applied) { |
| 8041 | + // Sort the logits in descending order |
| 8042 | + if (!candidates->sorted) { |
| 8043 | + std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { |
| 8044 | + return a.logit > b.logit; |
| 8045 | + }); |
| 8046 | + candidates->sorted = true; |
| 8047 | + } |
| 8048 | + |
| 8049 | + const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max |
| 8050 | + size_t i = 1; // first token always matches |
| 8051 | + |
| 8052 | + for (; i < candidates->size; ++i) { |
| 8053 | + if (candidates->data[i].logit < min_logit && i >= min_keep) { |
| 8054 | + break; // prob too small |
| 8055 | + } |
| 8056 | + } |
| 8057 | + |
| 8058 | + // Resize the output vector to keep only the matching tokens |
| 8059 | + candidates->size = i; |
| 8060 | + } |
8025 | 8061 |
|
8026 | 8062 | if (ctx) {
|
8027 | 8063 | ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
0 commit comments