Skip to content

Commit 7d8661e

Browse files
committed
token healing : change dynamic rollback
Dynamic rollback now starts checking prefixes based on the length of the longest token.
1 parent cec3120 commit 7d8661e

File tree

2 files changed

+95
-43
lines changed

2 files changed

+95
-43
lines changed

common/sampling.cpp

Lines changed: 94 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@ static bool startswith(const std::string & str, const std::string & prefix) {
1313
static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) {
1414
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
1515
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
16-
if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) {
16+
std::string token = llama_token_to_piece(ctx_main, token_id);
17+
if (startswith(token, prefix)) {
1718
return true;
1819
}
1920
}
2021
return false;
2122
}
2223

23-
static std::vector<llama_token> token_healing_find_prefix(
24+
static std::vector<llama_token> token_healing_get_candidates(
2425
const llama_context * ctx_main,
2526
const std::string & prefix,
2627
const bool include_partial_prefix) {
@@ -38,6 +39,85 @@ static std::vector<llama_token> token_healing_find_prefix(
3839
return candidates;
3940
}
4041

42+
static size_t get_max_token_length(const llama_context * ctx_main) {
43+
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
44+
size_t len = 0;
45+
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
46+
std::string token = llama_token_to_piece(ctx_main, token_id);
47+
len = std::max(len, token.size());
48+
}
49+
return len;
50+
}
51+
52+
struct token_healing_info {
53+
std::string prefix;
54+
int n_tokens_removed;
55+
};
56+
57+
token_healing_info llama_token_healing_get_prefix(
58+
const llama_context * ctx_main,
59+
const llama_token_healing_type th_type,
60+
const std::vector<llama_token> & tokens,
61+
int max_to_remove) {
62+
if (tokens.size() <= 1) {
63+
return {"", 0};
64+
}
65+
66+
const int n_ctx = tokens.size();
67+
max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove;
68+
max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain
69+
70+
int removed = 0;
71+
std::string prefix;
72+
73+
const llama_model * model = llama_get_model(ctx_main);
74+
auto is_special_token = [&](const llama_token token_id) {
75+
return llama_token_is_control(model, token_id) || llama_token_is_eog(model, token_id);
76+
};
77+
78+
if (th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI) {
79+
// The number of bytes to roll back cannot exceed the length of the longest token.
80+
const size_t n_longest_token = get_max_token_length(ctx_main);
81+
size_t len = 0;
82+
while (removed < max_to_remove) {
83+
const llama_token next_token_id = tokens[n_ctx - removed - 1];
84+
if (is_special_token(next_token_id)) {
85+
break;
86+
}
87+
const size_t next_token_size = llama_token_to_piece(ctx_main, next_token_id).size();
88+
if (len + next_token_size > n_longest_token) {
89+
break;
90+
}
91+
len += next_token_size;
92+
removed += 1;
93+
}
94+
95+
while (removed > 0) {
96+
prefix.clear();
97+
for (int i = n_ctx - removed; i < n_ctx; i++) {
98+
prefix += llama_token_to_piece(ctx_main, tokens[i]);
99+
}
100+
if (token_healing_prefix_exists(ctx_main, prefix)) {
101+
break; // Stop on longest valid prefix
102+
}
103+
removed -= 1;
104+
}
105+
} else {
106+
// Roll back tokens a fixed amount and stop early if a special token is encountered.
107+
while (removed < max_to_remove) {
108+
const llama_token next_token_id = tokens[n_ctx - removed - 1];
109+
if (is_special_token(next_token_id)) {
110+
break;
111+
}
112+
removed += 1;
113+
}
114+
for (int i = n_ctx - removed; i < n_ctx; i++) {
115+
prefix += llama_token_to_piece(ctx_main, tokens[i]);
116+
}
117+
}
118+
return {prefix, removed};
119+
}
120+
41121
//
42122
// Token healing (external)
43123
//
@@ -48,56 +128,28 @@ std::string llama_token_healing_rollback(
48128
std::vector<llama_token> & tokens,
49129
int max_to_remove,
50130
int * n_removed) {
51-
// NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back.
52-
// It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt.
53131
if (n_removed != nullptr) {
54132
*n_removed = 0;
55133
}
56-
if (tokens.size() <= 1) {
57-
return "";
58-
}
59-
const llama_model * model = llama_get_model(ctx_main);
60-
const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI;
61-
const int n_ctx = tokens.size();
62-
max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove;
63-
max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain
64-
int removed = 0;
65-
std::string prefix;
66-
// Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt
67-
// and stop early if a special token is encountered.
68-
// NB. This doesn't handle cases where a long token is split many times,
69-
// e.g. if "abc" is tokenized into ["a", "b", "c"] but "bc" is not a token (hypothetically),
70-
// then "abc" will not be returned even if "abcd" exists in the vocab.
71-
while (removed < max_to_remove) {
72-
const llama_token next_token_id = tokens[n_ctx - removed - 1];
73-
if (llama_token_is_control(model, next_token_id) || llama_token_is_eog(model, next_token_id)) {
74-
break; // Don't roll back e.g. <|endoftext|>
75-
}
76-
std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix;
77-
if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) {
78-
break;
79-
}
80-
removed += 1;
81-
prefix = new_prefix;
82-
}
83-
if (removed == 0) { // E.g. if the last token is a special token
84-
return "";
85-
}
86-
// If constrained decoding would give back the original prompt, there is no need to modify the context
134+
// NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back.
135+
// It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt.
136+
token_healing_info info = llama_token_healing_get_prefix(ctx_main, th_type, tokens, max_to_remove);
137+
138+
// If constrained decoding would give back the original prompt, there is no need to modify the prompt.
87139
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
88140
th_type == llama_token_healing_type::DYNAMIC_MULTI;
89-
const std::vector<llama_token> candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step);
90-
LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), removed);
91-
if (removed == 1 && candidates.size() == 1) {
141+
const std::vector<llama_token> candidates = token_healing_get_candidates(ctx_main, info.prefix, is_multi_step);
142+
LOG("token_healing: prefix = '%s' (%d tokens)\n", info.prefix.c_str(), info.n_tokens_removed);
143+
if (info.n_tokens_removed == 1 && candidates.size() == 1) {
92144
LOG("token_healing: nothing to heal\n");
93145
return "";
94146
}
95147
// Finalize outputs
96148
if (n_removed != nullptr) {
97-
*n_removed = removed;
149+
*n_removed = info.n_tokens_removed;
98150
}
99-
tokens.resize(n_ctx - removed);
100-
return prefix;
151+
tokens.resize(tokens.size() - info.n_tokens_removed);
152+
return info.prefix;
101153
}
102154

103155
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) {
@@ -506,7 +558,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
506558
if (params.token_healing_enabled && !th_prefix.empty()) {
507559
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
508560
th_type == llama_token_healing_type::DYNAMIC_MULTI;
509-
std::vector<llama_token> th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step);
561+
std::vector<llama_token> th_candidates = token_healing_get_candidates(ctx_main, th_prefix, is_multi_step);
510562

511563
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
512564
for (const llama_token token_id : th_candidates) {

examples/main/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ int main(int argc, char ** argv) {
278278

279279
if (sparams.token_healing_enabled && (params.conversation || !params.input_suffix.empty())) {
280280
sparams.token_healing_enabled = false;
281-
LOG("token_healing: disabled due to custom suffix/conversation mode");
281+
LOG("token healing: disabled due to custom suffix/conversation mode");
282282
}
283283
std::string token_healing_prefix;
284284
int token_healing_n_removed = 0;

0 commit comments

Comments
 (0)