Skip to content

Commit e1e7210

Browse files
authored
llama : fix memory leak in llama_batch_free (#5252)
The llama_batch_init allocates memory for a fixed number of tokens. However, the llama_batch_free only frees memory for the number of tokens that were added to the batch. This change-set uses a null terminated array for the batch seq_id, and frees all the elements until the nullptr is reached. This change-set also changes the name of the first parameter from `n_tokens` to `n_tokens_alloc` to more clearly indicate that this value is the number of tokens allocated to the batch, not the number of tokens in the batch.
1 parent 128dcbd commit e1e7210

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

llama.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11377,22 +11377,24 @@ struct llama_batch llama_batch_get_one(
1137711377
};
1137811378
}
1137911379

11380-
struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) {
11380+
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
1138111381
llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
1138211382

1138311383
if (embd) {
11384-
batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd);
11384+
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
1138511385
} else {
11386-
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
11386+
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
1138711387
}
1138811388

11389-
batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
11390-
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
11391-
batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
11392-
for (int i = 0; i < n_tokens; ++i) {
11389+
batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
11390+
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
11391+
batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
11392+
for (int i = 0; i < n_tokens_alloc; ++i) {
1139311393
batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
1139411394
}
11395-
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
11395+
batch.seq_id[n_tokens_alloc] = nullptr;
11396+
11397+
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
1139611398

1139711399
return batch;
1139811400
}
@@ -11403,7 +11405,7 @@ void llama_batch_free(struct llama_batch batch) {
1140311405
if (batch.pos) free(batch.pos);
1140411406
if (batch.n_seq_id) free(batch.n_seq_id);
1140511407
if (batch.seq_id) {
11406-
for (int i = 0; i < batch.n_tokens; ++i) {
11408+
for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
1140711409
free(batch.seq_id[i]);
1140811410
}
1140911411
free(batch.seq_id);

0 commit comments

Comments
 (0)