Skip to content

Commit a4e5579

Browse files
committed
llama : only copy used KV cache in get / set state
1 parent f4cef87 commit a4e5579

File tree

2 files changed

+80
-24
lines changed

2 files changed

+80
-24
lines changed

llama.cpp

Lines changed: 77 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,6 +1269,9 @@ static bool llama_eval_internal(
12691269
//embd_w.resize(n_vocab*N);
12701270
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
12711271

1272+
// update kv token count
1273+
lctx.model.kv_self.n = n_past + N;
1274+
12721275
// extract logits
12731276
{
12741277
auto & logits_out = lctx.logits;
@@ -2385,7 +2388,7 @@ void llama_set_rng_seed(struct llama_context * ctx, int seed) {
23852388
ctx->rng.seed(seed);
23862389
}
23872390

2388-
// Returns the size of the state
2391+
// Returns the *maximum* size of the state
23892392
size_t llama_get_state_size(const struct llama_context * ctx) {
23902393
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
23912394
// for reference, std::mt19937(1337) serializes to 6701 bytes.
@@ -2464,21 +2467,50 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
24642467

24652468
// copy kv cache
24662469
{
2467-
const size_t kv_size = ctx->model.kv_self.buf.size;
2470+
const auto & kv_self = ctx->model.kv_self;
2471+
const auto & hparams = ctx->model.hparams;
2472+
const int n_layer = hparams.n_layer;
2473+
const int n_embd = hparams.n_embd;
2474+
const int n_ctx = hparams.n_ctx;
2475+
2476+
const size_t kv_size = kv_self.buf.size;
24682477
const int kv_ntok = llama_get_kv_cache_token_count(ctx);
24692478

24702479
memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
24712480
memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
24722481

24732482
if (kv_size) {
2474-
memcpy(out, ctx->model.kv_self.buf.addr, kv_size); out += kv_size;
2483+
{
2484+
// copy k: k layout is n_layer > n_ctx (tokens) > n_embd
2485+
const uint8_t * k_data = (uint8_t *) kv_self.k->data;
2486+
const size_t elt_size = ggml_element_size(kv_self.k);
2487+
2488+
for (int il = 0; il < n_layer; il++) {
2489+
const size_t offset = il * n_ctx * n_embd * elt_size;
2490+
const size_t size = kv_ntok * n_embd * elt_size;
2491+
memcpy(out, k_data + offset, size); out += size;
2492+
}
2493+
}
2494+
2495+
{
2496+
// copy v: v layout is n_layer > n_embd > n_ctx (tokens)
2497+
const uint8_t * v_data = (uint8_t *) kv_self.v->data;
2498+
const size_t elt_size = ggml_element_size(kv_self.v);
2499+
const int n_layer_embd = n_layer * n_embd;
2500+
2501+
for (int ile = 0; ile < n_layer_embd; ile++) {
2502+
const size_t offset = ile * n_ctx * elt_size;
2503+
const size_t size = kv_ntok * elt_size;
2504+
memcpy(out, v_data + offset, size); out += size;
2505+
}
2506+
}
24752507
}
24762508
}
24772509

24782510
const size_t written = out - dest;
2479-
const size_t expected = llama_get_state_size(ctx);
2511+
const size_t max_size = llama_get_state_size(ctx);
24802512

2481-
LLAMA_ASSERT(written == expected);
2513+
LLAMA_ASSERT(written <= max_size);
24822514

24832515
return written;
24842516
}
@@ -2536,32 +2568,55 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
25362568

25372569
// set kv cache
25382570
{
2571+
const auto & kv_self = ctx->model.kv_self;
2572+
const auto & hparams = ctx->model.hparams;
2573+
const int n_layer = hparams.n_layer;
2574+
const int n_embd = hparams.n_embd;
2575+
const int n_ctx = hparams.n_ctx;
2576+
25392577
size_t kv_size;
25402578
int kv_ntok;
25412579

25422580
memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size);
25432581
memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
25442582

25452583
if (kv_size) {
2546-
LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);
2547-
2548-
void * k_data = ctx->model.kv_self.k->data; // remember data pointers
2549-
void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
2584+
LLAMA_ASSERT(kv_self.buf.size == kv_size);
25502585

2551-
memcpy(ctx->model.kv_self.buf.addr, in, kv_size); in += kv_size;
2586+
{
2587+
// set k data: k layout is n_layer > n_ctx (tokens) > n_embd
2588+
uint8_t * k_data = (uint8_t *) kv_self.k->data;
2589+
const size_t elt_size = ggml_element_size(kv_self.k);
2590+
2591+
for (int il = 0; il < n_layer; il++) {
2592+
const size_t offset = il * n_ctx * n_embd * elt_size;
2593+
const size_t size = kv_ntok * n_embd * elt_size;
2594+
memcpy(k_data + offset, in, size); in += size;
2595+
}
2596+
}
25522597

2553-
ctx->model.kv_self.k->data = k_data; // restore correct data pointers
2554-
ctx->model.kv_self.v->data = v_data;
2598+
{
2599+
// set v data: v layout is n_layer > n_embd > n_ctx (tokens)
2600+
uint8_t * v_data = (uint8_t *) kv_self.v->data;
2601+
const size_t elt_size = ggml_element_size(kv_self.v);
2602+
const int n_layer_embd = n_layer * n_embd;
2603+
2604+
for (int ile = 0; ile < n_layer_embd; ile++) {
2605+
const size_t offset = ile * n_ctx * elt_size;
2606+
const size_t size = kv_ntok * elt_size;
2607+
memcpy(v_data + offset, in, size); in += size;
2608+
}
2609+
}
25552610

25562611
}
25572612

25582613
ctx->model.kv_self.n = kv_ntok;
25592614
}
25602615

25612616
const size_t nread = in - src;
2562-
const size_t expected = llama_get_state_size(ctx);
2617+
const size_t max_size = llama_get_state_size(ctx);
25632618

2564-
LLAMA_ASSERT(nread == expected);
2619+
LLAMA_ASSERT(nread <= max_size);
25652620

25662621
return nread;
25672622
}
@@ -2604,14 +2659,14 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi
26042659
// restore the context state
26052660
{
26062661
const size_t n_state_size_cur = file.size - file.tell();
2607-
const size_t n_state_size_exp = llama_get_state_size(ctx);
2662+
const size_t n_state_size_max = llama_get_state_size(ctx);
26082663

2609-
if (n_state_size_cur != n_state_size_exp) {
2610-
fprintf(stderr, "%s : the state size in session file didn't match! expected %zu, got %zu\n", __func__, n_state_size_exp, n_state_size_cur);
2664+
if (n_state_size_cur > n_state_size_max) {
2665+
fprintf(stderr, "%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur);
26112666
return false;
26122667
}
26132668

2614-
std::vector<uint8_t> state_data(n_state_size_cur);
2669+
std::vector<uint8_t> state_data(n_state_size_max);
26152670
file.read_raw(state_data.data(), n_state_size_cur);
26162671

26172672
llama_set_state_data(ctx, state_data.data());
@@ -2634,12 +2689,12 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
26342689

26352690
// save the context state
26362691
{
2637-
const size_t n_state_size = llama_get_state_size(ctx);
2692+
const size_t n_state_size_max = llama_get_state_size(ctx);
26382693

2639-
std::vector<uint8_t> state_data(n_state_size);
2640-
llama_copy_state_data(ctx, state_data.data());
2694+
std::vector<uint8_t> state_data(n_state_size_max);
2695+
const size_t n_state_size_cur = llama_copy_state_data(ctx, state_data.data());
26412696

2642-
file.write_raw(state_data.data(), n_state_size);
2697+
file.write_raw(state_data.data(), n_state_size_cur);
26432698
}
26442699

26452700
return true;

llama.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#define LLAMA_FILE_MAGIC 'ggjt'
2424
#define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml'
2525
#define LLAMA_SESSION_MAGIC 'ggsn'
26-
#define LLAMA_SESSION_VERSION 0
26+
#define LLAMA_SESSION_VERSION 1
2727

2828
#ifdef __cplusplus
2929
extern "C" {
@@ -127,7 +127,8 @@ extern "C" {
127127
// Sets the current rng seed.
128128
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
129129

130-
// Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
130+
// Returns the maximum size in bytes of the state (rng, logits, embedding
131+
// and kv_cache) - will often be smaller after compacting tokens
131132
LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
132133

133134
// Copies the state to the specified destination address.

0 commit comments

Comments
 (0)