Skip to content

Commit 9854a9c

Browse files
author
jaime-m-p
committed
Symetric params for llama_tokenize() and llama_detokenize()
1 parent 4a28063 commit 9854a9c

File tree

4 files changed

+29
-12
lines changed

4 files changed

+29
-12
lines changed

common/common.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2924,10 +2924,10 @@ std::string llama_token_to_piece(const struct llama_context * ctx, llama_token t
29242924
std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
29252925
std::string text;
29262926
text.resize(std::max(text.capacity(), tokens.size()));
2927-
int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), special);
2927+
int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
29282928
if (n_chars < 0) {
29292929
text.resize(-n_chars);
2930-
n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), special);
2930+
n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
29312931
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
29322932
}
29332933

llama.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18503,16 +18503,30 @@ int32_t llama_detokenize(
1850318503
int32_t n_tokens,
1850418504
char * text,
1850518505
int32_t text_len_max,
18506-
bool special) {
18506+
bool remove_special,
18507+
bool unparse_special) {
1850718508
// remove the leading space of the first non-control token
1850818509
static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
18509-
bool remove_space = !special && model->vocab.tokenizer_add_space_prefix;
18510+
bool remove_space = !unparse_special && model->vocab.tokenizer_add_space_prefix;
1851018511
int32_t avail = text_len_max;
1851118512
int32_t total = 0;
1851218513

18514+
if (remove_special && model->vocab.tokenizer_add_bos) {
18515+
if (n_tokens > 0 && tokens[0] == model->vocab.special_bos_id) {
18516+
n_tokens--;
18517+
tokens++;
18518+
}
18519+
}
18520+
18521+
if (remove_special && model->vocab.tokenizer_add_eos) {
18522+
if (n_tokens > 0 && tokens[n_tokens-1] == model->vocab.special_eos_id) {
18523+
n_tokens--;
18524+
}
18525+
}
18526+
1851318527
for (int32_t i = 0; i < n_tokens; ++i) {
1851418528
GGML_ASSERT(avail >= 0);
18515-
int32_t n_chars = llama_token_to_piece(model, tokens[i], text, avail, remove_space, special);
18529+
int32_t n_chars = llama_token_to_piece(model, tokens[i], text, avail, remove_space, unparse_special);
1851618530
const llama_token_attr attr = llama_token_get_attr(model, tokens[i]);
1851718531
remove_space = remove_space && (attr & attr_special); // until non-control token
1851818532
if (n_chars < 0) {

llama.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,7 @@ extern "C" {
874874
/// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
875875
/// @return Returns the number of tokens on success, no more than n_tokens_max
876876
/// @return Returns a negative number on failure - the number of tokens that would have been returned
877+
/// @param add_special Allow to add BOS and EOS tokens if model is configured to do so.
877878
/// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
878879
/// as plaintext. Does not insert a leading space.
879880
LLAMA_API int32_t llama_tokenize(
@@ -898,18 +899,20 @@ extern "C" {
898899
int32_t lstrip,
899900
bool special);
900901

901-
/// @details Convert the provided tokens into text.
902+
/// @details Convert the provided tokens into text (inverse of llama_tokenize()).
902903
/// @param text The char pointer must be large enough to hold the resulting text.
903904
/// @return Returns the number of chars/bytes on success, no more than text_len_max.
904905
/// @return Returns a negative number on failure - the number of chars/bytes that would have been returned.
905-
/// @param special If true, special tokens are rendered in the output.
906+
/// @param remove_special Allow to remove BOS and EOS tokens if model is configured to do so.
907+
/// @param unparse_special If true, special tokens are rendered in the output.
906908
LLAMA_API int32_t llama_detokenize(
907909
const struct llama_model * model,
908910
const llama_token * tokens,
909911
int32_t n_tokens,
910912
char * text,
911913
int32_t text_len_max,
912-
bool special);
914+
bool remove_special,
915+
bool unparse_special);
913916

914917
/// Apply chat template. Inspired by hf apply_chat_template() on python.
915918
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"

tests/test-tokenizer-random.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,15 @@ def tokenize(self, text: str, add_special: bool = False, parse_special: bool = F
9898
num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, len(self.token_ids), add_special, parse_special)
9999
return list(self.token_ids[0:num])
100100

101-
def detokenize(self, ids: list[int], special: bool = False) -> str:
101+
def detokenize(self, ids: list[int], remove_special: bool = False, unparse_special: bool = False) -> str:
102102
if len(self.token_ids) < len(ids):
103103
self.token_ids = self.ffi.new("llama_token[]", 2 * len(ids))
104104
for i, id in enumerate(ids):
105105
self.token_ids[i] = id
106-
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), special)
106+
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
107107
while num < 0 and len(self.text_buff) < (16 << 20):
108108
self.text_buff = self.ffi.new("uint8_t[]", -2 * num)
109-
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), special)
109+
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
110110
return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
111111

112112

@@ -160,7 +160,7 @@ def encode(self, text: str) -> list[int]:
160160
return self.model.tokenize(text, add_special=True, parse_special=True)
161161

162162
def decode(self, ids: list[int]) -> str:
163-
return self.model.detokenize(ids, special=True)
163+
return self.model.detokenize(ids, remove_special=False, unparse_special=True)
164164

165165

166166
def generator_custom_text() -> Iterator[str]:

0 commit comments

Comments
 (0)