Skip to content

Commit 6a94ae6

Browse files
committed
Add test for MPT tokenization
1 parent 22c69a2 commit 6a94ae6

File tree

5 files changed

+35
-24
lines changed

5 files changed

+35
-24
lines changed

convert-mpt-hf-to-gguf.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,21 @@ def parse_args() -> argparse.Namespace:
128128
# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
129129
tokenizer = AutoTokenizer.from_pretrained(dir_model)
130130

131+
added_vocab = tokenizer.get_added_vocab()
131132
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
132133

133134
for i in range(vocab_size):
134-
tokens.append(reverse_vocab[i] if i in reverse_vocab else f"[PAD{i}]")
135-
scores.append(0.0) # dummy
136-
toktypes.append(gguf.TokenType.NORMAL)
135+
if i in reverse_vocab:
136+
tokens.append(reverse_vocab[i])
137+
if reverse_vocab[i] not in added_vocab:
138+
toktypes.append(gguf.TokenType.NORMAL)
139+
else:
140+
toktypes.append(gguf.TokenType.USER_DEFINED)
141+
else:
142+
tokens.append(f"[PAD{i}]")
143+
toktypes.append(gguf.TokenType.USER_DEFINED)
137144

138145
gguf_writer.add_token_list(tokens)
139-
gguf_writer.add_token_scores(scores)
140146
gguf_writer.add_token_types(toktypes)
141147

142148
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)

llama.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -975,20 +975,6 @@ static void llama_nop(struct ggml_tensor * tensor) { // don't offload by default
975975
(void) tensor;
976976
}
977977

978-
static std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) {
979-
std::vector<char> result(8, 0);
980-
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
981-
if (n_tokens < 0) {
982-
result.resize(-n_tokens);
983-
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
984-
GGML_ASSERT(check == -n_tokens);
985-
} else {
986-
result.resize(n_tokens);
987-
}
988-
989-
return std::string(result.data(), result.size());
990-
}
991-
992978
//
993979
// globals
994980
//
@@ -1202,10 +1188,10 @@ struct llama_vocab {
12021188
id special_eot_id = 32010;
12031189

12041190
int find_bpe_rank(std::string token_left, std::string token_right) const {
1205-
replace_all(token_left, " ", "\u0120");
1206-
replace_all(token_left, "\n", "\u010A");
1207-
replace_all(token_right, " ", "\u0120");
1208-
replace_all(token_right, "\n", "\u010A");
1191+
GGML_ASSERT(token_left.find(" ") == std::string::npos);
1192+
GGML_ASSERT(token_left.find("\n") == std::string::npos);
1193+
GGML_ASSERT(token_right.find(" ") == std::string::npos);
1194+
GGML_ASSERT(token_right.find("\n") == std::string::npos);
12091195

12101196
auto it = bpe_ranks.find(std::make_pair(token_left, token_right));
12111197
if (it == bpe_ranks.end()) {
@@ -7461,6 +7447,21 @@ void llama_sample_repetition_penalties(
74617447
}
74627448
}
74637449

7450+
static std::string llama_token_to_piece(const struct llama_context* ctx, llama_token token) {
7451+
std::vector<char> result(8, 0);
7452+
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
7453+
if (n_tokens < 0) {
7454+
result.resize(-n_tokens);
7455+
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
7456+
GGML_ASSERT(check == -n_tokens);
7457+
}
7458+
else {
7459+
result.resize(n_tokens);
7460+
}
7461+
7462+
return std::string(result.data(), result.size());
7463+
}
7464+
74647465
void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
74657466
GGML_ASSERT(ctx);
74667467
const int64_t t_start_sample_us = ggml_time_us();
@@ -7480,7 +7481,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
74807481

74817482
for (size_t i = 0; i < candidates->size; ++i) {
74827483
const llama_token id = candidates->data[i].id;
7483-
const std::string piece = llama_token_to_str(ctx, id);
7484+
const std::string piece = llama_token_to_piece(ctx, id);
74847485
if (id == eos) {
74857486
if (!allow_eos) {
74867487
candidates->data[i].logit = -INFINITY;
@@ -7692,7 +7693,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
76927693
GGML_ASSERT(false);
76937694
}
76947695

7695-
const std::string piece = llama_token_to_str(ctx, token);
7696+
const std::string piece = llama_token_to_piece(ctx, token);
76967697

76977698
// Note terminating 0 in decoded string
76987699
const auto decoded = decode_utf8(piece.c_str(), grammar->partial_utf8);

models/ggml-vocab-mpt.gguf

1.69 MB
Binary file not shown.

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ llama_test_executable (test-tokenizer-1-llama test-tokenizer-1-llama.cpp ${CMAKE
3131
llama_build_executable(test-tokenizer-1-bpe.cpp)
3232
llama_test_executable (test-tokenizer-1-falcon test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
3333
llama_test_executable(test-tokenizer-1-aquila test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
34+
llama_test_executable(test-tokenizer-1-mpt test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
3435
llama_build_and_test_executable(test-grammar-parser.cpp)
3536
llama_build_and_test_executable(test-llama-grammar.cpp)
3637
llama_build_and_test_executable(test-grad0.cpp) # SLOW

tests/test-tokenizer-1-bpe.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ int main(int argc, char **argv) {
6262
const int n_vocab = llama_n_vocab(model);
6363

6464
for (int i = 0; i < n_vocab; ++i) {
65+
if (llama_token_get_type(ctx, i) == LLAMA_TOKEN_TYPE_USER_DEFINED) {
66+
continue;
67+
}
6568
std::string str = llama_detokenize_bpe(ctx, std::vector<int>(1, i));
6669
try {
6770
auto cps = codepoints_from_utf8(str);

0 commit comments

Comments
 (0)