Skip to content

Commit 31ac583

Browse files
authored
llama : keep track of all EOG tokens in the vocab (#9609)
ggml-ci
1 parent cea1486 commit 31ac583

File tree

3 files changed

+61
-18
lines changed

3 files changed

+61
-18
lines changed

src/llama-vocab.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1570,11 +1570,7 @@ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, lla
15701570
}
15711571

15721572
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
1573-
return token != -1 && (
1574-
token == llama_token_eos_impl(vocab) ||
1575-
token == llama_token_eot_impl(vocab) ||
1576-
token == llama_token_eom_impl(vocab)
1577-
);
1573+
return token != -1 && vocab.special_eog_ids.count(token) > 0;
15781574
}
15791575

15801576
bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {

src/llama-vocab.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <vector>
77
#include <unordered_map>
88
#include <map>
9+
#include <set>
910

1011
struct llama_vocab {
1112
using id = llama_token;
@@ -49,12 +50,15 @@ struct llama_vocab {
4950
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
5051
id special_eom_id = -1;
5152

53+
// set of all tokens that cause "end of generation"
54+
std::set<id> special_eog_ids;
55+
5256
// tokenizer flags
53-
bool tokenizer_add_space_prefix = false;
54-
bool tokenizer_add_bos = false;
55-
bool tokenizer_add_eos = false;
56-
bool tokenizer_ignore_merges = false;
57-
bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces
57+
bool tokenizer_add_space_prefix = false;
58+
bool tokenizer_add_bos = false;
59+
bool tokenizer_add_eos = false;
60+
bool tokenizer_ignore_merges = false;
61+
bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces
5862
bool tokenizer_remove_extra_whitespaces = false;
5963
bool tokenizer_escape_whitespaces = true;
6064
bool tokenizer_treat_whitespace_as_suffix = false;

src/llama.cpp

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6509,21 +6509,21 @@ static void llm_load_vocab(
65096509
// for now, we apply this workaround to find the EOT token based on its text
65106510
if (vocab.special_eot_id == -1) {
65116511
for (const auto & t : vocab.token_to_id) {
6512-
if (
6512+
if (false
65136513
// TODO: gemma "<end_of_turn>" is exported as a normal token, so the following check does not work
65146514
// need to fix convert script
65156515
//vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
6516-
(t.first == "<|eot_id|>" ||
6517-
t.first == "<|im_end|>" ||
6518-
t.first == "<|end|>" ||
6519-
t.first == "<end_of_turn>" ||
6520-
t.first == "<|endoftext|>"
6521-
)
6516+
|| t.first == "<|eot_id|>"
6517+
|| t.first == "<|im_end|>"
6518+
|| t.first == "<|end|>"
6519+
|| t.first == "<end_of_turn>"
6520+
|| t.first == "<|endoftext|>"
6521+
|| t.first == "<EOT>"
65226522
) {
65236523
vocab.special_eot_id = t.second;
65246524
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
65256525
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
6526-
__func__, t.first.c_str());
6526+
__func__, t.first.c_str());
65276527
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
65286528
}
65296529
break;
@@ -6546,6 +6546,44 @@ static void llm_load_vocab(
65466546
}
65476547
}
65486548
}
6549+
6550+
// maintain a list of tokens that cause end-of-generation
6551+
// this is currently determined based on the token text, which is obviously not ideal
6552+
// ref: https://github.com/ggerganov/llama.cpp/issues/9606
6553+
vocab.special_eog_ids.clear();
6554+
for (const auto & t : vocab.token_to_id) {
6555+
if (false
6556+
|| t.first == "<|eot_id|>"
6557+
|| t.first == "<|im_end|>"
6558+
|| t.first == "<|end|>"
6559+
|| t.first == "<end_of_turn>"
6560+
|| t.first == "<|endoftext|>"
6561+
|| t.first == "<|eom_id|>"
6562+
|| t.first == "<EOT>"
6563+
) {
6564+
vocab.special_eog_ids.insert(t.second);
6565+
if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
6566+
LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
6567+
__func__, t.first.c_str());
6568+
vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
6569+
}
6570+
}
6571+
}
6572+
6573+
if (vocab.special_eos_id != -1 && vocab.special_eog_ids.count(vocab.special_eos_id) == 0) {
6574+
vocab.special_eog_ids.insert(vocab.special_eos_id);
6575+
LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
6576+
}
6577+
6578+
if (vocab.special_eot_id != -1 && vocab.special_eog_ids.count(vocab.special_eot_id) == 0) {
6579+
vocab.special_eog_ids.insert(vocab.special_eot_id);
6580+
LLAMA_LOG_WARN("%s: special_eot_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
6581+
}
6582+
6583+
if (vocab.special_eom_id != -1 && vocab.special_eog_ids.count(vocab.special_eom_id) == 0) {
6584+
vocab.special_eog_ids.insert(vocab.special_eom_id);
6585+
LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
6586+
}
65496587
}
65506588

65516589
// build special tokens cache
@@ -6749,6 +6787,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
67496787
if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); }
67506788
if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
67516789
if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); }
6790+
if (vocab.special_eom_id != -1) { LLAMA_LOG_INFO( "%s: EOM token = %d '%s'\n", __func__, vocab.special_eom_id, vocab.id_to_token[vocab.special_eom_id].text.c_str() ); }
6791+
6792+
for (const auto & id : vocab.special_eog_ids) {
6793+
LLAMA_LOG_INFO( "%s: EOG token = %d '%s'\n", __func__, id, vocab.id_to_token[id].text.c_str() );
6794+
}
67526795

67536796
LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
67546797

0 commit comments

Comments
 (0)