Skip to content

Commit e216aa0

Browse files
authored
llama : only copy used KV cache in get / set state (#1272)
* llama : only copy used KV cache in get / set state * switch to ggml for copying k, v * avoid designated initializers
1 parent 2485d7a commit e216aa0

File tree

2 files changed

+80
-23
lines changed

2 files changed

+80
-23
lines changed

llama.cpp

Lines changed: 77 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,6 +1285,9 @@ static bool llama_eval_internal(
12851285
//embd_w.resize(n_vocab*N);
12861286
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
12871287

1288+
// update kv token count
1289+
lctx.model.kv_self.n = n_past + N;
1290+
12881291
// extract logits
12891292
{
12901293
auto & logits_out = lctx.logits;
@@ -2401,7 +2404,7 @@ void llama_set_rng_seed(struct llama_context * ctx, int seed) {
24012404
ctx->rng.seed(seed);
24022405
}
24032406

2404-
// Returns the size of the state
2407+
// Returns the *maximum* size of the state
24052408
size_t llama_get_state_size(const struct llama_context * ctx) {
24062409
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
24072410
// for reference, std::mt19937(1337) serializes to 6701 bytes.
@@ -2480,21 +2483,51 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
24802483

24812484
// copy kv cache
24822485
{
2483-
const size_t kv_size = ctx->model.kv_self.buf.size;
2486+
const auto & kv_self = ctx->model.kv_self;
2487+
const auto & hparams = ctx->model.hparams;
2488+
const int n_layer = hparams.n_layer;
2489+
const int n_embd = hparams.n_embd;
2490+
const int n_ctx = hparams.n_ctx;
2491+
2492+
const size_t kv_size = kv_self.buf.size;
24842493
const int kv_ntok = llama_get_kv_cache_token_count(ctx);
24852494

24862495
memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
24872496
memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
24882497

24892498
if (kv_size) {
2490-
memcpy(out, ctx->model.kv_self.buf.addr, kv_size); out += kv_size;
2499+
const size_t elt_size = ggml_element_size(kv_self.k);
2500+
char buffer[4096];
2501+
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
2502+
ggml_cgraph gf{};
2503+
gf.n_threads = 1;
2504+
2505+
ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
2506+
kout3d->data = out;
2507+
out += ggml_nbytes(kout3d);
2508+
2509+
ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
2510+
vout3d->data = out;
2511+
out += ggml_nbytes(vout3d);
2512+
2513+
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
2514+
n_embd, kv_ntok, n_layer,
2515+
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
2516+
2517+
ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
2518+
kv_ntok, n_embd, n_layer,
2519+
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
2520+
2521+
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
2522+
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
2523+
ggml_graph_compute(cpy_ctx, &gf);
24912524
}
24922525
}
24932526

24942527
const size_t written = out - dest;
2495-
const size_t expected = llama_get_state_size(ctx);
2528+
const size_t max_size = llama_get_state_size(ctx);
24962529

2497-
LLAMA_ASSERT(written == expected);
2530+
LLAMA_ASSERT(written <= max_size);
24982531

24992532
return written;
25002533
}
@@ -2552,32 +2585,55 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
25522585

25532586
// set kv cache
25542587
{
2588+
const auto & kv_self = ctx->model.kv_self;
2589+
const auto & hparams = ctx->model.hparams;
2590+
const int n_layer = hparams.n_layer;
2591+
const int n_embd = hparams.n_embd;
2592+
const int n_ctx = hparams.n_ctx;
2593+
25552594
size_t kv_size;
25562595
int kv_ntok;
25572596

25582597
memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size);
25592598
memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
25602599

25612600
if (kv_size) {
2562-
LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);
2601+
LLAMA_ASSERT(kv_self.buf.size == kv_size);
2602+
2603+
const size_t elt_size = ggml_element_size(kv_self.k);
2604+
char buffer[4096];
2605+
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
2606+
ggml_cgraph gf{};
2607+
gf.n_threads = 1;
2608+
2609+
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
2610+
kin3d->data = (void *) in;
2611+
in += ggml_nbytes(kin3d);
25632612

2564-
void * k_data = ctx->model.kv_self.k->data; // remember data pointers
2565-
void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
2613+
ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
2614+
vin3d->data = (void *) in;
2615+
in += ggml_nbytes(vin3d);
25662616

2567-
memcpy(ctx->model.kv_self.buf.addr, in, kv_size); in += kv_size;
2617+
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
2618+
n_embd, kv_ntok, n_layer,
2619+
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
25682620

2569-
ctx->model.kv_self.k->data = k_data; // restore correct data pointers
2570-
ctx->model.kv_self.v->data = v_data;
2621+
ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
2622+
kv_ntok, n_embd, n_layer,
2623+
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
25712624

2625+
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
2626+
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d));
2627+
ggml_graph_compute(cpy_ctx, &gf);
25722628
}
25732629

25742630
ctx->model.kv_self.n = kv_ntok;
25752631
}
25762632

25772633
const size_t nread = in - src;
2578-
const size_t expected = llama_get_state_size(ctx);
2634+
const size_t max_size = llama_get_state_size(ctx);
25792635

2580-
LLAMA_ASSERT(nread == expected);
2636+
LLAMA_ASSERT(nread <= max_size);
25812637

25822638
return nread;
25832639
}
@@ -2620,14 +2676,14 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi
26202676
// restore the context state
26212677
{
26222678
const size_t n_state_size_cur = file.size - file.tell();
2623-
const size_t n_state_size_exp = llama_get_state_size(ctx);
2679+
const size_t n_state_size_max = llama_get_state_size(ctx);
26242680

2625-
if (n_state_size_cur != n_state_size_exp) {
2626-
fprintf(stderr, "%s : the state size in session file didn't match! expected %zu, got %zu\n", __func__, n_state_size_exp, n_state_size_cur);
2681+
if (n_state_size_cur > n_state_size_max) {
2682+
fprintf(stderr, "%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur);
26272683
return false;
26282684
}
26292685

2630-
std::vector<uint8_t> state_data(n_state_size_cur);
2686+
std::vector<uint8_t> state_data(n_state_size_max);
26312687
file.read_raw(state_data.data(), n_state_size_cur);
26322688

26332689
llama_set_state_data(ctx, state_data.data());
@@ -2650,12 +2706,12 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
26502706

26512707
// save the context state
26522708
{
2653-
const size_t n_state_size = llama_get_state_size(ctx);
2709+
const size_t n_state_size_max = llama_get_state_size(ctx);
26542710

2655-
std::vector<uint8_t> state_data(n_state_size);
2656-
llama_copy_state_data(ctx, state_data.data());
2711+
std::vector<uint8_t> state_data(n_state_size_max);
2712+
const size_t n_state_size_cur = llama_copy_state_data(ctx, state_data.data());
26572713

2658-
file.write_raw(state_data.data(), n_state_size);
2714+
file.write_raw(state_data.data(), n_state_size_cur);
26592715
}
26602716

26612717
return true;

llama.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#define LLAMA_FILE_MAGIC 'ggjt'
2424
#define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml'
2525
#define LLAMA_SESSION_MAGIC 'ggsn'
26-
#define LLAMA_SESSION_VERSION 0
26+
#define LLAMA_SESSION_VERSION 1
2727

2828
#ifdef __cplusplus
2929
extern "C" {
@@ -127,7 +127,8 @@ extern "C" {
127127
// Sets the current rng seed.
128128
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
129129

130-
// Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
130+
// Returns the maximum size in bytes of the state (rng, logits, embedding
131+
// and kv_cache) - will often be smaller after compacting tokens
131132
LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
132133

133134
// Copies the state to the specified destination address.

0 commit comments

Comments
 (0)