Skip to content

Commit 738ace3

Browse files
committed
llama : free ggml context in set / copy state data (close #1425)
1 parent 699b1ad commit 738ace3

File tree

2 files changed

+29
-21
lines changed

2 files changed

+29
-21
lines changed

llama.cpp

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2450,8 +2450,8 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
24502450
}
24512451

24522452
// Copies the state to the specified destination address
2453-
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
2454-
uint8_t * out = dest;
2453+
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
2454+
uint8_t * out = dst;
24552455

24562456
// copy rng
24572457
{
@@ -2511,7 +2511,9 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
25112511

25122512
if (kv_size) {
25132513
const size_t elt_size = ggml_element_size(kv_self.k);
2514+
25142515
char buffer[4096];
2516+
25152517
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
25162518
ggml_cgraph gf{};
25172519
gf.n_threads = 1;
@@ -2535,10 +2537,12 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
25352537
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
25362538
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
25372539
ggml_graph_compute(cpy_ctx, &gf);
2540+
2541+
ggml_free(cpy_ctx);
25382542
}
25392543
}
25402544

2541-
const size_t written = out - dest;
2545+
const size_t written = out - dst;
25422546
const size_t max_size = llama_get_state_size(ctx);
25432547

25442548
LLAMA_ASSERT(written <= max_size);
@@ -2548,15 +2552,15 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
25482552

25492553
// Sets the state reading from the specified source address
25502554
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
2551-
const uint8_t * in = src;
2555+
const uint8_t * inp = src;
25522556

25532557
// set rng
25542558
{
25552559
size_t rng_size;
25562560
char rng_buf[LLAMA_MAX_RNG_STATE];
25572561

2558-
memcpy(&rng_size, in, sizeof(rng_size)); in += sizeof(rng_size);
2559-
memcpy(&rng_buf[0], in, LLAMA_MAX_RNG_STATE); in += LLAMA_MAX_RNG_STATE;
2562+
memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size);
2563+
memcpy(&rng_buf[0], inp, LLAMA_MAX_RNG_STATE); inp += LLAMA_MAX_RNG_STATE;
25602564

25612565
std::stringstream rng_ss;
25622566
rng_ss.str(std::string(&rng_buf[0], rng_size));
@@ -2570,30 +2574,30 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
25702574
size_t logits_cap;
25712575
size_t logits_size;
25722576

2573-
memcpy(&logits_cap, in, sizeof(logits_cap)); in += sizeof(logits_cap);
2574-
memcpy(&logits_size, in, sizeof(logits_size)); in += sizeof(logits_size);
2577+
memcpy(&logits_cap, inp, sizeof(logits_cap)); inp += sizeof(logits_cap);
2578+
memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size);
25752579

25762580
LLAMA_ASSERT(ctx->logits.capacity() == logits_cap);
25772581

25782582
if (logits_size) {
25792583
ctx->logits.resize(logits_size);
2580-
memcpy(ctx->logits.data(), in, logits_size * sizeof(float));
2584+
memcpy(ctx->logits.data(), inp, logits_size * sizeof(float));
25812585
}
25822586

2583-
in += logits_cap * sizeof(float);
2587+
inp += logits_cap * sizeof(float);
25842588
}
25852589

25862590
// set embeddings
25872591
{
25882592
size_t embedding_size;
25892593

2590-
memcpy(&embedding_size, in, sizeof(embedding_size)); in += sizeof(embedding_size);
2594+
memcpy(&embedding_size, inp, sizeof(embedding_size)); inp += sizeof(embedding_size);
25912595

25922596
LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);
25932597

25942598
if (embedding_size) {
2595-
memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float));
2596-
in += embedding_size * sizeof(float);
2599+
memcpy(ctx->embedding.data(), inp, embedding_size * sizeof(float));
2600+
inp += embedding_size * sizeof(float);
25972601
}
25982602
}
25992603

@@ -2608,25 +2612,27 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
26082612
size_t kv_size;
26092613
int kv_ntok;
26102614

2611-
memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size);
2612-
memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
2615+
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
2616+
memcpy(&kv_ntok, inp, sizeof(kv_ntok)); inp += sizeof(kv_ntok);
26132617

26142618
if (kv_size) {
26152619
LLAMA_ASSERT(kv_self.buf.size == kv_size);
26162620

26172621
const size_t elt_size = ggml_element_size(kv_self.k);
2622+
26182623
char buffer[4096];
2624+
26192625
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
26202626
ggml_cgraph gf{};
26212627
gf.n_threads = 1;
26222628

26232629
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
2624-
kin3d->data = (void *) in;
2625-
in += ggml_nbytes(kin3d);
2630+
kin3d->data = (void *) inp;
2631+
inp += ggml_nbytes(kin3d);
26262632

26272633
ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
2628-
vin3d->data = (void *) in;
2629-
in += ggml_nbytes(vin3d);
2634+
vin3d->data = (void *) inp;
2635+
inp += ggml_nbytes(vin3d);
26302636

26312637
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
26322638
n_embd, kv_ntok, n_layer,
@@ -2639,12 +2645,14 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
26392645
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
26402646
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d));
26412647
ggml_graph_compute(cpy_ctx, &gf);
2648+
2649+
ggml_free(cpy_ctx);
26422650
}
26432651

26442652
ctx->model.kv_self.n = kv_ntok;
26452653
}
26462654

2647-
const size_t nread = in - src;
2655+
const size_t nread = inp - src;
26482656
const size_t max_size = llama_get_state_size(ctx);
26492657

26502658
LLAMA_ASSERT(nread <= max_size);

llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ extern "C" {
134134
// Copies the state to the specified destination address.
135135
// Destination needs to have allocated enough memory.
136136
// Returns the number of bytes copied
137-
LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest);
137+
LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst);
138138

139139
// Set the state reading from the specified address
140140
// Returns the number of bytes read

0 commit comments

Comments
 (0)