Skip to content

Commit 677bf2e

Browse files
committed
llama : optimize long word tokenization with WPM
ggml-ci
1 parent 2075a66 commit 677bf2e

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

llama.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2293,6 +2293,8 @@ struct llama_vocab {
22932293
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
22942294
enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
22952295

2296+
int max_token_len = 0; // used for optimizing longest token search
2297+
22962298
std::unordered_map<token, id> token_to_id;
22972299
std::vector<token_data> id_to_token;
22982300

@@ -4939,6 +4941,7 @@ static void llm_load_vocab(
49394941
GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
49404942

49414943
vocab.token_to_id[word] = i;
4944+
vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size());
49424945

49434946
auto & token_data = vocab.id_to_token[i];
49444947
token_data.text = std::move(word);
@@ -5249,6 +5252,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
52495252
if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
52505253
if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); }
52515254

5255+
LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
5256+
52525257
if (model.arch == LLM_ARCH_DEEPSEEK2) {
52535258
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
52545259
LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
@@ -13448,7 +13453,7 @@ struct llm_tokenizer_bpe {
1344813453
struct llm_tokenizer_wpm {
1344913454
llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
1345013455

13451-
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
13456+
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) const {
1345213457
const auto & token_map = vocab.token_to_id;
1345313458

1345413459
// normalize and split by whitespace
@@ -13457,7 +13462,7 @@ struct llm_tokenizer_wpm {
1345713462
// bos token prepended already
1345813463

1345913464
// find the longest tokens that form the words
13460-
for (const std::string &word : words) {
13465+
for (const std::string & word : words) {
1346113466
// skip empty words
1346213467
if (word.size() == 0) {
1346313468
continue;
@@ -13474,7 +13479,7 @@ struct llm_tokenizer_wpm {
1347413479
for (int i = 0; i < n; ++i) {
1347513480
// loop through possible match length
1347613481
bool match = false;
13477-
for (int j = n; j > i; j--) {
13482+
for (int j = std::min(n, i + vocab.max_token_len + 1); j > i; j--) {
1347813483
auto it = token_map.find(word1.substr(i, j - i));
1347913484
if (it != token_map.end()) {
1348013485
output.push_back(it->second);
@@ -13497,7 +13502,8 @@ struct llm_tokenizer_wpm {
1349713502
}
1349813503
}
1349913504

13500-
std::vector<std::string> preprocess(const std::string & text) {
13505+
// TODO: reduce string copies by using cpts_offs array
13506+
std::vector<std::string> preprocess(const std::string & text) const {
1350113507
const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
1350213508
std::vector<std::string> words(1, "");
1350313509

@@ -13792,14 +13798,15 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
1379213798
output.push_back(vocab.special_cls_id);
1379313799
}
1379413800

13801+
llm_tokenizer_wpm tokenizer(vocab);
13802+
1379513803
for (const auto & fragment : fragment_buffer) {
1379613804
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
1379713805
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
1379813806

1379913807
#ifdef PRETOKENIZERDEBUG
1380013808
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
1380113809
#endif
13802-
llm_tokenizer_wpm tokenizer(vocab);
1380313810
tokenizer.tokenize(raw_text, output);
1380413811
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1380513812
output.push_back(fragment.token);

unicode.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,7 @@ std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & c
596596

597597
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
598598
std::vector<uint32_t> result;
599+
result.reserve(utf8.size());
599600
size_t offset = 0;
600601
while (offset < utf8.size()) {
601602
result.push_back(unicode_cpt_from_utf8(utf8, offset));

0 commit comments

Comments
 (0)