Skip to content

Commit 88fc854

Browse files
authored
llama : improve sep token handling (#14272)
1 parent e28c1b9 commit 88fc854

File tree

15 files changed

+161
-29
lines changed

15 files changed

+161
-29
lines changed

ci/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,7 @@ function gg_run_rerank_tiny {
779779
model_f16="${path_models}/ggml-model-f16.gguf"
780780

781781
# for this model, the SEP token is "</s>"
782-
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?</s></s>hi\nwhat is panda?</s></s>it's a bear\nwhat is panda?</s></s>The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
782+
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
783783

784784
# sample output
785785
# rerank score 0: 0.029

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2706,6 +2706,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
27062706
params.embd_sep = value;
27072707
}
27082708
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
2709+
add_opt(common_arg(
2710+
{"--cls-separator"}, "STRING",
2711+
"separator of classification sequences (default \\t) for example \"<#seq#>\"",
2712+
[](common_params & params, const std::string & value) {
2713+
params.cls_sep = value;
2714+
}
2715+
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
27092716
add_opt(common_arg(
27102717
{"--host"}, "HOST",
27112718
string_format("ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: %s)", params.hostname.c_str()),

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ struct common_params {
358358
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
359359
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
360360
std::string embd_sep = "\n"; // separator of embeddings
361+
std::string cls_sep = "\t"; // separator of classification sequences
361362

362363
// server params
363364
int32_t port = 8080; // server listens on this network port

convert_hf_to_gguf.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2145,7 +2145,6 @@ def __init__(self, *args, **kwargs):
21452145

21462146
def set_vocab(self):
21472147
self._set_vocab_gpt2()
2148-
self.gguf_writer.add_add_bos_token(True)
21492148

21502149
def set_gguf_parameters(self):
21512150
super().set_gguf_parameters()
@@ -3918,9 +3917,6 @@ def _xlmroberta_set_vocab(self) -> None:
39183917
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
39193918
special_vocab.add_to_gguf(self.gguf_writer)
39203919

3921-
self.gguf_writer.add_add_bos_token(True)
3922-
self.gguf_writer.add_add_eos_token(True)
3923-
39243920

39253921
@ModelBase.register("DistilBertModel", "DistilBertForMaskedLM", "DistilBertForSequenceClassification")
39263922
class DistilBertModel(BertModel):
@@ -3962,8 +3958,6 @@ def set_vocab(self):
39623958
bpe_tok_path = self.dir_model / "tokenizer.json"
39633959
if bpe_tok_path.exists():
39643960
self._set_vocab_gpt2()
3965-
self.gguf_writer.add_add_bos_token(True)
3966-
self.gguf_writer.add_add_eos_token(True)
39673961

39683962
# we need this to validate the size of the token_type embeddings
39693963
# though currently we are passing all zeros to the token_type embeddings
@@ -4848,8 +4842,6 @@ def set_vocab(self):
48484842
self.gguf_writer.add_token_type_count(2)
48494843
else:
48504844
raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')
4851-
self.gguf_writer.add_add_bos_token(True)
4852-
self.gguf_writer.add_add_eos_token(True)
48534845

48544846

48554847
@ModelBase.register("OpenELMForCausalLM")
@@ -5451,9 +5443,6 @@ def set_vocab(self):
54515443
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
54525444
special_vocab.add_to_gguf(self.gguf_writer)
54535445

5454-
self.gguf_writer.add_add_bos_token(False)
5455-
self.gguf_writer.add_add_eos_token(True)
5456-
54575446
def set_gguf_parameters(self):
54585447
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
54595448
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
@@ -5591,9 +5580,6 @@ def set_vocab(self):
55915580
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
55925581
special_vocab.add_to_gguf(self.gguf_writer)
55935582

5594-
self.gguf_writer.add_add_bos_token(False)
5595-
self.gguf_writer.add_add_eos_token(True)
5596-
55975583
def set_gguf_parameters(self):
55985584
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
55995585
logger.warning("Couldn't find context length in config.json, assuming default value of 512")

examples/embedding/embedding.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,36 @@ int main(int argc, char ** argv) {
133133
// max batch size
134134
const uint64_t n_batch = params.n_batch;
135135

136+
// get added sep and eos token, if any
137+
const std::string added_sep_token = llama_vocab_get_add_sep(vocab) ? llama_vocab_get_text(vocab, llama_vocab_sep(vocab)) : "";
138+
const std::string added_eos_token = llama_vocab_get_add_eos(vocab) ? llama_vocab_get_text(vocab, llama_vocab_eos(vocab)) : "";
139+
136140
// tokenize the prompts and trim
137141
std::vector<std::vector<int32_t>> inputs;
138142
for (const auto & prompt : prompts) {
139-
auto inp = common_tokenize(ctx, prompt, true, true);
143+
std::vector<llama_token> inp;
144+
145+
// split classification pairs and insert expected separator tokens
146+
if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) {
147+
std::vector<std::string> pairs = split_lines(prompt, params.cls_sep);
148+
std::string final_prompt;
149+
150+
for (size_t i = 0; i < pairs.size(); i++) {
151+
final_prompt += pairs[i];
152+
if (i != pairs.size() - 1) {
153+
if (!added_eos_token.empty()) {
154+
final_prompt += added_eos_token;
155+
}
156+
if (!added_sep_token.empty()) {
157+
final_prompt += added_sep_token;
158+
}
159+
}
160+
}
161+
162+
inp = common_tokenize(ctx, final_prompt, true, true);
163+
} else {
164+
inp = common_tokenize(ctx, prompt, true, true);
165+
}
140166
if (inp.size() > n_batch) {
141167
LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
142168
__func__, (long long int) inp.size(), (long long int) n_batch);
@@ -145,11 +171,11 @@ int main(int argc, char ** argv) {
145171
inputs.push_back(inp);
146172
}
147173

148-
// check if the last token is SEP
174+
// check if the last token is SEP/EOS
149175
// it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true'
150176
for (auto & inp : inputs) {
151-
if (inp.empty() || inp.back() != llama_vocab_sep(vocab)) {
152-
LOG_WRN("%s: last token in the prompt is not SEP\n", __func__);
177+
if (inp.empty() || (inp.back() != llama_vocab_sep(vocab) && inp.back() != llama_vocab_eos(vocab))) {
178+
LOG_WRN("%s: last token in the prompt is not SEP or EOS\n", __func__);
153179
LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__);
154180
}
155181
}

gguf-py/gguf/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ class Tokenizer:
198198
MASK_ID = "tokenizer.ggml.mask_token_id"
199199
ADD_BOS = "tokenizer.ggml.add_bos_token"
200200
ADD_EOS = "tokenizer.ggml.add_eos_token"
201+
ADD_SEP = "tokenizer.ggml.add_sep_token"
201202
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
202203
REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces"
203204
PRECOMPILED_CHARSMAP = "tokenizer.ggml.precompiled_charsmap"

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,9 @@ def add_add_bos_token(self, value: bool) -> None:
891891
def add_add_eos_token(self, value: bool) -> None:
892892
self.add_bool(Keys.Tokenizer.ADD_EOS, value)
893893

894+
def add_add_sep_token(self, value: bool) -> None:
895+
self.add_bool(Keys.Tokenizer.ADD_SEP, value)
896+
894897
def add_add_space_prefix(self, value: bool) -> None:
895898
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
896899

gguf-py/gguf/vocab.py

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def _set_special_token(self, typ: str, tid: Any) -> None:
119119
logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
120120

121121
def _try_load_from_tokenizer_json(self, path: Path) -> bool:
122+
tokenizer = None
122123
tokenizer_file = path / 'tokenizer.json'
123124
if tokenizer_file.is_file():
124125
with open(tokenizer_file, encoding = 'utf-8') as f:
@@ -152,11 +153,87 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool:
152153
added_tokens = tokenizer.get('added_tokens', {})
153154
else:
154155
added_tokens = {}
156+
tokenizer_config = None
155157
tokenizer_config_file = path / 'tokenizer_config.json'
156-
if not tokenizer_config_file.is_file():
158+
if tokenizer_config_file.is_file():
159+
with open(tokenizer_config_file, encoding = 'utf-8') as f:
160+
tokenizer_config = json.load(f)
161+
if tokenizer:
162+
special_bos = (tokenizer_config or {}).get('bos_token')
163+
special_cls = (tokenizer_config or {}).get('cls_token')
164+
special_eos = (tokenizer_config or {}).get('eos_token')
165+
special_sep = (tokenizer_config or {}).get('sep_token')
166+
if not special_bos and special_cls and tokenizer_config:
167+
tokenizer_config['bos_token'] = special_bos = special_cls
168+
if not special_eos and special_sep and tokenizer_config:
169+
tokenizer_config['eos_token'] = special_eos = special_sep
170+
post_processor = tokenizer.get('post_processor', {})
171+
for processor in post_processor.get('processors', [post_processor]):
172+
if processor.get('type') == 'RobertaProcessing':
173+
self.add_special_token['bos'] = True
174+
self.add_special_token['eos'] = True
175+
self.add_special_token['sep'] = True
176+
if not special_cls and tokenizer_config:
177+
special_cls = processor.get('cls', [special_bos])[0]
178+
tokenizer_config['cls_token'] = special_cls
179+
if not special_sep and tokenizer_config:
180+
special_sep = processor.get('sep', [special_eos])[0]
181+
tokenizer_config['sep_token'] = special_sep
182+
continue
183+
# Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
184+
# Only works with simple templates, **will** get it wrong on unusual sequences
185+
if processor.get('type') == 'TemplateProcessing':
186+
tmpl_single = processor.get('single', [])
187+
tmpl_pair = processor.get('pair', [])
188+
special_first = None
189+
special_last = None
190+
if len(tmpl_single) > 1:
191+
if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
192+
if not tokenizer_config:
193+
special_bos = special_first
194+
self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
195+
if special_first not in (special_bos, special_cls):
196+
logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
197+
if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
198+
if not tokenizer_config:
199+
special_eos = special_last
200+
self.add_special_token['eos'] = True if special_last == special_eos else False
201+
if special_last != special_eos:
202+
logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
203+
if tmpl_pair:
204+
seq_start = 1 if tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
205+
seq_stop = -1 if tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
206+
if seq_start == 0 or seq_stop is None:
207+
logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
208+
if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
209+
tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
210+
tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
211+
if tmpl_a != 'A' or tmpl_b != 'B':
212+
logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
213+
# A [sep] [eos] B
214+
if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
215+
add_sep = False
216+
if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
217+
if special_entry in (special_sep, special_eos) and not special_last:
218+
add_sep = True
219+
if special_entry not in (special_sep, special_eos):
220+
logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
221+
else:
222+
logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
223+
if len(tmpl_pair) == 2:
224+
if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
225+
if special_entry in (special_sep, special_eos):
226+
add_sep = True
227+
if special_entry not in (special_sep, special_eos):
228+
logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
229+
else:
230+
logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
231+
self.add_special_token['sep'] = add_sep
232+
if add_sep and not special_sep and tokenizer_config:
233+
tokenizer_config['sep_token'] = special_eos
234+
continue
235+
if not tokenizer_config:
157236
return True
158-
with open(tokenizer_config_file, encoding = 'utf-8') as f:
159-
tokenizer_config = json.load(f)
160237
chat_template_alt = None
161238
chat_template_file = path / 'chat_template.json'
162239
if chat_template_file.is_file():

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,6 +1044,7 @@ extern "C" {
10441044

10451045
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
10461046
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
1047+
LLAMA_API bool llama_vocab_get_add_sep(const struct llama_vocab * vocab);
10471048

10481049
LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab);
10491050
LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab);

src/llama-arch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
198198
{ LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
199199
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
200200
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
201+
{ LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" },
201202
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
202203
{ LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" },
203204
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ enum llm_kv {
194194
LLM_KV_TOKENIZER_MASK_ID,
195195
LLM_KV_TOKENIZER_ADD_BOS,
196196
LLM_KV_TOKENIZER_ADD_EOS,
197+
LLM_KV_TOKENIZER_ADD_SEP,
197198
LLM_KV_TOKENIZER_ADD_PREFIX,
198199
LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
199200
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,

src/llama-model-saver.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ void llama_model_saver::add_kv_from_model() {
228228
// add_kv(LLM_KV_TOKENIZER_MASK_ID, ???);
229229
add_kv(LLM_KV_TOKENIZER_ADD_BOS, vocab.get_add_bos());
230230
add_kv(LLM_KV_TOKENIZER_ADD_EOS, vocab.get_add_eos());
231+
add_kv(LLM_KV_TOKENIZER_ADD_SEP, vocab.get_add_sep());
231232
add_kv(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.get_add_space_prefix());
232233
add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.get_remove_extra_whitespaces());
233234
add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, vocab.get_precompiled_charsmap());

0 commit comments

Comments
 (0)