|
2 | 2 | #include "sampling.h"
|
3 | 3 | #include <random>
|
4 | 4 |
|
| 5 | +// |
| 6 | +// Token healing (internal) |
| 7 | +// |
| 8 | + |
| 9 | +static bool startswith(const std::string & str, const std::string & prefix) { |
| 10 | + return str.rfind(prefix, 0) != std::string::npos; |
| 11 | +} |
| 12 | + |
| 13 | +static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) { |
| 14 | + const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); |
| 15 | + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { |
| 16 | + if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) { |
| 17 | + return true; |
| 18 | + } |
| 19 | + } |
| 20 | + return false; |
| 21 | +} |
| 22 | + |
| 23 | +static std::vector<llama_token> token_healing_find_prefix( |
| 24 | + const llama_context * ctx_main, |
| 25 | + const std::string & prefix, |
| 26 | + const bool include_partial_prefix) { |
| 27 | + // Example: prefix=" world" -> " world", " worldwide", ... |
| 28 | + // If `include_partial_prefix`, include also: " w", " wo", ... |
| 29 | + std::vector<llama_token> candidates; |
| 30 | + const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main)); |
| 31 | + for (llama_token token_id = 0; token_id < n_vocab; ++token_id) { |
| 32 | + std::string token = llama_token_to_piece(ctx_main, token_id); |
| 33 | + if (startswith(token, prefix) || |
| 34 | + (include_partial_prefix && startswith(prefix, token))) { |
| 35 | + candidates.push_back(token_id); |
| 36 | + } |
| 37 | + } |
| 38 | + return candidates; |
| 39 | +} |
| 40 | + |
| 41 | +// |
| 42 | +// Token healing (external) |
| 43 | +// |
| 44 | + |
| 45 | +std::string llama_token_healing_rollback( |
| 46 | + const llama_context * ctx_main, |
| 47 | + llama_token_healing_type th_type, |
| 48 | + std::vector<llama_token> & tokens, |
| 49 | + int max_to_remove, |
| 50 | + int * n_removed) { |
| 51 | + // NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back. |
| 52 | + // It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt. |
| 53 | + if (n_removed != nullptr) { |
| 54 | + *n_removed = 0; |
| 55 | + } |
| 56 | + if (tokens.size() <= 1) { |
| 57 | + return ""; |
| 58 | + } |
| 59 | + const llama_model * model = llama_get_model(ctx_main); |
| 60 | + const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI; |
| 61 | + const int n_ctx = tokens.size(); |
| 62 | + max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove; |
| 63 | + max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain |
| 64 | + int removed = 0; |
| 65 | + std::string prefix; |
| 66 | + // Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt |
| 67 | + // and stop early if a special token is encountered. |
| 68 | + // NB. This doesn't handle cases where a long token is split many times, |
| 69 | + // e.g. if "abc" is tokenized into ["a", "b", "c"] but "bc" is not a token (hypothetically), |
| 70 | + // then "abc" will not be returned even if "abcd" exists in the vocab. |
| 71 | + while (removed < max_to_remove) { |
| 72 | + const llama_token next_token_id = tokens[n_ctx - removed - 1]; |
| 73 | + if (llama_token_is_control(model, next_token_id) || llama_token_is_eog(model, next_token_id)) { |
| 74 | + break; // Don't roll back e.g. <|endoftext|> |
| 75 | + } |
| 76 | + std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix; |
| 77 | + if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) { |
| 78 | + break; |
| 79 | + } |
| 80 | + removed += 1; |
| 81 | + prefix = new_prefix; |
| 82 | + } |
| 83 | + if (removed == 0) { // E.g. if the last token is a special token |
| 84 | + return ""; |
| 85 | + } |
| 86 | + // If constrained decoding would give back the original prompt, there is no need to modify the context |
| 87 | + const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || |
| 88 | + th_type == llama_token_healing_type::DYNAMIC_MULTI; |
| 89 | + const std::vector<llama_token> candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step); |
| 90 | + LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), removed); |
| 91 | + if (removed == 1 && candidates.size() == 1) { |
| 92 | + LOG("token_healing: nothing to heal\n"); |
| 93 | + return ""; |
| 94 | + } |
| 95 | + // Finalize outputs |
| 96 | + if (n_removed != nullptr) { |
| 97 | + *n_removed = removed; |
| 98 | + } |
| 99 | + tokens.resize(n_ctx - removed); |
| 100 | + return prefix; |
| 101 | +} |
| 102 | + |
| 103 | +void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) { |
| 104 | + ctx_sampling->token_healing_prefix = prefix; |
| 105 | +} |
| 106 | + |
| 107 | +// |
| 108 | +// Sampling |
| 109 | +// |
| 110 | + |
5 | 111 | struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
|
6 | 112 | struct llama_sampling_context * result = new llama_sampling_context();
|
7 | 113 |
|
@@ -72,6 +178,8 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
|
72 | 178 | ctx->grammar = grammar;
|
73 | 179 | }
|
74 | 180 |
|
| 181 | + ctx->token_healing_prefix.clear(); |
| 182 | + |
75 | 183 | std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
76 | 184 | ctx->cur.clear();
|
77 | 185 | ctx->n_valid = 0;
|
@@ -130,7 +238,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
|
130 | 238 | }
|
131 | 239 |
|
132 | 240 | std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
133 |
| - std::string result = "CFG -> Penalties "; |
| 241 | + std::string result = "(Token healing) -> CFG -> Penalties "; |
134 | 242 | if (params.mirostat == 0) {
|
135 | 243 | for (auto sampler_type : params.samplers_sequence) {
|
136 | 244 | const auto sampler_type_name = llama_sampling_type_to_str(sampler_type);
|
@@ -392,8 +500,27 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
392 | 500 |
|
393 | 501 | cur.clear();
|
394 | 502 |
|
395 |
| - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { |
396 |
| - cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); |
| 503 | + // Constrain tokens based on the remaining token healing prefix (if any) |
| 504 | + const auto & th_type = params.token_healing_type; |
| 505 | + const auto & th_prefix = ctx_sampling->token_healing_prefix; |
| 506 | + if (params.token_healing_enabled && !th_prefix.empty()) { |
| 507 | + const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI || |
| 508 | + th_type == llama_token_healing_type::DYNAMIC_MULTI; |
| 509 | + std::vector<llama_token> th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step); |
| 510 | + |
| 511 | + LOG("token_healing: prefix = '%s'\n", th_prefix.c_str()); |
| 512 | + for (const llama_token token_id : th_candidates) { |
| 513 | + LOG(" [%6d] '%s'\n", token_id, llama_token_to_piece(ctx_main, token_id).c_str()); |
| 514 | + } |
| 515 | + |
| 516 | + // N.B. We could also set token constraints by setting rejected tokens' logits to -inf |
| 517 | + for (const llama_token token_id : th_candidates) { |
| 518 | + cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); |
| 519 | + } |
| 520 | + } else { |
| 521 | + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { |
| 522 | + cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); |
| 523 | + } |
397 | 524 | }
|
398 | 525 |
|
399 | 526 | llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
@@ -456,4 +583,19 @@ void llama_sampling_accept(
|
456 | 583 | if (ctx_sampling->grammar != NULL && apply_grammar) {
|
457 | 584 | llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
|
458 | 585 | }
|
| 586 | + |
| 587 | + if (ctx_sampling->params.token_healing_enabled && apply_grammar) { |
| 588 | + std::string & th_prefix = ctx_sampling->token_healing_prefix; |
| 589 | + if (!th_prefix.empty()) { |
| 590 | + const std::string new_token_piece = llama_token_to_piece(ctx_main, id); |
| 591 | + if (new_token_piece.size() < th_prefix.size()) { |
| 592 | + // Shift prefix constraint (for multi step token healing) |
| 593 | + th_prefix = th_prefix.substr(new_token_piece.size()); |
| 594 | + } else { |
| 595 | + // Prefix has been generated => no more constrained generation |
| 596 | + th_prefix.clear(); |
| 597 | + LOG("token_healing: done\n"); |
| 598 | + } |
| 599 | + } |
| 600 | + } |
459 | 601 | }
|
0 commit comments