Skip to content

Commit 0f7cb95

Browse files
author
ochafik
committed
Fix import of llama2.c models that don't share weights between embedding layers
1 parent 930523c commit 0f7cb95

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ typedef struct {
4949
// float* freq_cis_real; // (seq_len, dim/2)
5050
// float* freq_cis_imag; // (seq_len, dim/2)
5151
// (optional) classifier weights for the logits, on the last layer
52-
//float* wcls;
52+
float* wcls;
5353
} TransformerWeights;
5454

55-
void malloc_weights(TransformerWeights* w, Config* p) {
55+
void malloc_weights(TransformerWeights* w, Config* p, bool shared_weights) {
5656
// we calloc instead of malloc to keep valgrind happy
5757
w->token_embedding_table = new float[p->vocab_size * p->dim]();
5858
printf("[%s:AK] Allocating [%d] x [%d] = [%d] float space for w->token_embedding_table\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim);
@@ -86,9 +86,16 @@ void malloc_weights(TransformerWeights* w, Config* p) {
8686

8787
w->rms_final_weight = new float[p->dim]();
8888
printf("[%s:AK] Allocating [%d] float space for w->rms_final_weight\n",__func__,p->dim);
89+
90+
if (shared_weights) {
91+
w->wcls = NULL;
92+
} else {
93+
w->wcls = new float[p->vocab_size * p->dim]();
94+
printf("[%s:AK] Allocating [%d] x [%d] = [%d] float space for w->wcls\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim);
95+
}
8996
}
9097

91-
int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f) {
98+
int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f, bool shared_weights) {
9299
if (fread(w->token_embedding_table, sizeof(float), p->vocab_size * p->dim, f) != static_cast<size_t>(p->vocab_size * p->dim)) return 1;
93100
if (fread(w->rms_att_weight, sizeof(float), p->n_layers * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim)) return 1;
94101
if (fread(w->wq, sizeof(float), p->n_layers * p->dim * p->dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->dim)) return 1;
@@ -100,6 +107,22 @@ int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f) {
100107
if (fread(w->w2, sizeof(float), p->n_layers * p->hidden_dim * p->dim, f) != static_cast<size_t>(p->n_layers * p->hidden_dim * p->dim)) return 1;
101108
if (fread(w->w3, sizeof(float), p->n_layers * p->dim * p->hidden_dim, f) != static_cast<size_t>(p->n_layers * p->dim * p->hidden_dim)) return 1;
102109
if (fread(w->rms_final_weight, sizeof(float), p->dim, f) != static_cast<size_t>(p->dim)) return 1;
110+
111+
// Skip freq_cis_real & freq_cis_imag
112+
int head_size = p->dim / p->n_heads;
113+
fseek(f, p->seq_len * head_size * sizeof(float), SEEK_CUR);
114+
115+
if (!shared_weights && fread(w->wcls, sizeof(float), p->vocab_size * p->dim, f) != static_cast<size_t>(p->vocab_size * p->dim)) return 1;
116+
117+
// Check we didn't forget to read anything
118+
auto curr = ftell(f);
119+
fseek(f, 0, SEEK_END);
120+
auto end = ftell(f);
121+
if (curr != end) {
122+
printf("Error: failed to read the checkpoint file to the end (curr = %ld, end = %ld)\n", curr, end);
123+
return 1;
124+
}
125+
103126
return 0;
104127
}
105128

@@ -115,6 +138,7 @@ void free_weights(TransformerWeights* w) {
115138
delete w->w2;
116139
delete w->w3;
117140
delete w->rms_final_weight;
141+
if (w->wcls) delete w->wcls;
118142
}
119143

120144
void print_sample_weights(TransformerWeights *w){
@@ -131,6 +155,7 @@ void print_sample_weights(TransformerWeights *w){
131155
printf("%f\n", w->w2[0]);
132156
printf("%f\n", w->w3[0]);
133157
printf("%f\n", w->rms_att_weight[0]);
158+
if (w->wcls) printf("%f\n", w->wcls[0]);
134159
}
135160
////////////////////////////////////////////////////////////////////////////////////////////////////////////
136161

@@ -617,7 +642,7 @@ void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * mod
617642
// // w->token_embedding_table -> model->tok_embeddings
618643
// // float* -> struct ggml_tensor
619644
// stuff_karpathy_weights_into_gg(model->tok_embeddings, w->token_embedding_table);
620-
// stuff_karpathy_weights_into_gg(model->output, w->token_embedding_table);
645+
// stuff_karpathy_weights_into_gg(model->output, w->wcls ? w->wcls : w->token_embedding_table);
621646
//
622647
// stuff_karpathy_weights_into_gg(model->norm, w->rms_final_weight);
623648
// //print_row(model->norm, 0);
@@ -791,9 +816,12 @@ int main(int argc, char ** argv) {
791816
if (!file) { printf("Unable to open the checkpoint file %s!\n", params.fn_llama2c_model); return 1; }
792817
// read in the config header
793818
if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
819+
auto shared_weights = config.vocab_size > 0;
820+
config.vocab_size = abs(config.vocab_size);
821+
794822
// read in the Transformer weights
795-
malloc_weights(&weights, &config);
796-
if(checkpoint_init_weights(&weights, &config, file)) { return 1; }
823+
malloc_weights(&weights, &config, shared_weights);
824+
if(checkpoint_init_weights(&weights, &config, file, shared_weights)) { return 1; }
797825
fclose(file);
798826
}
799827

0 commit comments

Comments
 (0)