@@ -13,14 +13,15 @@ static bool startswith(const std::string & str, const std::string & prefix) {
13
13
static bool token_healing_prefix_exists (const llama_context * ctx_main, const std::string & prefix) {
14
14
const int32_t n_vocab = llama_n_vocab (llama_get_model (ctx_main));
15
15
for (llama_token token_id = 0 ; token_id < n_vocab; ++token_id) {
16
- if (startswith (llama_token_to_piece (ctx_main, token_id), prefix)) {
16
+ std::string token = llama_token_to_piece (ctx_main, token_id);
17
+ if (startswith (token, prefix)) {
17
18
return true ;
18
19
}
19
20
}
20
21
return false ;
21
22
}
22
23
23
- static std::vector<llama_token> token_healing_find_prefix (
24
+ static std::vector<llama_token> token_healing_get_candidates (
24
25
const llama_context * ctx_main,
25
26
const std::string & prefix,
26
27
const bool include_partial_prefix) {
@@ -38,6 +39,85 @@ static std::vector<llama_token> token_healing_find_prefix(
38
39
return candidates;
39
40
}
40
41
42
+ static size_t get_max_token_length (const llama_context * ctx_main) {
43
+ const int32_t n_vocab = llama_n_vocab (llama_get_model (ctx_main));
44
+ size_t len = 0 ;
45
+ for (llama_token token_id = 0 ; token_id < n_vocab; ++token_id) {
46
+ std::string token = llama_token_to_piece (ctx_main, token_id);
47
+ len = std::max (len, token.size ());
48
+ }
49
+ return len;
50
+ }
51
+
52
+ struct token_healing_info {
53
+ std::string prefix;
54
+ int n_tokens_removed;
55
+ };
56
+
57
+ token_healing_info llama_token_healing_get_prefix (
58
+ const llama_context * ctx_main,
59
+ const llama_token_healing_type th_type,
60
+ const std::vector<llama_token> & tokens,
61
+ int max_to_remove) {
62
+ if (tokens.size () <= 1 ) {
63
+ return {" " , 0 };
64
+ }
65
+
66
+ const int n_ctx = tokens.size ();
67
+ max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove;
68
+ max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min (max_to_remove, n_ctx - 1 ); // 1 token must remain
69
+
70
+ int removed = 0 ;
71
+ std::string prefix;
72
+
73
+ const llama_model * model = llama_get_model (ctx_main);
74
+ auto is_special_token = [&](const llama_token token_id) {
75
+ return llama_token_is_control (model, token_id) || llama_token_is_eog (model, token_id);
76
+ };
77
+
78
+ if (th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI) {
79
+ // The number of bytes to roll back cannot exceed the length of the longest token.
80
+ const size_t n_longest_token = get_max_token_length (ctx_main);
81
+ size_t len = 0 ;
82
+ while (removed < max_to_remove) {
83
+ const llama_token next_token_id = tokens[n_ctx - removed - 1 ];
84
+ if (is_special_token (next_token_id)) {
85
+ break ;
86
+ }
87
+ const size_t next_token_size = llama_token_to_piece (ctx_main, next_token_id).size ();
88
+ if (len + next_token_size > n_longest_token) {
89
+ break ;
90
+ }
91
+ len += next_token_size;
92
+ removed += 1 ;
93
+ }
94
+
95
+ while (removed > 0 ) {
96
+ prefix.clear ();
97
+ for (int i = n_ctx - removed; i < n_ctx; i++) {
98
+ prefix += llama_token_to_piece (ctx_main, tokens[i]);
99
+ }
100
+ if (token_healing_prefix_exists (ctx_main, prefix)) {
101
+ break ; // Stop on longest valid prefix
102
+ }
103
+ removed -= 1 ;
104
+ }
105
+ } else {
106
+ // Roll back tokens a fixed amount and stop early if a special token is encountered.
107
+ while (removed < max_to_remove) {
108
+ const llama_token next_token_id = tokens[n_ctx - removed - 1 ];
109
+ if (is_special_token (next_token_id)) {
110
+ break ;
111
+ }
112
+ removed += 1 ;
113
+ }
114
+ for (int i = n_ctx - removed; i < n_ctx; i++) {
115
+ prefix += llama_token_to_piece (ctx_main, tokens[i]);
116
+ }
117
+ }
118
+ return {prefix, removed};
119
+ }
120
+
41
121
//
42
122
// Token healing (external)
43
123
//
@@ -48,56 +128,28 @@ std::string llama_token_healing_rollback(
48
128
std::vector<llama_token> & tokens,
49
129
int max_to_remove,
50
130
int * n_removed) {
51
- // NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back.
52
- // It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt.
53
131
if (n_removed != nullptr ) {
54
132
*n_removed = 0 ;
55
133
}
56
- if (tokens.size () <= 1 ) {
57
- return " " ;
58
- }
59
- const llama_model * model = llama_get_model (ctx_main);
60
- const bool is_dynamic = th_type == llama_token_healing_type::DYNAMIC_ONCE || th_type == llama_token_healing_type::DYNAMIC_MULTI;
61
- const int n_ctx = tokens.size ();
62
- max_to_remove = th_type == llama_token_healing_type::ROLLBACK_LAST ? 1 : max_to_remove;
63
- max_to_remove = max_to_remove < 0 ? n_ctx - 1 : std::min (max_to_remove, n_ctx - 1 ); // 1 token must remain
64
- int removed = 0 ;
65
- std::string prefix;
66
- // Roll back tokens a fixed amount or until there does not exist a token that can cover the prompt
67
- // and stop early if a special token is encountered.
68
- // NB. This doesn't handle cases where a long token is split many times,
69
- // e.g. if "abc" is tokenized into ["a", "b", "c"] but "bc" is not a token (hypothetically),
70
- // then "abc" will not be returned even if "abcd" exists in the vocab.
71
- while (removed < max_to_remove) {
72
- const llama_token next_token_id = tokens[n_ctx - removed - 1 ];
73
- if (llama_token_is_control (model, next_token_id) || llama_token_is_eog (model, next_token_id)) {
74
- break ; // Don't roll back e.g. <|endoftext|>
75
- }
76
- std::string new_prefix = llama_token_to_piece (ctx_main, next_token_id) + prefix;
77
- if (is_dynamic && !token_healing_prefix_exists (ctx_main, new_prefix)) {
78
- break ;
79
- }
80
- removed += 1 ;
81
- prefix = new_prefix;
82
- }
83
- if (removed == 0 ) { // E.g. if the last token is a special token
84
- return " " ;
85
- }
86
- // If constrained decoding would give back the original prompt, there is no need to modify the context
134
+ // NB. To avoid returning empty `tokens`, at least 1 token will remain in `tokens` after rolling back.
135
+ // It is the caller's responsibility to add BOS to the start of the prompt if they want to roll back the whole prompt.
136
+ token_healing_info info = llama_token_healing_get_prefix (ctx_main, th_type, tokens, max_to_remove);
137
+
138
+ // If constrained decoding would give back the original prompt, there is no need to modify the prompt.
87
139
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
88
140
th_type == llama_token_healing_type::DYNAMIC_MULTI;
89
- const std::vector<llama_token> candidates = token_healing_find_prefix (ctx_main, prefix, is_multi_step);
90
- LOG (" token_healing: prefix = '%s' (%d tokens)\n " , prefix.c_str (), removed );
91
- if (removed == 1 && candidates.size () == 1 ) {
141
+ const std::vector<llama_token> candidates = token_healing_get_candidates (ctx_main, info. prefix , is_multi_step);
142
+ LOG (" token_healing: prefix = '%s' (%d tokens)\n " , info. prefix .c_str (), info. n_tokens_removed );
143
+ if (info. n_tokens_removed == 1 && candidates.size () == 1 ) {
92
144
LOG (" token_healing: nothing to heal\n " );
93
145
return " " ;
94
146
}
95
147
// Finalize outputs
96
148
if (n_removed != nullptr ) {
97
- *n_removed = removed ;
149
+ *n_removed = info. n_tokens_removed ;
98
150
}
99
- tokens.resize (n_ctx - removed );
100
- return prefix;
151
+ tokens.resize (tokens. size () - info. n_tokens_removed );
152
+ return info. prefix ;
101
153
}
102
154
103
155
void llama_token_healing_set_prefix (llama_sampling_context * ctx_sampling, const std::string & prefix) {
@@ -506,7 +558,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
506
558
if (params.token_healing_enabled && !th_prefix.empty ()) {
507
559
const bool is_multi_step = th_type == llama_token_healing_type::ROLLBACK_MULTI ||
508
560
th_type == llama_token_healing_type::DYNAMIC_MULTI;
509
- std::vector<llama_token> th_candidates = token_healing_find_prefix (ctx_main, th_prefix, is_multi_step);
561
+ std::vector<llama_token> th_candidates = token_healing_get_candidates (ctx_main, th_prefix, is_multi_step);
510
562
511
563
LOG (" token_healing: prefix = '%s'\n " , th_prefix.c_str ());
512
564
for (const llama_token token_id : th_candidates) {
0 commit comments