|
52 | 52 | #include <algorithm>
|
53 | 53 | #include <array>
|
54 | 54 | #include <cassert>
|
| 55 | +#include <cfloat> |
55 | 56 | #include <cinttypes>
|
56 | 57 | #include <climits>
|
57 | 58 | #include <cmath>
|
@@ -8246,21 +8247,56 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
|
8246 | 8247 | return;
|
8247 | 8248 | }
|
8248 | 8249 |
|
8249 |
| - llama_sample_softmax(ctx, candidates); |
8250 |
| - |
8251 | 8250 | const int64_t t_start_sample_us = ggml_time_us();
|
8252 | 8251 |
|
8253 |
| - float scale = candidates->data[0].p; // scale by max prob |
8254 |
| - size_t i = 1; // first token always matches |
| 8252 | + bool min_p_applied = false; |
| 8253 | + |
| 8254 | + // if the candidates aren't sorted, try the unsorted implementation first |
| 8255 | + if (!candidates->sorted) { |
| 8256 | + std::vector<llama_token_data> filtered_tokens; |
| 8257 | + |
| 8258 | + float max_logit = -FLT_MAX; |
| 8259 | + for (size_t i = 0; i < candidates->size; ++i) { |
| 8260 | + max_logit = std::max(max_logit, candidates->data[i].logit); |
| 8261 | + } |
| 8262 | + const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max |
| 8263 | + |
| 8264 | + for (size_t i = 0; i < candidates->size; ++i) { |
| 8265 | + if (candidates->data[i].logit >= min_logit) { |
| 8266 | + filtered_tokens.push_back(candidates->data[i]); |
| 8267 | + } |
| 8268 | + } |
8255 | 8269 |
|
8256 |
| - for (; i < candidates->size; ++i) { |
8257 |
| - if (candidates->data[i].p < p * scale && i >= min_keep) { |
8258 |
| - break; // prob too small |
| 8270 | + // if we have enough values the operation was a success |
| 8271 | + if (filtered_tokens.size() >= min_keep) { |
| 8272 | + memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); |
| 8273 | + candidates->size = filtered_tokens.size(); |
| 8274 | + min_p_applied = true; |
8259 | 8275 | }
|
8260 | 8276 | }
|
8261 | 8277 |
|
8262 |
| - // Resize the output vector to keep only the matching tokens |
8263 |
| - candidates->size = i; |
| 8278 | + // if the candidates are sorted or the unsorted implementation failed, use this implementation |
| 8279 | + if (!min_p_applied) { |
| 8280 | + // Sort the logits in descending order |
| 8281 | + if (!candidates->sorted) { |
| 8282 | + std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { |
| 8283 | + return a.logit > b.logit; |
| 8284 | + }); |
| 8285 | + candidates->sorted = true; |
| 8286 | + } |
| 8287 | + |
| 8288 | + const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max |
| 8289 | + size_t i = 1; // first token always matches |
| 8290 | + |
| 8291 | + for (; i < candidates->size; ++i) { |
| 8292 | + if (candidates->data[i].logit < min_logit && i >= min_keep) { |
| 8293 | + break; // prob too small |
| 8294 | + } |
| 8295 | + } |
| 8296 | + |
| 8297 | + // Resize the output vector to keep only the matching tokens |
| 8298 | + candidates->size = i; |
| 8299 | + } |
8264 | 8300 |
|
8265 | 8301 | if (ctx) {
|
8266 | 8302 | ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
0 commit comments