Skip to content

llama : optimize long word tokenization with WPM #8034

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2293,6 +2293,8 @@ struct llama_vocab {
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;

int max_token_len = 0; // used for optimizing longest token search

std::unordered_map<token, id> token_to_id;
std::vector<token_data> id_to_token;

Expand Down Expand Up @@ -4939,6 +4941,7 @@ static void llm_load_vocab(
GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);

vocab.token_to_id[word] = i;
vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size());

auto & token_data = vocab.id_to_token[i];
token_data.text = std::move(word);
Expand Down Expand Up @@ -5249,6 +5252,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); }

LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);

if (model.arch == LLM_ARCH_DEEPSEEK2) {
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
Expand Down Expand Up @@ -13448,7 +13453,7 @@ struct llm_tokenizer_bpe {
struct llm_tokenizer_wpm {
llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}

void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) const {
const auto & token_map = vocab.token_to_id;

// normalize and split by whitespace
Expand All @@ -13457,7 +13462,7 @@ struct llm_tokenizer_wpm {
// bos token prepended already

// find the longest tokens that form the words
for (const std::string &word : words) {
for (const std::string & word : words) {
// skip empty words
if (word.size() == 0) {
continue;
Expand All @@ -13474,7 +13479,7 @@ struct llm_tokenizer_wpm {
for (int i = 0; i < n; ++i) {
// loop through possible match length
bool match = false;
for (int j = n; j > i; j--) {
for (int j = std::min(n, i + vocab.max_token_len + 1); j > i; j--) {
auto it = token_map.find(word1.substr(i, j - i));
if (it != token_map.end()) {
output.push_back(it->second);
Expand All @@ -13497,7 +13502,8 @@ struct llm_tokenizer_wpm {
}
}

std::vector<std::string> preprocess(const std::string & text) {
// TODO: reduce string copies by using cpts_offs array
std::vector<std::string> preprocess(const std::string & text) const {
const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
std::vector<std::string> words(1, "");

Expand Down Expand Up @@ -13792,14 +13798,15 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
output.push_back(vocab.special_cls_id);
}

llm_tokenizer_wpm tokenizer(vocab);

for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);

#ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif
llm_tokenizer_wpm tokenizer(vocab);
tokenizer.tokenize(raw_text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token);
Expand Down
1 change: 1 addition & 0 deletions unicode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & c

std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
std::vector<uint32_t> result;
result.reserve(utf8.size());
size_t offset = 0;
while (offset < utf8.size()) {
result.push_back(unicode_cpt_from_utf8(utf8, offset));
Expand Down
Loading