Skip to content

Commit c88362d

Browse files
author
ochafik
committed
llama2.c: support copying vocab from a llama gguf model file
1 parent 63174b8 commit c88362d

File tree

1 file changed

+82
-24
lines changed

1 file changed

+82
-24
lines changed

examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp

Lines changed: 82 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -531,31 +531,87 @@ bool is_ggml_file(const char *filename) {
531531
return magic == GGUF_MAGIC;
532532
}
533533

534-
llama_vocab::ttype get_token_type(llama_vocab::id id, const llama_vocab::token &text) {
535-
if (id == UNKNOWN_TOKEN_ID) return LLAMA_TOKEN_TYPE_UNKNOWN;
536-
if (id == BOS_TOKEN_ID || id == EOS_TOKEN_ID) return LLAMA_TOKEN_TYPE_CONTROL;
537-
unsigned char byte_val;
538-
if (sscanf(text.c_str(), "<0x%02hhX>", &byte_val) == 1) {
539-
return LLAMA_TOKEN_TYPE_BYTE;
540-
}
541-
return LLAMA_TOKEN_TYPE_NORMAL;
542-
}
543-
544534
void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab) {
545-
// assume llama2.c vocabulary
546-
printf("Assuming llama2.c vocabulary since %s is not a ggml file\n", filename);
547-
llama_file file(filename, "rb");
548-
const int n_vocab = config->vocab_size;
549-
/* uint32_t max_token_length = */ file.read_u32(); // unused
550-
vocab->id_to_token.resize(n_vocab);
551-
for (int i=0; i<n_vocab; ++i) {
552-
float_t score = file.read_f32();
553-
uint32_t len = file.read_u32();
554-
std::string text = file.read_string(len);
555-
vocab->id_to_token[i].text = text;
556-
vocab->id_to_token[i].score = score;
557-
vocab->id_to_token[i].type = get_token_type(i, text);
558-
vocab->token_to_id.emplace(text, i);
535+
if (is_ggml_file(filename)) {
536+
struct ggml_context * ctx_data = NULL;
537+
538+
struct gguf_init_params params = {
539+
/*.no_alloc = */ false,
540+
/*.ctx = */ &ctx_data,
541+
};
542+
543+
struct gguf_context * ctx = gguf_init_from_file(filename, params);
544+
GGML_ASSERT(ctx != NULL);
545+
546+
const int model_idx = gguf_find_key(ctx, "tokenizer.ggml.model");
547+
GGML_ASSERT(model_idx >= 0);
548+
std::string tokenizer_name = gguf_get_val_str(ctx, model_idx);
549+
GGML_ASSERT(tokenizer_name == "llama");
550+
551+
const int token_idx = gguf_find_key(ctx, "tokenizer.ggml.tokens");
552+
GGML_ASSERT(token_idx >= 0);
553+
554+
const int score_idx = gguf_find_key(ctx, "tokenizer.ggml.scores");
555+
GGML_ASSERT(score_idx >= 0);
556+
const float * scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
557+
558+
const int toktype_idx = gguf_find_key(ctx, "tokenizer.ggml.token_type");
559+
GGML_ASSERT(toktype_idx >= 0);
560+
const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
561+
562+
const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
563+
564+
vocab->id_to_token.resize(n_vocab);
565+
566+
for (uint32_t i = 0; i < n_vocab; i++) {
567+
std::string word = gguf_get_arr_str(ctx, token_idx, i);
568+
569+
vocab->token_to_id[word] = i;
570+
571+
auto & token_data = vocab->id_to_token[i];
572+
token_data.text = std::move(word);
573+
token_data.score = scores[i];
574+
token_data.type = (llama_token_type) toktypes[i];
575+
}
576+
ggml_free(ctx_data);
577+
gguf_free(ctx);
578+
} else {
579+
// assume llama2.c vocabulary
580+
printf("Assuming llama2.c vocabulary since %s is not a gguf file\n", filename);
581+
llama_file file(filename, "rb");
582+
const int n_vocab = config->vocab_size;
583+
/* uint32_t max_token_length = */ file.read_u32(); // unused
584+
vocab->id_to_token.resize(n_vocab);
585+
for (llama_vocab::id id=0; id<n_vocab; ++id) {
586+
float_t score = file.read_f32();
587+
uint32_t len = file.read_u32();
588+
std::string text = file.read_string(len);
589+
590+
unsigned char byte_val;
591+
llama_vocab::ttype type = LLAMA_TOKEN_TYPE_NORMAL;
592+
if (id == UNKNOWN_TOKEN_ID) {
593+
text = "<unk>";
594+
type = LLAMA_TOKEN_TYPE_UNKNOWN;
595+
} else if (id == BOS_TOKEN_ID) {
596+
text = "<s>";
597+
type = LLAMA_TOKEN_TYPE_CONTROL;
598+
} else if (id == EOS_TOKEN_ID) {
599+
text = "</s>";
600+
type = LLAMA_TOKEN_TYPE_CONTROL;
601+
} else if (text.empty()) {
602+
type = LLAMA_TOKEN_TYPE_CONTROL;
603+
} else if (sscanf(text.c_str(), "<0x%02hhX>", &byte_val) == 1) {
604+
// Text of byte tokens is already in the expected format.
605+
type = LLAMA_TOKEN_TYPE_BYTE;
606+
} else {
607+
type = LLAMA_TOKEN_TYPE_NORMAL;
608+
}
609+
610+
vocab->id_to_token[id].text = text;
611+
vocab->id_to_token[id].score = score;
612+
vocab->id_to_token[id].type = type;
613+
vocab->token_to_id.emplace(text, id);
614+
}
559615
}
560616
}
561617

@@ -651,6 +707,8 @@ void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * mod
651707
gguf_set_val_u32(ctx, "tokenizer.ggml.unknown_token_id", UNKNOWN_TOKEN_ID);
652708
gguf_set_val_u32(ctx, "tokenizer.ggml.bos_token_id", BOS_TOKEN_ID);
653709
gguf_set_val_u32(ctx, "tokenizer.ggml.eos_token_id", EOS_TOKEN_ID);
710+
gguf_set_val_u32(ctx, "tokenizer.ggml.sep_token_id", -1);
711+
gguf_set_val_u32(ctx, "tokenizer.ggml.pad_token_id", -1);
654712

655713
gguf_set_val_u32(ctx, "llama.context_length", model->hparams.n_ctx);
656714
gguf_set_val_u32(ctx, "llama.embedding_length", model->hparams.n_embd);

0 commit comments

Comments
 (0)