Skip to content

Commit 2442cae

Browse files
committed
token healing : refactor to return struct
1 parent 7d8661e commit 2442cae

File tree

3 files changed

+44
-52
lines changed

3 files changed

+44
-52
lines changed

common/sampling.cpp

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,13 @@ static size_t get_max_token_length(const llama_context * ctx_main) {
4949
return len;
5050
}
5151

52-
struct token_healing_info {
53-
std::string prefix;
54-
int n_tokens_removed;
55-
};
56-
57-
token_healing_info llama_token_healing_get_prefix(
58-
const llama_context * ctx_main,
59-
const llama_token_healing_type th_type,
60-
const std::vector<llama_token> & tokens,
61-
int max_to_remove) {
52+
static llama_token_healing_output llama_token_healing_get_prefix(
53+
const llama_context * ctx_main,
54+
const llama_token_healing_type th_type,
55+
const std::vector<llama_token> & tokens,
56+
int max_to_remove) {
6257
if (tokens.size() <= 1) {
63-
return {"", 0};
58+
return {};
6459
}
6560

6661
const int n_ctx = tokens.size();
@@ -122,34 +117,28 @@ token_healing_info llama_token_healing_get_prefix(
122117
// Token healing (external)
123118
//
124119

125-
std::string llama_token_healing_rollback(
126-
const llama_context * ctx_main,
127-
llama_token_healing_type th_type,
128-
std::vector<llama_token> & tokens,
129-
int max_to_remove,
130-
int * n_removed) {
131-
if (n_removed != nullptr) {
132-
*n_removed = 0;
133-
}
120+
llama_token_healing_output llama_token_healing_rollback(
121+
const llama_context * ctx_main,
122+
llama_token_healing_type th_type,
123+
std::vector<llama_token> & tokens,
124+
int max_to_remove) {
134125
// NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back.
135126
// It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt.
136-
token_healing_info info = llama_token_healing_get_prefix(ctx_main, th_type, tokens, max_to_remove);
127+
llama_token_healing_output out = llama_token_healing_get_prefix(ctx_main, th_type, tokens, max_to_remove);
137128

138129
// If constrained decoding would give back the original prompt, there is no need to modify the prompt.
139130
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
140131
th_type == llama_token_healing_type::DYNAMIC_MULTI;
141-
const std::vector<llama_token> candidates = token_healing_get_candidates(ctx_main, info.prefix, is_multi_step);
142-
LOG("token_healing: prefix = '%s' (%d tokens)\n", info.prefix.c_str(), info.n_tokens_removed);
143-
if (info.n_tokens_removed == 1 && candidates.size() == 1) {
132+
const std::vector<llama_token> candidates = token_healing_get_candidates(ctx_main, out.prefix, is_multi_step);
133+
LOG("token_healing: prefix = '%s' (%d tokens)\n", out.prefix.c_str(), out.n_tokens_removed);
134+
if (out.n_tokens_removed == 1 && candidates.size() == 1) {
144135
LOG("token_healing: nothing to heal\n");
145-
return "";
136+
return {};
146137
}
147-
// Finalize outputs
148-
if (n_removed != nullptr) {
149-
*n_removed = info.n_tokens_removed;
150-
}
151-
tokens.resize(tokens.size() - info.n_tokens_removed);
152-
return info.prefix;
138+
139+
// Finally, trim prompt tokens
140+
tokens.resize(tokens.size() - out.n_tokens_removed);
141+
return out;
153142
}
154143

155144
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix) {

common/sampling.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,17 @@ void llama_sampling_accept(
176176
// Token healing
177177
//
178178

179-
// Roll back `tokens` for constrained generation according to the token healing
180-
// strategy. Returns the prefix for constrained generation.
181-
std::string llama_token_healing_rollback(
182-
const llama_context * ctx_main,
183-
llama_token_healing_type th_type,
184-
std::vector<llama_token> & tokens,
185-
int max_to_remove = -1,
186-
int * n_removed = nullptr);
179+
struct llama_token_healing_output {
180+
std::string prefix;
181+
int n_tokens_removed;
182+
};
183+
184+
// Roll back `tokens` for constrained generation according to the token healing strategy.
185+
// Call `llama_token_healing_set_prefix` with the returned prefix before the first sampling.
186+
llama_token_healing_output llama_token_healing_rollback(
187+
const llama_context * ctx_main,
188+
llama_token_healing_type th_type,
189+
std::vector<llama_token> & tokens,
190+
int max_to_remove = -1);
187191

188192
void llama_token_healing_set_prefix(llama_sampling_context * ctx_sampling, const std::string & prefix);

examples/main/main.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -280,11 +280,10 @@ int main(int argc, char ** argv) {
280280
sparams.token_healing_enabled = false;
281281
LOG("token healing: disabled due to custom suffix/conversation mode");
282282
}
283-
std::string token_healing_prefix;
284-
int token_healing_n_removed = 0;
283+
llama_token_healing_output token_healing_out{};
285284
if (!params.interactive_first && sparams.token_healing_enabled) {
286-
token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
287-
sparams.token_healing_n_rollback, &token_healing_n_removed);
285+
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
286+
sparams.token_healing_n_rollback);
288287
}
289288

290289
// Should not run without any tokens
@@ -306,7 +305,7 @@ int main(int argc, char ** argv) {
306305
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
307306
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
308307

309-
original_prompt_len = original_inp.size() - token_healing_n_removed;
308+
original_prompt_len = original_inp.size() - token_healing_out.n_tokens_removed;
310309
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
311310
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
312311
LOG("guidance_offset: %s", log_tostr(guidance_offset));
@@ -528,7 +527,7 @@ int main(int argc, char ** argv) {
528527
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
529528
exit(1);
530529
}
531-
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);
530+
llama_token_healing_set_prefix(ctx_sampling, token_healing_out.prefix);
532531

533532
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
534533
// predict
@@ -845,7 +844,8 @@ int main(int argc, char ** argv) {
845844
assistant_ss << llama_token_to_piece(ctx, id, false);
846845
}
847846

848-
token_healing_n_removed = 0;
847+
token_healing_out = {};
848+
849849
if (n_past > 0 && is_interacting) {
850850
LOG("waiting for user input\n");
851851

@@ -917,9 +917,8 @@ int main(int argc, char ** argv) {
917917
const int max_to_remove = sparams.token_healing_n_rollback < 0
918918
? n_new_tokens
919919
: std::min(sparams.token_healing_n_rollback, n_new_tokens);
920-
token_healing_prefix = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp,
921-
max_to_remove, &token_healing_n_removed);
922-
n_bytes_to_skip = token_healing_prefix.size();
920+
token_healing_out = llama_token_healing_rollback(ctx, sparams.token_healing_type, embd_inp, max_to_remove);
921+
n_bytes_to_skip = token_healing_out.prefix.size();
923922
}
924923

925924
for (size_t i = original_size; i < embd_inp.size(); ++i) {
@@ -931,7 +930,7 @@ int main(int argc, char ** argv) {
931930
// reset assistant message
932931
assistant_ss.str("");
933932

934-
n_remain -= line_inp.size() + token_healing_n_removed;
933+
n_remain -= line_inp.size() + token_healing_out.n_tokens_removed;
935934
LOG("n_remain: %d\n", n_remain);
936935
} else {
937936
LOG("empty line, passing control back\n");
@@ -943,9 +942,9 @@ int main(int argc, char ** argv) {
943942
if (n_past > 0) {
944943
if (is_interacting) {
945944
llama_sampling_reset(ctx_sampling);
946-
if (token_healing_n_removed > 0) {
945+
if (token_healing_out.n_tokens_removed > 0) {
947946
// Set new prefix after an interaction
948-
llama_token_healing_set_prefix(ctx_sampling, token_healing_prefix);
947+
llama_token_healing_set_prefix(ctx_sampling, token_healing_out.prefix);
949948
}
950949
}
951950
is_interacting = false;

0 commit comments

Comments
 (0)