Skip to content

Commit d5d30b2

Browse files
committed
llama : pre-tokenize non-special user-defined tokens first
1 parent ac0f33c commit d5d30b2

File tree

2 files changed

+21
-37
lines changed

2 files changed

+21
-37
lines changed

src/llama.cpp

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5495,28 +5495,6 @@ static void llm_load_vocab(
54955495
vocab.token_to_id[word] = i;
54965496
vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size());
54975497

5498-
// TODO: properly handle pre-normalized added_tokens and remove this
5499-
// handle space tokens with dual tokens,
5500-
// like the pre-normalized added_tokens
5501-
// of neox-style tokenizers (mpt, olmo, stablelm, etc)
5502-
if (word.find(' ') != std::string::npos) {
5503-
// same as in the internal `unicode_byte_encoding_process`
5504-
// TODO: extract and expose this in some unicode_* function
5505-
std::string text_utf;
5506-
auto utf_word = unicode_cpts_from_utf8(word);
5507-
for (size_t i = 0; i < utf_word.size(); ++i) {
5508-
text_utf += unicode_cpt_to_utf8(utf_word[i]);
5509-
}
5510-
5511-
std::string encoded_token;
5512-
for (char & c : text_utf) {
5513-
encoded_token += unicode_byte_to_utf8(c);
5514-
}
5515-
5516-
// override token id
5517-
vocab.token_to_id[encoded_token] = i;
5518-
}
5519-
55205498
auto & token_data = vocab.id_to_token[i];
55215499
token_data.text = std::move(word);
55225500
token_data.score = scores ? scores[i] : 0.0f;
@@ -5534,6 +5512,13 @@ static void llm_load_vocab(
55345512
default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break;
55355513
}
55365514
}
5515+
5516+
if ((token_data.attr & LLAMA_TOKEN_ATTR_USER_DEFINED) && token_data.text.find('<') && token_data.text.rfind('>')) {
5517+
// Some models mark some added tokens which ought to be control tokens as not special.
5518+
// (e.g. command-r, command-r-plus, deepseek-coder)
5519+
// TODO: should this be fixed in the convert script instead?
5520+
token_data.attr = LLAMA_TOKEN_ATTR_CONTROL;
5521+
}
55375522
}
55385523
GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size());
55395524

@@ -15426,13 +15411,6 @@ struct llm_tokenizer_bpe {
1542615411
"[0-9][0-9][0-9]",
1542715412
};
1542815413
break;
15429-
case LLAMA_VOCAB_PRE_TYPE_MPT:
15430-
case LLAMA_VOCAB_PRE_TYPE_OLMO:
15431-
regex_exprs = {
15432-
"[ ]{2,24}", // the spaces from the added_tokens are split separately
15433-
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
15434-
};
15435-
break;
1543615414
case LLAMA_VOCAB_PRE_TYPE_STARCODER:
1543715415
case LLAMA_VOCAB_PRE_TYPE_REFACT:
1543815416
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
@@ -15442,6 +15420,8 @@ struct llm_tokenizer_bpe {
1544215420
};
1544315421
break;
1544415422
case LLAMA_VOCAB_PRE_TYPE_GPT2:
15423+
case LLAMA_VOCAB_PRE_TYPE_MPT:
15424+
case LLAMA_VOCAB_PRE_TYPE_OLMO:
1544515425
case LLAMA_VOCAB_PRE_TYPE_JAIS:
1544615426
regex_exprs = {
1544715427
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
@@ -15523,10 +15503,6 @@ struct llm_tokenizer_bpe {
1552315503
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
1552415504
int final_prev_index = -1;
1552515505

15526-
// FIXME: pre-tokenize added_tokens (user-defined tokens) before other pre-tokenization
15527-
// ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726
15528-
// (useful for neox-style tokenizers)
15529-
1553015506
const auto word_collection = unicode_regex_split(text, regex_exprs);
1553115507

1553215508
symbols_final.clear();
@@ -16192,12 +16168,20 @@ struct fragment_buffer_variant {
1619216168

1619316169
// #define PRETOKENIZERDEBUG
1619416170

16195-
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
16171+
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer, bool parse_special) {
1619616172
// for each special token
1619716173
for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
1619816174
const auto & data = vocab.id_to_token[special_id];
1619916175
const auto & special_token = data.text;
1620016176

16177+
if (!parse_special && (data.attr & LLAMA_TOKEN_ATTR_CONTROL)) {
16178+
// Only ignore control tokens when parse_special == false
16179+
continue;
16180+
// User-defined tokens are still pre-tokenized before everything else
16181+
// ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726
16182+
// This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.)
16183+
}
16184+
1620116185
// for each text fragment
1620216186
std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
1620316187
while (it != buffer.end()) {
@@ -16310,7 +16294,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
1631016294

1631116295
if (!raw_text.empty()) {
1631216296
fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
16313-
if (parse_special) tokenizer_st_partition(vocab, fragment_buffer);
16297+
tokenizer_st_partition(vocab, fragment_buffer, parse_special);
1631416298
}
1631516299

1631616300
switch (vocab.type) {

tests/test-tokenizer-0.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ int main(int argc, char **argv) {
195195
const bool add_special = false;
196196

197197
for (const auto & test_kv : k_tests) {
198-
const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special, true);
198+
const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special);
199199

200200
printf("\n");
201201
printf("src: '%s'\n", test_kv.first.c_str());
@@ -253,7 +253,7 @@ int main(int argc, char **argv) {
253253
{
254254
const auto t_start = ggml_time_us();
255255

256-
res = llama_tokenize(ctx, text, add_special, true);
256+
res = llama_tokenize(ctx, text, add_special);
257257

258258
const auto t_end = ggml_time_us();
259259

0 commit comments

Comments
 (0)