Skip to content

llama2c : fix segfault and alloc-dealloc-mismatch #2913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ typedef struct {
int seq_len; // max sequence length
} Config;

typedef struct {
struct TransformerWeights {
// token embedding table
float* token_embedding_table; // (vocab_size, dim)
// weights for rmsnorms
Expand All @@ -97,7 +97,22 @@ typedef struct {
// float* freq_cis_imag; // (seq_len, dim/2)
// (optional) classifier weights for the logits, on the last layer
float* wcls;
} TransformerWeights;

~TransformerWeights() {
delete[] token_embedding_table;
delete[] rms_att_weight;
delete[] rms_ffn_weight;
delete[] wq;
delete[] wk;
delete[] wv;
delete[] wo;
delete[] w1;
delete[] w2;
delete[] w3;
delete[] rms_final_weight;
delete[] wcls;
}
};

void malloc_weights(TransformerWeights* w, Config* p, bool shared_weights) {
// we calloc instead of malloc to keep valgrind happy
Expand Down Expand Up @@ -173,21 +188,6 @@ int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f, bool shar
return 0;
}

void free_weights(TransformerWeights* w) {
delete w->token_embedding_table;
delete w->rms_att_weight;
delete w->rms_ffn_weight;
delete w->wq;
delete w->wk;
delete w->wv;
delete w->wo;
delete w->w1;
delete w->w2;
delete w->w3;
delete w->rms_final_weight;
if (w->wcls) delete w->wcls;
}

void print_sample_weights(TransformerWeights *w){
printf("----- Quick print of first of the weight vales of all the variables\n");
printf("%f\n", w->token_embedding_table[0]);
Expand Down Expand Up @@ -596,6 +596,10 @@ void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab)
// assume llama2.c vocabulary
printf("Assuming llama2.c vocabulary since %s is not a gguf file\n", filename);
llama_file file(filename, "rb");
if (!file.fp) {
fprintf(stderr, "error: %s: %s\n", strerror(errno), filename);
exit(1);
}
const int n_vocab = config->vocab_size;
/* uint32_t max_token_length = */ file.read_u32(); // unused
vocab->id_to_token.resize(n_vocab);
Expand Down Expand Up @@ -898,7 +902,7 @@ bool params_parse(int argc, char ** argv, struct train_params * params) {
}

std::string basename(const std::string &path) {
size_t pos = path.find_last_of("/");
size_t pos = path.find_last_of("/\\");
if (pos == std::string::npos) {
return path;
}
Expand All @@ -911,7 +915,7 @@ int main(int argc, char ** argv) {
return 1;
}
Config config;
TransformerWeights weights;
TransformerWeights weights = {};
{
FILE *file = fopen(params.fn_llama2c_model, "rb");
if (!file) { printf("Unable to open the checkpoint file %s!\n", params.fn_llama2c_model); return 1; }
Expand Down Expand Up @@ -953,6 +957,5 @@ int main(int argc, char ** argv) {
printf("Saving llama.c model file %s in ggml format at %s\n", params.fn_llama2c_model, params.fn_llama2c_output_model);

ggml_free(model.ctx);
free_weights(&weights);
return 0;
}