@@ -49,10 +49,10 @@ typedef struct {
49
49
// float* freq_cis_real; // (seq_len, dim/2)
50
50
// float* freq_cis_imag; // (seq_len, dim/2)
51
51
// (optional) classifier weights for the logits, on the last layer
52
- // float* wcls;
52
+ float * wcls;
53
53
} TransformerWeights;
54
54
55
- void malloc_weights (TransformerWeights* w, Config* p) {
55
+ void malloc_weights (TransformerWeights* w, Config* p, bool shared_weights ) {
56
56
// we calloc instead of malloc to keep valgrind happy
57
57
w->token_embedding_table = new float [p->vocab_size * p->dim ]();
58
58
printf (" [%s:AK] Allocating [%d] x [%d] = [%d] float space for w->token_embedding_table\n " ,__func__,p->vocab_size , p->dim , p->vocab_size * p->dim );
@@ -86,9 +86,16 @@ void malloc_weights(TransformerWeights* w, Config* p) {
86
86
87
87
w->rms_final_weight = new float [p->dim ]();
88
88
printf (" [%s:AK] Allocating [%d] float space for w->rms_final_weight\n " ,__func__,p->dim );
89
+
90
+ if (shared_weights) {
91
+ w->wcls = NULL ;
92
+ } else {
93
+ w->wcls = new float [p->vocab_size * p->dim ]();
94
+ printf (" [%s:AK] Allocating [%d] x [%d] = [%d] float space for w->wcls\n " ,__func__,p->vocab_size , p->dim , p->vocab_size * p->dim );
95
+ }
89
96
}
90
97
91
- int checkpoint_init_weights (TransformerWeights *w, Config* p, FILE* f) {
98
+ int checkpoint_init_weights (TransformerWeights *w, Config* p, FILE* f, bool shared_weights ) {
92
99
if (fread (w->token_embedding_table , sizeof (float ), p->vocab_size * p->dim , f) != static_cast <size_t >(p->vocab_size * p->dim )) return 1 ;
93
100
if (fread (w->rms_att_weight , sizeof (float ), p->n_layers * p->dim , f) != static_cast <size_t >(p->n_layers * p->dim )) return 1 ;
94
101
if (fread (w->wq , sizeof (float ), p->n_layers * p->dim * p->dim , f) != static_cast <size_t >(p->n_layers * p->dim * p->dim )) return 1 ;
@@ -100,6 +107,22 @@ int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f) {
100
107
if (fread (w->w2 , sizeof (float ), p->n_layers * p->hidden_dim * p->dim , f) != static_cast <size_t >(p->n_layers * p->hidden_dim * p->dim )) return 1 ;
101
108
if (fread (w->w3 , sizeof (float ), p->n_layers * p->dim * p->hidden_dim , f) != static_cast <size_t >(p->n_layers * p->dim * p->hidden_dim )) return 1 ;
102
109
if (fread (w->rms_final_weight , sizeof (float ), p->dim , f) != static_cast <size_t >(p->dim )) return 1 ;
110
+
111
+ // Skip freq_cis_real & freq_cis_imag
112
+ int head_size = p->dim / p->n_heads ;
113
+ fseek (f, p->seq_len * head_size * sizeof (float ), SEEK_CUR);
114
+
115
+ if (!shared_weights && fread (w->wcls , sizeof (float ), p->vocab_size * p->dim , f) != static_cast <size_t >(p->vocab_size * p->dim )) return 1 ;
116
+
117
+ // Check we didn't forget to read anything
118
+ auto curr = ftell (f);
119
+ fseek (f, 0 , SEEK_END);
120
+ auto end = ftell (f);
121
+ if (curr != end) {
122
+ printf (" Error: failed to read the checkpoint file to the end (curr = %ld, end = %ld)\n " , curr, end);
123
+ return 1 ;
124
+ }
125
+
103
126
return 0 ;
104
127
}
105
128
@@ -115,6 +138,7 @@ void free_weights(TransformerWeights* w) {
115
138
delete w->w2 ;
116
139
delete w->w3 ;
117
140
delete w->rms_final_weight ;
141
+ if (w->wcls ) delete w->wcls ;
118
142
}
119
143
120
144
void print_sample_weights (TransformerWeights *w){
@@ -131,6 +155,7 @@ void print_sample_weights(TransformerWeights *w){
131
155
printf (" %f\n " , w->w2 [0 ]);
132
156
printf (" %f\n " , w->w3 [0 ]);
133
157
printf (" %f\n " , w->rms_att_weight [0 ]);
158
+ if (w->wcls ) printf (" %f\n " , w->wcls [0 ]);
134
159
}
135
160
// //////////////////////////////////////////////////////////////////////////////////////////////////////////
136
161
@@ -617,7 +642,7 @@ void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * mod
617
642
// // w->token_embedding_table -> model->tok_embeddings
618
643
// // float* -> struct ggml_tensor
619
644
// stuff_karpathy_weights_into_gg(model->tok_embeddings, w->token_embedding_table);
620
- // stuff_karpathy_weights_into_gg(model->output, w->token_embedding_table);
645
+ // stuff_karpathy_weights_into_gg(model->output, w->wcls ? w->wcls : w-> token_embedding_table);
621
646
//
622
647
// stuff_karpathy_weights_into_gg(model->norm, w->rms_final_weight);
623
648
// //print_row(model->norm, 0);
@@ -791,9 +816,12 @@ int main(int argc, char ** argv) {
791
816
if (!file) { printf (" Unable to open the checkpoint file %s!\n " , params.fn_llama2c_model ); return 1 ; }
792
817
// read in the config header
793
818
if (fread (&config, sizeof (Config), 1 , file) != 1 ) { return 1 ; }
819
+ auto shared_weights = config.vocab_size > 0 ;
820
+ config.vocab_size = abs (config.vocab_size );
821
+
794
822
// read in the Transformer weights
795
- malloc_weights (&weights, &config);
796
- if (checkpoint_init_weights (&weights, &config, file)) { return 1 ; }
823
+ malloc_weights (&weights, &config, shared_weights );
824
+ if (checkpoint_init_weights (&weights, &config, file, shared_weights )) { return 1 ; }
797
825
fclose (file);
798
826
}
799
827
0 commit comments