Skip to content

Commit 18705a3

Browse files
authored
llama2c : fix segfault and alloc-dealloc-mismatch (#2913)
* llama2c : fix segfault if vocab is not found * llama2c : fix mismatch between new[] and delete * llama2c : fix basename on Windows * llama2c : use a destructor to prevent memory leaks
1 parent e8d9158 commit 18705a3

File tree

1 file changed

+23
-20
lines changed

1 file changed

+23
-20
lines changed

examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ typedef struct {
7575
int seq_len; // max sequence length
7676
} Config;
7777

78-
typedef struct {
78+
struct TransformerWeights {
7979
// token embedding table
8080
float* token_embedding_table; // (vocab_size, dim)
8181
// weights for rmsnorms
@@ -97,7 +97,22 @@ typedef struct {
9797
// float* freq_cis_imag; // (seq_len, dim/2)
9898
// (optional) classifier weights for the logits, on the last layer
9999
float* wcls;
100-
} TransformerWeights;
100+
101+
~TransformerWeights() {
102+
delete[] token_embedding_table;
103+
delete[] rms_att_weight;
104+
delete[] rms_ffn_weight;
105+
delete[] wq;
106+
delete[] wk;
107+
delete[] wv;
108+
delete[] wo;
109+
delete[] w1;
110+
delete[] w2;
111+
delete[] w3;
112+
delete[] rms_final_weight;
113+
delete[] wcls;
114+
}
115+
};
101116

102117
void malloc_weights(TransformerWeights* w, Config* p, bool shared_weights) {
103118
// we calloc instead of malloc to keep valgrind happy
@@ -173,21 +188,6 @@ int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f, bool shar
173188
return 0;
174189
}
175190

176-
void free_weights(TransformerWeights* w) {
177-
delete w->token_embedding_table;
178-
delete w->rms_att_weight;
179-
delete w->rms_ffn_weight;
180-
delete w->wq;
181-
delete w->wk;
182-
delete w->wv;
183-
delete w->wo;
184-
delete w->w1;
185-
delete w->w2;
186-
delete w->w3;
187-
delete w->rms_final_weight;
188-
if (w->wcls) delete w->wcls;
189-
}
190-
191191
void print_sample_weights(TransformerWeights *w){
192192
printf("----- Quick print of first of the weight vales of all the variables\n");
193193
printf("%f\n", w->token_embedding_table[0]);
@@ -596,6 +596,10 @@ void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab)
596596
// assume llama2.c vocabulary
597597
printf("Assuming llama2.c vocabulary since %s is not a gguf file\n", filename);
598598
llama_file file(filename, "rb");
599+
if (!file.fp) {
600+
fprintf(stderr, "error: %s: %s\n", strerror(errno), filename);
601+
exit(1);
602+
}
599603
const int n_vocab = config->vocab_size;
600604
/* uint32_t max_token_length = */ file.read_u32(); // unused
601605
vocab->id_to_token.resize(n_vocab);
@@ -898,7 +902,7 @@ bool params_parse(int argc, char ** argv, struct train_params * params) {
898902
}
899903

900904
std::string basename(const std::string &path) {
901-
size_t pos = path.find_last_of("/");
905+
size_t pos = path.find_last_of("/\\");
902906
if (pos == std::string::npos) {
903907
return path;
904908
}
@@ -911,7 +915,7 @@ int main(int argc, char ** argv) {
911915
return 1;
912916
}
913917
Config config;
914-
TransformerWeights weights;
918+
TransformerWeights weights = {};
915919
{
916920
FILE *file = fopen(params.fn_llama2c_model, "rb");
917921
if (!file) { printf("Unable to open the checkpoint file %s!\n", params.fn_llama2c_model); return 1; }
@@ -953,6 +957,5 @@ int main(int argc, char ** argv) {
953957
printf("Saving llama.c model file %s in ggml format at %s\n", params.fn_llama2c_model, params.fn_llama2c_output_model);
954958

955959
ggml_free(model.ctx);
956-
free_weights(&weights);
957960
return 0;
958961
}

0 commit comments

Comments
 (0)