@@ -75,7 +75,7 @@ typedef struct {
75
75
int seq_len; // max sequence length
76
76
} Config;
77
77
78
- typedef struct {
78
+ struct TransformerWeights {
79
79
// token embedding table
80
80
float * token_embedding_table; // (vocab_size, dim)
81
81
// weights for rmsnorms
@@ -97,7 +97,22 @@ typedef struct {
97
97
// float* freq_cis_imag; // (seq_len, dim/2)
98
98
// (optional) classifier weights for the logits, on the last layer
99
99
float * wcls;
100
- } TransformerWeights;
100
+
101
+ ~TransformerWeights () {
102
+ delete[] token_embedding_table;
103
+ delete[] rms_att_weight;
104
+ delete[] rms_ffn_weight;
105
+ delete[] wq;
106
+ delete[] wk;
107
+ delete[] wv;
108
+ delete[] wo;
109
+ delete[] w1;
110
+ delete[] w2;
111
+ delete[] w3;
112
+ delete[] rms_final_weight;
113
+ delete[] wcls;
114
+ }
115
+ };
101
116
102
117
void malloc_weights (TransformerWeights* w, Config* p, bool shared_weights) {
103
118
// we calloc instead of malloc to keep valgrind happy
@@ -173,21 +188,6 @@ int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f, bool shar
173
188
return 0 ;
174
189
}
175
190
176
- void free_weights (TransformerWeights* w) {
177
- delete w->token_embedding_table ;
178
- delete w->rms_att_weight ;
179
- delete w->rms_ffn_weight ;
180
- delete w->wq ;
181
- delete w->wk ;
182
- delete w->wv ;
183
- delete w->wo ;
184
- delete w->w1 ;
185
- delete w->w2 ;
186
- delete w->w3 ;
187
- delete w->rms_final_weight ;
188
- if (w->wcls ) delete w->wcls ;
189
- }
190
-
191
191
void print_sample_weights (TransformerWeights *w){
192
192
printf (" ----- Quick print of first of the weight vales of all the variables\n " );
193
193
printf (" %f\n " , w->token_embedding_table [0 ]);
@@ -596,6 +596,10 @@ void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab)
596
596
// assume llama2.c vocabulary
597
597
printf (" Assuming llama2.c vocabulary since %s is not a gguf file\n " , filename);
598
598
llama_file file (filename, " rb" );
599
+ if (!file.fp ) {
600
+ fprintf (stderr, " error: %s: %s\n " , strerror (errno), filename);
601
+ exit (1 );
602
+ }
599
603
const int n_vocab = config->vocab_size ;
600
604
/* uint32_t max_token_length = */ file.read_u32 (); // unused
601
605
vocab->id_to_token .resize (n_vocab);
@@ -898,7 +902,7 @@ bool params_parse(int argc, char ** argv, struct train_params * params) {
898
902
}
899
903
900
904
std::string basename (const std::string &path) {
901
- size_t pos = path.find_last_of (" /" );
905
+ size_t pos = path.find_last_of (" /\\ " );
902
906
if (pos == std::string::npos) {
903
907
return path;
904
908
}
@@ -911,7 +915,7 @@ int main(int argc, char ** argv) {
911
915
return 1 ;
912
916
}
913
917
Config config;
914
- TransformerWeights weights;
918
+ TransformerWeights weights = {} ;
915
919
{
916
920
FILE *file = fopen (params.fn_llama2c_model , " rb" );
917
921
if (!file) { printf (" Unable to open the checkpoint file %s!\n " , params.fn_llama2c_model ); return 1 ; }
@@ -953,6 +957,5 @@ int main(int argc, char ** argv) {
953
957
printf (" Saving llama.c model file %s in ggml format at %s\n " , params.fn_llama2c_model , params.fn_llama2c_output_model );
954
958
955
959
ggml_free (model.ctx );
956
- free_weights (&weights);
957
960
return 0 ;
958
961
}
0 commit comments