Skip to content

Commit cec3120

Browse files
committed
main : add token healing
1 parent 0ddeff1 commit cec3120

File tree

5 files changed

+249
-6
lines changed

5 files changed

+249
-6
lines changed

common/common.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,6 +1058,25 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
10581058
sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
10591059
return true;
10601060
}
1061+
if (arg == "-th" || arg == "--token-healing") {
1062+
CHECK_ARG
1063+
sparams.token_healing_enabled = true;
1064+
auto & th_type = sparams.token_healing_type;
1065+
auto & th_n_rollback = sparams.token_healing_n_rollback;
1066+
std::string value(argv[i]);
1067+
/**/ if (value == "0" ) { sparams.token_healing_enabled = false; }
1068+
else if (value == "1" ) { th_type = llama_token_healing_type::ROLLBACK_LAST; }
1069+
else if (value == "d1") { th_type = llama_token_healing_type::DYNAMIC_ONCE; }
1070+
else if (value == "d" ) { th_type = llama_token_healing_type::DYNAMIC_MULTI; }
1071+
else if (value[0] == 'r' ) {
1072+
th_type = llama_token_healing_type::ROLLBACK_MULTI;
1073+
th_n_rollback = std::stoi(value.substr(1));
1074+
if (th_n_rollback <= 0) {
1075+
sparams.token_healing_enabled = false;
1076+
}
1077+
} else { invalid_param = true; }
1078+
return true;
1079+
}
10611080
if (arg == "--override-kv") {
10621081
CHECK_ARG
10631082
if (!string_parse_kv_override(argv[i], params.kv_overrides)) {
@@ -1455,6 +1474,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
14551474
"set custom jinja chat template (default: template taken from model's metadata)\n"
14561475
"only commonly used templates are accepted:\n"
14571476
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
1477+
1478+
options.push_back({ "main", "-th, --token-healing {0,1,d1,d,r{N}}",
1479+
"Token healing type. (default: 0, disabled)\n"
1480+
"1: replace one token, d1: replace longest suffix with one token, d: replace longest suffix, r{N}: roll back N tokens" });
14581481
options.push_back({ "grammar" });
14591482
options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() });
14601483
options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" });

common/sampling.cpp

Lines changed: 145 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,112 @@
22
#include "sampling.h"
33
#include <random>
44

5+
//
6+
// Token healing (internal)
7+
//
8+
9+
static bool startswith(const std::string & str, const std::string & prefix) {
10+
return str.rfind(prefix, 0) != std::string::npos;
11+
}
12+
13+
static bool token_healing_prefix_exists(const llama_context * ctx_main, const std::string & prefix) {
14+
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
15+
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
16+
if (startswith(llama_token_to_piece(ctx_main, token_id), prefix)) {
17+
return true;
18+
}
19+
}
20+
return false;
21+
}
22+
23+
static std::vector<llama_token> token_healing_find_prefix(
24+
const llama_context * ctx_main,
25+
const std::string & prefix,
26+
const bool include_partial_prefix) {
27+
// Example: prefix=" world" -> " world", " worldwide", ...
28+
// If `include_partial_prefix`, include also: " w", " wo", ...
29+
std::vector<llama_token> candidates;
30+
const int32_t n_vocab = llama_n_vocab(llama_get_model(ctx_main));
31+
for (llama_token token_id = 0; token_id < n_vocab; ++token_id) {
32+
std::string token = llama_token_to_piece(ctx_main, token_id);
33+
if (startswith(token, prefix) ||
34+
(include_partial_prefix && startswith(prefix, token))) {
35+
candidates.push_back(token_id);
36+
}
37+
}
38+
return candidates;
39+
}
40+
41+
//
42+
// Token healing (external)
43+
//
44+
45+
std::string llama_token_healing_rollback(
46+
const llama_context * ctx_main,
47+
llama_token_healing_type th_type,
48+
std::vector<llama_token> & tokens,
49+
int max_to_remove,
50+
int * n_removed) {
51+
// NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back.
52+
// It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt.
53+
if (n_removed != nullptr) {
54+
*n_removed = 0;
55+
}
56+
if (tokens.size() <= 1) {
57+
return "";
58+
}
59+
const llama_model * model = llama_get_model(ctx_main);
60+
const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI;
61+
const int n_ctx = tokens.size();
62+
max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove;
63+
max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min(max_to_remove, n_ctx - 1); // 1 token must remain
64+
int removed = 0;
65+
std::string prefix;
66+
// Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt
67+
// and stop early if a special token is encountered.
68+
// NB. This doesn't handle cases where a long token is split many times,
69+
// e.g. if "abc" is tokenized into ["a", "b", "c"] but "bc" is not a token (hypothetically),
70+
// then "abc" will not be returned even if "abcd" exists in the vocab.
71+
while (removed < max_to_remove) {
72+
const llama_token next_token_id = tokens[n_ctx - removed - 1];
73+
if (llama_token_is_control(model, next_token_id) || llama_token_is_eog(model, next_token_id)) {
74+
break; // Don't roll back e.g. <|endoftext|>
75+
}
76+
std::string new_prefix = llama_token_to_piece(ctx_main, next_token_id) + prefix;
77+
if (is_dynamic && !token_healing_prefix_exists(ctx_main, new_prefix)) {
78+
break;
79+
}
80+
removed += 1;
81+
prefix = new_prefix;
82+
}
83+
if (removed == 0) { // E.g. if the last token is a special token
84+
return "";
85+
}
86+
// If constrained decoding would give back the original prompt, there is no need to modify the context
87+
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
88+
th_type == llama_token_healing_type::DYNAMIC_MULTI;
89+
const std::vector<llama_token> candidates = token_healing_find_prefix(ctx_main, prefix, is_multi_step);
90+
LOG("token_healing: prefix = '%s' (%d tokens)\n", prefix.c_str(), removed);
91+
if (removed == 1 && candidates.size() == 1) {
92+
LOG("token_healing: nothing to heal\n");
93+
return "";
94+
}
95+
// Finalize outputs
96+
if (n_removed != nullptr) {
97+
*n_removed = removed;
98+
}
99+
tokens.resize(n_ctx - removed);
100+
return prefix;
101+
}
102+
103+
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) {
104+
ctx_sampling->token_healing_prefix = prefix;
105+
}
106+
107+
//
108+
// Sampling
109+
//
110+
5111
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
6112
struct llama_sampling_context * result = new llama_sampling_context();
7113

@@ -72,6 +178,8 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
72178
ctx->grammar = grammar;
73179
}
74180

181+
ctx->token_healing_prefix.clear();
182+
75183
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
76184
ctx->cur.clear();
77185
ctx->n_valid = 0;
@@ -130,7 +238,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
130238
}
131239

132240
std::string llama_sampling_order_print(const llama_sampling_params & params) {
133-
std::string result = "CFG -> Penalties ";
241+
std::string result = "(Token healing) -> CFG -> Penalties ";
134242
if (params.mirostat == 0) {
135243
for (auto sampler_type : params.samplers_sequence) {
136244
const auto sampler_type_name = llama_sampling_type_to_str(sampler_type);
@@ -392,8 +500,27 @@ static llama_token_data_array llama_sampling_prepare_impl(
392500

393501
cur.clear();
394502

395-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
396-
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
503+
// Constrain tokens based on the remaining token healing prefix (if any)
504+
const auto & th_type = params.token_healing_type;
505+
const auto & th_prefix = ctx_sampling->token_healing_prefix;
506+
if (params.token_healing_enabled && !th_prefix.empty()) {
507+
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
508+
th_type == llama_token_healing_type::DYNAMIC_MULTI;
509+
std::vector<llama_token> th_candidates = token_healing_find_prefix(ctx_main, th_prefix, is_multi_step);
510+
511+
LOG("token_healing: prefix = '%s'\n", th_prefix.c_str());
512+
for (const llama_token token_id : th_candidates) {
513+
LOG(" [%6d] '%s'\n", token_id, llama_token_to_piece(ctx_main, token_id).c_str());
514+
}
515+
516+
// N.B. We could also set token constraints by setting rejected tokens' logits to -inf
517+
for (const llama_token token_id : th_candidates) {
518+
cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
519+
}
520+
} else {
521+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
522+
cur.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
523+
}
397524
}
398525

399526
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
@@ -456,4 +583,19 @@ void llama_sampling_accept(
456583
if (ctx_sampling->grammar != NULL && apply_grammar) {
457584
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
458585
}
586+
587+
if (ctx_sampling->params.token_healing_enabled && apply_grammar) {
588+
std::string & th_prefix = ctx_sampling->token_healing_prefix;
589+
if (!th_prefix.empty()) {
590+
const std::string new_token_piece = llama_token_to_piece(ctx_main, id);
591+
if (new_token_piece.size() < th_prefix.size()) {
592+
// Shift prefix constraint (for multi step token healing)
593+
th_prefix = th_prefix.substr(new_token_piece.size());
594+
} else {
595+
// Prefix has been generated => no more constrained generation
596+
th_prefix.clear();
597+
LOG("token_healing: done\n");
598+
}
599+
}
600+
}
459601
}

common/sampling.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ enum class llama_sampler_type : char {
1919
TEMPERATURE = 't'
2020
};
2121

22+
enum class llama_token_healing_type : uint8_t {
23+
ROLLBACK_LAST, // roll back last token with a single constrained decoding step
24+
ROLLBACK_MULTI, // roll back a fixed amount of tokens, multiple constrained decoding steps
25+
DYNAMIC_ONCE, // dynamic roll back, single constrained decoding step
26+
DYNAMIC_MULTI // dynamic roll back, multiple constrained decoding steps
27+
};
28+
2229
// sampling parameters
2330
typedef struct llama_sampling_params {
2431
int32_t n_prev = 64; // number of previous tokens to remember
@@ -62,6 +69,10 @@ typedef struct llama_sampling_params {
6269

6370
std::vector<llama_token> penalty_prompt_tokens;
6471
bool use_penalty_prompt_tokens = false;
72+
73+
llama_token_healing_type token_healing_type = llama_token_healing_type::ROLLBACK_LAST;
74+
bool token_healing_enabled = false;
75+
int token_healing_n_rollback = -1; // number of tokens to roll back
6576
} llama_sampling_params;
6677

6778
// general sampler context
@@ -78,6 +89,8 @@ struct llama_sampling_context {
7889
// internal
7990
grammar_parser::parse_state parsed_grammar;
8091

92+
std::string token_healing_prefix; // remaining prefix to constrain sampling
93+
8194
// TODO: replace with ring-buffer
8295
std::vector<llama_token> prev;
8396
std::vector<llama_token_data> cur;
@@ -158,3 +171,18 @@ void llama_sampling_accept(
158171
struct llama_context * ctx_main,
159172
llama_token id,
160173
bool apply_grammar);
174+
175+
//
176+
// Token healing
177+
//
178+
179+
// Roll back `tokens` for constrained generation according to the token healing
180+
// strategy. Returns the prefix for constrained generation.
181+
std::string llama_token_healing_rollback(
182+
const llama_context * ctx_main,
183+
llama_token_healing_type th_type,
184+
std::vector<llama_token> & tokens,
185+
int max_to_remove = -1,
186+
int * n_removed = nullptr);
187+
188+
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);

examples/main/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,19 @@ A more practical use case might be to prevent the generation of `\code{begin}` a
246246

247247
Example usage: `--logit-bias 29905-inf`
248248

249+
### Token healing
250+
251+
- `-th {0,1,d1,d,r{N}}, --token-healing {0,1,d1,d,r{N}}`: Set the token healing strategy (default: 0, 0 = disabled).
252+
253+
Token healing (a.k.a. token alignment) alleviates tokenization artifacts for text completion.
254+
255+
- `-th 1`: Roll back the last token and constrain the bytes of the next token to start with the chopped off last token [0, 2].
256+
- `-th d1`: Roll back multiple tokens until there doesn't exist a token which can cover the prompt's suffix and do a single constrained decoding step [2].
257+
- `-th d`: Like `d1` but allow multiple decoding steps until the removed suffix is generated.
258+
- `-th r{N}`: Like `d` but roll back `N` tokens, where `-th r3` is recommended [1].
259+
260+
Sources: [0](https://github.com/guidance-ai/guidance/blob/main/notebooks/art_of_prompt_design/prompt_boundaries_and_token_healing.ipynb), [1](https://arxiv.org/abs/2403.08688), [2](https://arxiv.org/abs/2402.01035).
261+
249262
### RNG Seed
250263

251264
- `-s SEED, --seed SEED`: Set the random number generator (RNG) seed (default: -1, -1 = random seed).

examples/main/main.cpp

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,17 @@ int main(int argc, char ** argv) {
276276
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
277277
}
278278

279+
if (sparams.token_healing_enabled && (params.conversation || !params.input_suffix.empty())) {
280+
sparams.token_healing_enabled = false;
281+
LOG("token_healing: disabled due to custom suffix/conversation mode");
282+
}
283+
std::string token_healing_prefix;
284+
int token_healing_n_removed = 0;
285+
if (!params.interactive_first && sparams.token_healing_enabled) {
286+
token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
287+
sparams.token_healing_n_rollback, &token_healing_n_removed);
288+
}
289+
279290
// Should not run without any tokens
280291
if (embd_inp.empty()) {
281292
embd_inp.push_back(llama_token_bos(model));
@@ -295,7 +306,7 @@ int main(int argc, char ** argv) {
295306
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
296307
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
297308

298-
original_prompt_len = original_inp.size();
309+
original_prompt_len = original_inp.size() - token_healing_n_removed;
299310
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
300311
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
301312
LOG("guidance_offset: %s", log_tostr(guidance_offset));
@@ -490,6 +501,7 @@ int main(int argc, char ** argv) {
490501
int n_consumed = 0;
491502
int n_session_consumed = 0;
492503
int n_past_guidance = 0;
504+
int n_bytes_to_skip = 0; // to skip printing when generating token healing prefix
493505

494506
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
495507
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
@@ -516,6 +528,7 @@ int main(int argc, char ** argv) {
516528
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
517529
exit(1);
518530
}
531+
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);
519532

520533
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
521534
// predict
@@ -732,7 +745,15 @@ int main(int argc, char ** argv) {
732745
const std::string token_str = llama_token_to_piece(ctx, id, params.special);
733746

734747
// Console/Stream Output
735-
fprintf(stdout, "%s", token_str.c_str());
748+
// Suppress printing while generating token healing prefix
749+
if (n_bytes_to_skip > 0 && n_bytes_to_skip < (int)token_str.size()) {
750+
fprintf(stdout, "%s", token_str.substr(n_bytes_to_skip).c_str());
751+
n_bytes_to_skip = 0;
752+
} else if (n_bytes_to_skip > 0) {
753+
n_bytes_to_skip -= token_str.size();
754+
} else {
755+
fprintf(stdout, "%s", token_str.c_str());
756+
}
736757

737758
// Record Displayed Tokens To Log
738759
// Note: Generated tokens are created one by one hence this check
@@ -824,6 +845,7 @@ int main(int argc, char ** argv) {
824845
assistant_ss << llama_token_to_piece(ctx, id, false);
825846
}
826847

848+
token_healing_n_removed = 0;
827849
if (n_past > 0 && is_interacting) {
828850
LOG("waiting for user input\n");
829851

@@ -889,6 +911,17 @@ int main(int argc, char ** argv) {
889911
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
890912
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
891913

914+
if (sparams.token_healing_enabled) {
915+
// Limit token healing rollback to new tokens only (otherwise would need to shift everything)
916+
const int n_new_tokens = embd_inp.size() - original_size;
917+
const int max_to_remove = sparams.token_healing_n_rollback < 0
918+
? n_new_tokens
919+
: std::min(sparams.token_healing_n_rollback, n_new_tokens);
920+
token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
921+
max_to_remove, &token_healing_n_removed);
922+
n_bytes_to_skip = token_healing_prefix.size();
923+
}
924+
892925
for (size_t i = original_size; i < embd_inp.size(); ++i) {
893926
const llama_token token = embd_inp[i];
894927
output_tokens.push_back(token);
@@ -898,7 +931,7 @@ int main(int argc, char ** argv) {
898931
// reset assistant message
899932
assistant_ss.str("");
900933

901-
n_remain -= line_inp.size();
934+
n_remain -= line_inp.size() + token_healing_n_removed;
902935
LOG("n_remain: %d\n", n_remain);
903936
} else {
904937
LOG("empty line, passing control back\n");
@@ -910,6 +943,10 @@ int main(int argc, char ** argv) {
910943
if (n_past > 0) {
911944
if (is_interacting) {
912945
llama_sampling_reset(ctx_sampling);
946+
if (token_healing_n_removed > 0) {
947+
// Set new prefix after an interaction
948+
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);
949+
}
913950
}
914951
is_interacting = false;
915952
}

0 commit comments

Comments
 (0)