Skip to content

Commit 02c1eca

Browse files
authored
Tokenizer WPM fixes (#7500)
* Update random test: add_bos_token. * Update random test: add WPM models for testing. * Build vocab.special_tokens_cache using vocab token types. * Fix and improve WPM preprocessing. - Fix unicode edge case combinations. - Split by whitspace in the same pass. * Discard all tokens when no matching found.
1 parent 6bd12ce commit 02c1eca

File tree

2 files changed

+75
-167
lines changed

2 files changed

+75
-167
lines changed

llama.cpp

Lines changed: 62 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -2162,7 +2162,7 @@ struct llama_vocab {
21622162
std::unordered_map<token, id> token_to_id;
21632163
std::vector<token_data> id_to_token;
21642164

2165-
std::unordered_map<token, id> special_tokens_cache;
2165+
std::vector<id> special_tokens_cache;
21662166

21672167
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
21682168

@@ -4831,97 +4831,19 @@ static void llm_load_vocab(
48314831

48324832
// build special tokens cache
48334833
{
4834-
// TODO: It is unclear (to me) at this point, whether special tokes are guaranteed to be of a deterministic type,
4835-
// and will always be correctly labeled in 'added_tokens.json' etc.
4836-
// The assumption is, since special tokens aren't meant to be exposed to end user, they are designed
4837-
// to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer
4838-
// are special tokens.
4839-
// From testing, this appears to correlate 1:1 with special tokens.
4840-
//
4841-
4842-
// Counting special tokens and verifying in only one direction
4843-
// is sufficient to detect difference in those two sets.
4844-
//
4845-
uint32_t special_tokens_count_by_type = 0;
4846-
uint32_t special_tokens_count_from_verification = 0;
4847-
4848-
bool special_tokens_definition_mismatch = false;
4849-
4850-
for (const auto & t : vocab.token_to_id) {
4851-
const auto & token = t.first;
4852-
const auto & id = t.second;
4853-
4854-
// Count all non-normal tokens in the vocab while iterating
4834+
for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
48554835
if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) {
4856-
special_tokens_count_by_type++;
4836+
vocab.special_tokens_cache.push_back(id);
48574837
}
4838+
}
48584839

4859-
// Skip single character tokens
4860-
if (token.length() > 1) {
4861-
bool is_tokenizable = false;
4862-
4863-
// Split token string representation in two, in all possible ways
4864-
// and check if both halves can be matched to a valid token
4865-
for (unsigned i = 1; i < token.length();) {
4866-
const auto left = token.substr(0, i);
4867-
const auto right = token.substr(i);
4868-
4869-
// check if we didnt partition in the middle of a utf sequence
4870-
auto utf = utf8_len(left.at(left.length() - 1));
4871-
4872-
if (utf == 1) {
4873-
if (vocab.token_to_id.find(left) != vocab.token_to_id.end() &&
4874-
vocab.token_to_id.find(right) != vocab.token_to_id.end() ) {
4875-
is_tokenizable = true;
4876-
break;
4877-
}
4878-
i++;
4879-
} else {
4880-
// skip over the rest of multibyte utf sequence
4881-
i += utf - 1;
4882-
}
4883-
}
4884-
4885-
if (!is_tokenizable) {
4886-
// Some tokens are multibyte, but they are utf sequences with equivalent text length of 1
4887-
// it's faster to re-filter them here, since there are way less candidates now
4888-
4889-
// Calculate a total "utf" length of a token string representation
4890-
size_t utf8_str_len = 0;
4891-
for (unsigned i = 0; i < token.length();) {
4892-
utf8_str_len++;
4893-
i += utf8_len(token.at(i));
4894-
}
4895-
4896-
// And skip the ones which are one character
4897-
if (utf8_str_len > 1) {
4898-
// At this point what we have left are special tokens only
4899-
vocab.special_tokens_cache[token] = id;
4900-
4901-
// Count manually found special tokens
4902-
special_tokens_count_from_verification++;
4903-
4904-
// If this manually found special token is not marked as such, flag a mismatch
4905-
if (vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL) {
4906-
special_tokens_definition_mismatch = true;
4907-
}
4908-
}
4909-
}
4840+
std::sort( vocab.special_tokens_cache.begin(), vocab.special_tokens_cache.end(),
4841+
[&] (const llama_vocab::id a, const llama_vocab::id b) {
4842+
return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
49104843
}
4911-
}
4844+
);
49124845

4913-
if (special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type) {
4914-
LLAMA_LOG_WARN("%s: mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n",
4915-
__func__,
4916-
special_tokens_count_from_verification, vocab.id_to_token.size(),
4917-
special_tokens_count_by_type, vocab.id_to_token.size()
4918-
);
4919-
} else {
4920-
LLAMA_LOG_INFO("%s: special tokens definition check successful ( %u/%zu ).\n",
4921-
__func__,
4922-
special_tokens_count_from_verification, vocab.id_to_token.size()
4923-
);
4924-
}
4846+
LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.special_tokens_cache.size());
49254847
}
49264848
}
49274849

@@ -13146,7 +13068,7 @@ struct llm_tokenizer_wpm {
1314613068
llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
1314713069

1314813070
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
13149-
auto * token_map = &vocab.token_to_id;
13071+
const auto & token_map = vocab.token_to_id;
1315013072

1315113073
// normalize and split by whitespace
1315213074
std::vector<std::string> words = preprocess(text);
@@ -13161,108 +13083,89 @@ struct llm_tokenizer_wpm {
1316113083
}
1316213084

1316313085
// prepend phantom space
13164-
std::string word1 = "\xe2\x96\x81" + word;
13165-
int n = word1.size();
13086+
const std::string word1 = "\xe2\x96\x81" + word;
13087+
const int n = word1.size();
1316613088

13167-
// we're at the start of a new word
13168-
int i = 0;
13169-
bool match_any = false;
13089+
const size_t current_tokens = output.size();
1317013090

13091+
// we're at the start of a new word
1317113092
// move through character position in word
13172-
while (i < n) {
13093+
for (int i = 0; i < n; ++i) {
1317313094
// loop through possible match length
1317413095
bool match = false;
1317513096
for (int j = n; j > i; j--) {
13176-
auto it = token_map->find(word1.substr(i, j - i));
13177-
if (it != token_map->end()) {
13097+
auto it = token_map.find(word1.substr(i, j - i));
13098+
if (it != token_map.end()) {
1317813099
output.push_back(it->second);
1317913100
match = true;
13180-
match_any = true;
13181-
i = j;
13101+
i = j - 1;
1318213102
break;
1318313103
}
1318413104
}
1318513105

13186-
// must be an unknown character
13187-
if (!match) {
13188-
i++;
13106+
if (!match) { // discard all
13107+
output.resize(current_tokens);
13108+
break; // and discard next tokens
1318913109
}
1319013110
}
1319113111

1319213112
// we didn't find any matches for this word
13193-
if (!match_any) {
13113+
if (current_tokens == output.size()) {
1319413114
output.push_back(vocab.special_unk_id);
1319513115
}
1319613116
}
1319713117
}
1319813118

1319913119
std::vector<std::string> preprocess(const std::string & text) {
13200-
std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
13201-
13202-
// strip accents, strip control, uniformize whitespace,
13203-
// to lowercase, pad chinese characters, pad punctuation
13204-
std::string new_str = "";
13205-
for (uint32_t code : cpts_nfd) {
13206-
const codepoint_flags flags = unicode_cpt_flags(code);
13207-
if (flags.is_accent_mark || flags.is_control) {
13120+
const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
13121+
std::vector<std::string> words(1, "");
13122+
13123+
for (const char32_t cpt : cpts_nfd) {
13124+
const auto flags = unicode_cpt_flags(cpt);
13125+
13126+
if (flags.is_whitespace) {
13127+
if (words.back().size()) { // finish previous word if any
13128+
words.emplace_back();
13129+
}
1320813130
continue;
1320913131
}
13210-
code = unicode_tolower(code);
13211-
if (flags.is_separator || flags.is_whitespace) { //####FIXME: is_separator ?
13212-
code = ' ';
13213-
}
13214-
std::string s = unicode_cpt_to_utf8(code);
13215-
if (flags.is_punctuation || is_ascii_punct(code) || is_chinese_char(code)) {
13216-
new_str += " ";
13217-
new_str += s;
13218-
new_str += " ";
13219-
} else {
13220-
new_str += s;
13132+
13133+
assert (!flags.is_separator);
13134+
if (cpt == 0 || cpt == 0xFFFD || flags.is_control) {
13135+
continue;
1322113136
}
13222-
}
1322313137

13224-
// split by whitespace
13225-
uint64_t l = 0;
13226-
uint64_t r = 0;
13227-
std::vector<std::string> words;
13228-
while (r < new_str.size()) {
13229-
// if is whitespace
13230-
if (isspace(new_str[r], std::locale::classic())) {
13231-
if (r > l) words.push_back(new_str.substr(l, (r - l)));
13232-
l = r + 1;
13233-
r = l;
13138+
const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt));
13139+
if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) {
13140+
if (words.back().size()) { // finish previous word if any
13141+
words.emplace_back();
13142+
}
13143+
words.back() = s; // single char word
13144+
words.emplace_back(); // start a new word
1323413145
} else {
13235-
r += 1;
13146+
words.back() += s; // append char to word
1323613147
}
1323713148
}
13238-
if (r > l) {
13239-
words.push_back(new_str.substr(l, (r - l)));
13240-
}
13241-
return words;
13242-
}
1324313149

13244-
bool is_ascii_punct(uint32_t code) {
13245-
if (code > 0xFF) {
13246-
return false;
13150+
if (!words.back().size()) {
13151+
words.pop_back();
1324713152
}
13248-
auto c = char(static_cast<unsigned char>(code));
13249-
return ispunct(c, std::locale::classic());
13153+
13154+
return words;
1325013155
}
1325113156

13252-
bool is_chinese_char(uint32_t cpt) {
13253-
if ((cpt >= 0x4E00 && cpt <= 0x9FFF) ||
13254-
(cpt >= 0x3400 && cpt <= 0x4DBF) ||
13157+
static bool is_chinese_char(uint32_t cpt) {
13158+
return
13159+
(cpt >= 0x04E00 && cpt <= 0x09FFF) ||
13160+
(cpt >= 0x03400 && cpt <= 0x04DBF) ||
1325513161
(cpt >= 0x20000 && cpt <= 0x2A6DF) ||
1325613162
(cpt >= 0x2A700 && cpt <= 0x2B73F) ||
1325713163
(cpt >= 0x2B740 && cpt <= 0x2B81F) ||
1325813164
(cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
13259-
(cpt >= 0xF900 && cpt <= 0xFAFF) ||
13260-
(cpt >= 0x2F800 && cpt <= 0x2FA1F) ||
13261-
(cpt >= 0x3000 && cpt <= 0x303F) ||
13262-
(cpt >= 0xFF00 && cpt <= 0xFFEF)) {
13263-
return true; // NOLINT
13264-
}
13265-
return false;
13165+
(cpt >= 0x0F900 && cpt <= 0x0FAFF) ||
13166+
(cpt >= 0x2F800 && cpt <= 0x2FA1F);
13167+
//(cpt >= 0x3000 && cpt <= 0x303F) ||
13168+
//(cpt >= 0xFF00 && cpt <= 0xFFEF);
1326613169
}
1326713170

1326813171
const llama_vocab & vocab;
@@ -13306,9 +13209,8 @@ struct fragment_buffer_variant {
1330613209

1330713210
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
1330813211
// for each special token
13309-
for (const auto & st: vocab.special_tokens_cache) {
13310-
const auto & special_token = st.first;
13311-
const auto & special_id = st.second;
13212+
for (const llama_vocab::id special_id : vocab.special_tokens_cache) {
13213+
const auto & special_token = vocab.id_to_token[special_id].text;
1331213214

1331313215
// for each text fragment
1331413216
std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
@@ -13317,7 +13219,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
1331713219

1331813220
// if a fragment is text ( not yet processed )
1331913221
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
13320-
auto * raw_text = &(fragment.raw_text);
13222+
auto & raw_text = fragment.raw_text;
1332113223

1332213224
auto raw_text_base_offset = fragment.offset;
1332313225
auto raw_text_base_length = fragment.length;
@@ -13327,7 +13229,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
1332713229
// find the first occurrence of a given special token in this fragment
1332813230
// passing offset argument only limit the "search area" but match coordinates
1332913231
// are still relative to the source full raw_text
13330-
auto match = raw_text->find(special_token, raw_text_base_offset);
13232+
auto match = raw_text.find(special_token, raw_text_base_offset);
1333113233

1333213234
// no occurrences found, stop processing this fragment for a given special token
1333313235
if (match == std::string::npos) break;
@@ -13346,7 +13248,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
1334613248
// left
1334713249
const int64_t left_reminder_offset = raw_text_base_offset + 0;
1334813250
const int64_t left_reminder_length = match - raw_text_base_offset;
13349-
buffer.emplace_after(it, (*raw_text), left_reminder_offset, left_reminder_length);
13251+
buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
1335013252

1335113253
#ifdef PRETOKENIZERDEBUG
1335213254
LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
@@ -13362,7 +13264,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
1336213264
if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
1336313265
const int64_t right_reminder_offset = match + special_token.length();
1336413266
const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
13365-
buffer.emplace_after(it, (*raw_text), right_reminder_offset, right_reminder_length);
13267+
buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
1336613268

1336713269
#ifdef PRETOKENIZERDEBUG
1336813270
LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());

tests/test-tokenizer-random.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,10 @@ def generator_random_special_tokens(tokenizer, iterations=100) -> Iterator[str]:
167167
for m in range(iterations):
168168
rand.seed(m)
169169
words = rand.choices(special_tokens, k=500)
170-
if tokenizer.add_bos_token: # skip spam warning of double BOS
171-
while words and words[0] == tokenizer.bos_token:
170+
if words[0] == tokenizer.bos_token: # skip spam warning of double BOS
171+
while len(words) > 1 and words[1] == tokenizer.bos_token: # leave one starting BOS
172+
words.pop(0)
173+
if tokenizer.add_bos_token: # drop all starting BOS
172174
words.pop(0)
173175
yield "".join(words)
174176

@@ -293,15 +295,17 @@ def main(argv: list[str] = None):
293295
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
294296
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
295297

296-
tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", True)
297-
tokenizer.add_eos_token = getattr(tokenizer, "add_eos_token", False)
298-
299298
def func_tokenize1(text: str):
300299
return model.tokenize(text, add_special=True, parse_special=True)
301300

302301
def func_tokenize2(text: str):
303302
return tokenizer.encode(text, add_special_tokens=True)
304303

304+
ids = func_tokenize2("a")
305+
assert 1 <= len(ids) <= 3
306+
add_bos_token = len(ids) > 1 and tokenizer.bos_token_id == ids[0]
307+
tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", add_bos_token)
308+
305309
vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
306310
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text())
307311
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
@@ -324,8 +328,10 @@ def func_tokenize2(text: str):
324328
# import os
325329
# tokenizers = os.listdir(path_tokenizers)
326330
tokenizers = [
327-
"llama-spm", # SPM
328-
"phi-3", # SPM
331+
# "llama-spm", # SPM
332+
# "phi-3", # SPM
333+
"jina-v2-en", # WPM
334+
"bert-bge", # WPM
329335
]
330336

331337
for tokenizer in tokenizers:

0 commit comments

Comments
 (0)