Skip to content

Commit 1b67731

Browse files
authored
BERT tokenizer fixes (#6498)
Key changes: * BERT conversion: fix abuse of LlamaHfVocab, do not set BOS or EOS * Nomic Embed conversion: pad vocab instead of slicing embedding tensor * llama_tokenize: handle added special tokens like HF does
1 parent c4a3a4f commit 1b67731

File tree

20 files changed

+221
-194
lines changed

20 files changed

+221
-194
lines changed

common/common.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2212,23 +2212,23 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
22122212
std::vector<llama_token> llama_tokenize(
22132213
const struct llama_context * ctx,
22142214
const std::string & text,
2215-
bool add_bos,
2216-
bool special) {
2217-
return llama_tokenize(llama_get_model(ctx), text, add_bos, special);
2215+
bool add_special,
2216+
bool parse_special) {
2217+
return llama_tokenize(llama_get_model(ctx), text, add_special, parse_special);
22182218
}
22192219

22202220
std::vector<llama_token> llama_tokenize(
22212221
const struct llama_model * model,
22222222
const std::string & text,
2223-
bool add_bos,
2224-
bool special) {
2223+
bool add_special,
2224+
bool parse_special) {
22252225
// upper limit for the number of tokens
2226-
int n_tokens = text.length() + add_bos;
2226+
int n_tokens = text.length() + 2 * add_special;
22272227
std::vector<llama_token> result(n_tokens);
2228-
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
2228+
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
22292229
if (n_tokens < 0) {
22302230
result.resize(-n_tokens);
2231-
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
2231+
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
22322232
GGML_ASSERT(check == -n_tokens);
22332233
} else {
22342234
result.resize(n_tokens);

common/common.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,14 +223,14 @@ void llama_batch_add(
223223
std::vector<llama_token> llama_tokenize(
224224
const struct llama_context * ctx,
225225
const std::string & text,
226-
bool add_bos,
227-
bool special = false);
226+
bool add_special,
227+
bool parse_special = false);
228228

229229
std::vector<llama_token> llama_tokenize(
230230
const struct llama_model * model,
231231
const std::string & text,
232-
bool add_bos,
233-
bool special = false);
232+
bool add_special,
233+
bool parse_special = false);
234234

235235
// tokenizes a token into a piece
236236
// should work similar to Python's `tokenizer.id_to_piece`

convert-hf-to-gguf.py

Lines changed: 19 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -227,15 +227,14 @@ def _get_part_names(self):
227227
return ("pytorch_model.bin",)
228228
return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
229229

230-
def _set_vocab_gpt2(self):
231-
dir_model = self.dir_model
232-
hparams = self.hparams
230+
# used for GPT-2 BPE and WordPiece vocabs
231+
def get_basic_vocab(self) -> tuple[list[str], list[int]]:
233232
tokens: list[str] = []
234233
toktypes: list[int] = []
235234

236235
from transformers import AutoTokenizer
237-
tokenizer = AutoTokenizer.from_pretrained(dir_model)
238-
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
236+
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
237+
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
239238
assert max(tokenizer.vocab.values()) < vocab_size
240239

241240
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
@@ -255,11 +254,15 @@ def _set_vocab_gpt2(self):
255254
tokens.append(reverse_vocab[i])
256255
toktypes.append(gguf.TokenType.NORMAL)
257256

257+
return tokens, toktypes
258+
259+
def _set_vocab_gpt2(self) -> None:
260+
tokens, toktypes = self.get_basic_vocab()
258261
self.gguf_writer.add_tokenizer_model("gpt2")
259262
self.gguf_writer.add_token_list(tokens)
260263
self.gguf_writer.add_token_types(toktypes)
261264

262-
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
265+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
263266
special_vocab.add_to_gguf(self.gguf_writer)
264267

265268
def _set_vocab_qwen(self):
@@ -2043,34 +2046,25 @@ def set_gguf_parameters(self):
20432046
self.gguf_writer.add_pooling_type(pooling_type)
20442047

20452048
def set_vocab(self):
2046-
# use huggingface vocab to get all tokens
2047-
vocab = LlamaHfVocab(self.dir_model, ignore_nonllama=True)
2048-
tokens, scores, toktypes = zip(*vocab.all_tokens())
2049-
assert len(tokens) == vocab.vocab_size
2050-
self.vocab_size = vocab.vocab_size
2049+
tokens, toktypes = self.get_basic_vocab()
2050+
self.vocab_size = len(tokens)
20512051

20522052
# we need this to validate the size of the token_type embeddings
20532053
# though currently we are passing all zeros to the token_type embeddings
2054-
n_token_types = len(set(toktypes))
2055-
self.gguf_writer.add_token_type_count(n_token_types)
2054+
self.gguf_writer.add_token_type_count(2) # "Sequence A" or "Sequence B"
20562055

20572056
# convert to phantom space vocab
2058-
def phantom(tok, typ):
2059-
if tok.startswith(b"[") and tok.endswith(b"]"):
2057+
def phantom(tok):
2058+
if tok.startswith("[") and tok.endswith("]"):
20602059
return tok
2061-
if tok.startswith(b"##"):
2060+
if tok.startswith("##"):
20622061
return tok[2:]
2063-
return b"\xe2\x96\x81" + tok
2064-
tokens = tuple(phantom(t, y) for t, y in zip(tokens, toktypes))
2065-
2066-
# set up bos and eos tokens (cls and sep)
2067-
self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id)
2068-
self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id)
2062+
return "\u2581" + tok
2063+
tokens = list(map(phantom, tokens))
20692064

20702065
# add vocab to gguf
20712066
self.gguf_writer.add_tokenizer_model("bert")
20722067
self.gguf_writer.add_token_list(tokens)
2073-
self.gguf_writer.add_token_scores(scores)
20742068
self.gguf_writer.add_token_types(toktypes)
20752069

20762070
# handle special tokens
@@ -2142,16 +2136,6 @@ def set_gguf_parameters(self):
21422136
super().set_gguf_parameters()
21432137
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
21442138

2145-
def get_tensors(self):
2146-
assert self.vocab_size is not None
2147-
for name, data in super().get_tensors():
2148-
# Nomic Embed's token embeddings tensor is padded, but llama.cpp wants tensor sizes to match exactly.
2149-
if name == 'embeddings.word_embeddings.weight' and data.shape[1] != self.vocab_size:
2150-
rounded_vocab_size = (self.vocab_size + 63) // 64 * 64
2151-
assert data.shape == (rounded_vocab_size, self.hparams["n_embd"])
2152-
data = data[:self.vocab_size, :]
2153-
yield name, data
2154-
21552139

21562140
@Model.register("GemmaForCausalLM")
21572141
class GemmaModel(Model):
@@ -2327,7 +2311,8 @@ def write_tensors(self):
23272311
data = data.astype(np.float32)
23282312

23292313
# if f16 desired, convert big float32 2-dim weight tensors to float16
2330-
if self.ftype == 1 and data_dtype == np.float32 and new_name.removesuffix(".weight").endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
2314+
new_weight_name = new_name[:-len(".weight")] if new_name.endswith(".weight") else ""
2315+
if self.ftype == 1 and data_dtype == np.float32 and new_weight_name.endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
23312316
data = data.astype(np.float16)
23322317

23332318
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")

convert-persimmon-to-gguf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#!/usr/bin/env python3
2+
from __future__ import annotations
3+
24
import argparse
35
import os
46
import sys

convert.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import gguf
3434

3535
if TYPE_CHECKING:
36-
from typing import TypeAlias
36+
from typing_extensions import Self, TypeAlias
3737

3838
if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'):
3939
faulthandler.register(signal.SIGUSR1)
@@ -517,17 +517,15 @@ class LlamaHfVocab(Vocab):
517517
tokenizer_model = "llama"
518518
name = "hfft"
519519

520-
def __init__(self, base_path: Path, ignore_nonllama: bool = False):
520+
def __init__(self, base_path: Path):
521521
fname_tokenizer = base_path / FAST_TOKENIZER_FILE
522522
# if this fails, FileNotFoundError propagates to caller
523523
with open(fname_tokenizer, encoding='utf-8') as f:
524524
tokenizer_json = json.load(f)
525525

526526
# pre-check so we know if we need transformers
527527
tokenizer_model: dict[str, Any] = tokenizer_json['model']
528-
if ignore_nonllama:
529-
pass # workaround incorrect use of this class for WordPiece
530-
elif (
528+
if (
531529
tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
532530
or tokenizer_json['decoder']['type'] != 'Sequence'
533531
):
@@ -647,16 +645,17 @@ def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
647645

648646

649647
class Tensor(ABC):
648+
ndarray: NDArray
650649
data_type: DataType
651650

652651
@abstractmethod
653-
def astype(self, data_type: DataType) -> Tensor: ...
652+
def astype(self, data_type: DataType) -> Self: ...
654653
@abstractmethod
655-
def permute(self, n_head: int, n_head_kv: int) -> Tensor: ...
654+
def permute(self, n_head: int, n_head_kv: int) -> Self: ...
656655
@abstractmethod
657-
def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: ...
656+
def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> Self: ...
658657
@abstractmethod
659-
def part(self, n_part: int) -> UnquantizedTensor: ...
658+
def part(self, n_part: int) -> Self: ...
660659
@abstractmethod
661660
def to_ggml(self) -> GGMLCompatibleTensor: ...
662661

@@ -673,13 +672,13 @@ def __init__(self, ndarray: NDArray):
673672
self.ndarray = ndarray
674673
self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]
675674

676-
def astype(self, data_type: DataType) -> Tensor:
675+
def astype(self, data_type: DataType) -> UnquantizedTensor:
677676
dtype = data_type.dtype
678677
if self.data_type == DT_BF16:
679678
self.ndarray = bf16_to_fp32(self.ndarray)
680679
return UnquantizedTensor(self.ndarray.astype(dtype))
681680

682-
def to_ggml(self) -> UnquantizedTensor:
681+
def to_ggml(self) -> Self:
683682
return self
684683

685684
def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor:

examples/embedding/embedding.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,10 @@ int main(int argc, char ** argv) {
123123
inputs.push_back(inp);
124124
}
125125

126-
// add eos if not present
126+
// add SEP if not present
127127
for (auto & inp : inputs) {
128-
if (inp.empty() || inp.back() != llama_token_eos(model)) {
129-
inp.push_back(llama_token_eos(model));
128+
if (inp.empty() || inp.back() != llama_token_sep(model)) {
129+
inp.push_back(llama_token_sep(model));
130130
}
131131
}
132132

examples/imatrix/imatrix.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,12 +349,13 @@ static void process_logits(
349349
static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl, int from_chunk) {
350350

351351
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
352+
GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1);
352353
const int n_ctx = llama_n_ctx(ctx);
353354

354355
auto tim1 = std::chrono::high_resolution_clock::now();
355356
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
356357

357-
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
358+
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, true);
358359

359360
auto tim2 = std::chrono::high_resolution_clock::now();
360361
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());

examples/infill/infill.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ int main(int argc, char ** argv) {
239239
LOG_TEE("%s\n", get_system_info(params).c_str());
240240
}
241241
const bool add_bos = llama_should_add_bos_token(model);
242+
GGML_ASSERT(llama_add_eos_token(model) != 1);
242243
LOG("add_bos: %d\n", add_bos);
243244

244245
bool suff_rm_leading_spc = params.escape;
@@ -279,10 +280,10 @@ int main(int argc, char ** argv) {
279280
if (ctx_guidance) {
280281
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
281282

282-
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos);
283+
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true);
283284
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
284285

285-
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
286+
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true);
286287
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
287288

288289
original_prompt_len = original_inp.size();

examples/llava/llava-cli.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
146146
int n_past = 0;
147147

148148
const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
149-
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx_llava->ctx_llama));
150149

151150
std::string system_prompt, user_prompt;
152151
size_t image_pos = prompt.find("<image>");
@@ -180,7 +179,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
180179
}
181180
}
182181

183-
eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, add_bos);
182+
eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, true);
184183
llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past);
185184
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);
186185

examples/lookahead/lookahead.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,10 @@ int main(int argc, char ** argv) {
6464
std::tie(model, ctx) = llama_init_from_gpt_params(params);
6565

6666
// Tokenize the prompt
67-
const bool add_bos = llama_should_add_bos_token(model);
68-
LOG("add_bos tgt: %d\n", add_bos);
69-
7067
std::vector<llama_token> inp;
7168
std::vector<llama_token> all;
7269

73-
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
70+
inp = ::llama_tokenize(ctx, params.prompt, true, true);
7471
all = inp;
7572

7673
const int max_context_size = llama_n_ctx(ctx);

examples/lookup/lookup-create.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@ int main(int argc, char ** argv){
2828
GGML_ASSERT(model != nullptr);
2929

3030
// tokenize the prompt
31-
const bool add_bos = llama_should_add_bos_token(model);
32-
3331
std::vector<llama_token> inp;
34-
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
32+
inp = ::llama_tokenize(ctx, params.prompt, true, true);
3533
fprintf(stderr, "%s: tokenization done\n", __func__);
3634

3735

examples/lookup/lookup-stats.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,8 @@ int main(int argc, char ** argv){
3434
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
3535

3636
// tokenize the prompt
37-
const bool add_bos = llama_should_add_bos_token(model);
38-
LOG("add_bos tgt: %d\n", add_bos);
39-
4037
std::vector<llama_token> inp;
41-
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
38+
inp = ::llama_tokenize(ctx, params.prompt, true, true);
4239

4340
llama_ngram_cache ngram_cache_context;
4441
llama_ngram_cache ngram_cache_dynamic;

examples/lookup/lookup.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,8 @@ int main(int argc, char ** argv){
4242
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
4343

4444
// tokenize the prompt
45-
const bool add_bos = llama_should_add_bos_token(model);
46-
LOG("add_bos tgt: %d\n", add_bos);
47-
4845
std::vector<llama_token> inp;
49-
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
46+
inp = ::llama_tokenize(ctx, params.prompt, true, true);
5047

5148
llama_ngram_cache ngram_cache_context;
5249
llama_ngram_cache ngram_cache_dynamic;

examples/main/main.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ int main(int argc, char ** argv) {
246246
}
247247

248248
const bool add_bos = llama_should_add_bos_token(model);
249+
GGML_ASSERT(llama_add_eos_token(model) != 1);
249250
LOG("add_bos: %d\n", add_bos);
250251

251252
std::vector<llama_token> embd_inp;
@@ -255,7 +256,7 @@ int main(int argc, char ** argv) {
255256
if (params.chatml) {
256257
params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>";
257258
}
258-
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
259+
embd_inp = ::llama_tokenize(ctx, params.prompt, true, true);
259260
} else {
260261
LOG("use session tokens\n");
261262
embd_inp = session_tokens;
@@ -277,10 +278,10 @@ int main(int argc, char ** argv) {
277278
if (ctx_guidance) {
278279
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
279280

280-
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos, true);
281+
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true, true);
281282
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
282283

283-
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
284+
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
284285
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
285286

286287
original_prompt_len = original_inp.size();
@@ -339,14 +340,14 @@ int main(int argc, char ** argv) {
339340
}
340341

341342
// prefix & suffix for instruct mode
342-
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos, true);
343-
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true);
343+
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true, true);
344+
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true);
344345

345346
LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str());
346347
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str());
347348

348349
// chatml prefix & suffix
349-
const auto cml_pfx = ::llama_tokenize(ctx, "\n<|im_start|>user\n", add_bos, true);
350+
const auto cml_pfx = ::llama_tokenize(ctx, "\n<|im_start|>user\n", true, true);
350351
const auto cml_sfx = ::llama_tokenize(ctx, "<|im_end|>\n<|im_start|>assistant\n", false, true);
351352

352353
LOG("cml_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_pfx).c_str());

0 commit comments

Comments
 (0)