@@ -1269,6 +1269,9 @@ static bool llama_eval_internal(
1269
1269
// embd_w.resize(n_vocab*N);
1270
1270
// memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
1271
1271
1272
+ // update kv token count
1273
+ lctx.model .kv_self .n = n_past + N;
1274
+
1272
1275
// extract logits
1273
1276
{
1274
1277
auto & logits_out = lctx.logits ;
@@ -2385,7 +2388,7 @@ void llama_set_rng_seed(struct llama_context * ctx, int seed) {
2385
2388
ctx->rng .seed (seed);
2386
2389
}
2387
2390
2388
- // Returns the size of the state
2391
+ // Returns the *maximum* size of the state
2389
2392
size_t llama_get_state_size (const struct llama_context * ctx) {
2390
2393
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
2391
2394
// for reference, std::mt19937(1337) serializes to 6701 bytes.
@@ -2464,21 +2467,50 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
2464
2467
2465
2468
// copy kv cache
2466
2469
{
2467
- const size_t kv_size = ctx->model .kv_self .buf .size ;
2470
+ const auto & kv_self = ctx->model .kv_self ;
2471
+ const auto & hparams = ctx->model .hparams ;
2472
+ const int n_layer = hparams.n_layer ;
2473
+ const int n_embd = hparams.n_embd ;
2474
+ const int n_ctx = hparams.n_ctx ;
2475
+
2476
+ const size_t kv_size = kv_self.buf .size ;
2468
2477
const int kv_ntok = llama_get_kv_cache_token_count (ctx);
2469
2478
2470
2479
memcpy (out, &kv_size, sizeof (kv_size)); out += sizeof (kv_size);
2471
2480
memcpy (out, &kv_ntok, sizeof (kv_ntok)); out += sizeof (kv_ntok);
2472
2481
2473
2482
if (kv_size) {
2474
- memcpy (out, ctx->model .kv_self .buf .addr , kv_size); out += kv_size;
2483
+ {
2484
+ // copy k: k layout is n_layer > n_ctx (tokens) > n_embd
2485
+ const uint8_t * k_data = (uint8_t *) kv_self.k ->data ;
2486
+ const size_t elt_size = ggml_element_size (kv_self.k );
2487
+
2488
+ for (int il = 0 ; il < n_layer; il++) {
2489
+ const size_t offset = il * n_ctx * n_embd * elt_size;
2490
+ const size_t size = kv_ntok * n_embd * elt_size;
2491
+ memcpy (out, k_data + offset, size); out += size;
2492
+ }
2493
+ }
2494
+
2495
+ {
2496
+ // copy v: v layout is n_layer > n_embd > n_ctx (tokens)
2497
+ const uint8_t * v_data = (uint8_t *) kv_self.v ->data ;
2498
+ const size_t elt_size = ggml_element_size (kv_self.v );
2499
+ const int n_layer_embd = n_layer * n_embd;
2500
+
2501
+ for (int ile = 0 ; ile < n_layer_embd; ile++) {
2502
+ const size_t offset = ile * n_ctx * elt_size;
2503
+ const size_t size = kv_ntok * elt_size;
2504
+ memcpy (out, v_data + offset, size); out += size;
2505
+ }
2506
+ }
2475
2507
}
2476
2508
}
2477
2509
2478
2510
const size_t written = out - dest;
2479
- const size_t expected = llama_get_state_size (ctx);
2511
+ const size_t max_size = llama_get_state_size (ctx);
2480
2512
2481
- LLAMA_ASSERT (written == expected );
2513
+ LLAMA_ASSERT (written <= max_size );
2482
2514
2483
2515
return written;
2484
2516
}
@@ -2536,32 +2568,55 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
2536
2568
2537
2569
// set kv cache
2538
2570
{
2571
+ const auto & kv_self = ctx->model .kv_self ;
2572
+ const auto & hparams = ctx->model .hparams ;
2573
+ const int n_layer = hparams.n_layer ;
2574
+ const int n_embd = hparams.n_embd ;
2575
+ const int n_ctx = hparams.n_ctx ;
2576
+
2539
2577
size_t kv_size;
2540
2578
int kv_ntok;
2541
2579
2542
2580
memcpy (&kv_size, in, sizeof (kv_size)); in += sizeof (kv_size);
2543
2581
memcpy (&kv_ntok, in, sizeof (kv_ntok)); in += sizeof (kv_ntok);
2544
2582
2545
2583
if (kv_size) {
2546
- LLAMA_ASSERT (ctx->model .kv_self .buf .size == kv_size);
2547
-
2548
- void * k_data = ctx->model .kv_self .k ->data ; // remember data pointers
2549
- void * v_data = ctx->model .kv_self .v ->data ; // because their value is stored in buf and overwritten by memcpy
2584
+ LLAMA_ASSERT (kv_self.buf .size == kv_size);
2550
2585
2551
- memcpy (ctx->model .kv_self .buf .addr , in, kv_size); in += kv_size;
2586
+ {
2587
+ // set k data: k layout is n_layer > n_ctx (tokens) > n_embd
2588
+ uint8_t * k_data = (uint8_t *) kv_self.k ->data ;
2589
+ const size_t elt_size = ggml_element_size (kv_self.k );
2590
+
2591
+ for (int il = 0 ; il < n_layer; il++) {
2592
+ const size_t offset = il * n_ctx * n_embd * elt_size;
2593
+ const size_t size = kv_ntok * n_embd * elt_size;
2594
+ memcpy (k_data + offset, in, size); in += size;
2595
+ }
2596
+ }
2552
2597
2553
- ctx->model .kv_self .k ->data = k_data; // restore correct data pointers
2554
- ctx->model .kv_self .v ->data = v_data;
2598
+ {
2599
+ // set v data: v layout is n_layer > n_embd > n_ctx (tokens)
2600
+ uint8_t * v_data = (uint8_t *) kv_self.v ->data ;
2601
+ const size_t elt_size = ggml_element_size (kv_self.v );
2602
+ const int n_layer_embd = n_layer * n_embd;
2603
+
2604
+ for (int ile = 0 ; ile < n_layer_embd; ile++) {
2605
+ const size_t offset = ile * n_ctx * elt_size;
2606
+ const size_t size = kv_ntok * elt_size;
2607
+ memcpy (v_data + offset, in, size); in += size;
2608
+ }
2609
+ }
2555
2610
2556
2611
}
2557
2612
2558
2613
ctx->model .kv_self .n = kv_ntok;
2559
2614
}
2560
2615
2561
2616
const size_t nread = in - src;
2562
- const size_t expected = llama_get_state_size (ctx);
2617
+ const size_t max_size = llama_get_state_size (ctx);
2563
2618
2564
- LLAMA_ASSERT (nread == expected );
2619
+ LLAMA_ASSERT (nread <= max_size );
2565
2620
2566
2621
return nread;
2567
2622
}
@@ -2604,14 +2659,14 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi
2604
2659
// restore the context state
2605
2660
{
2606
2661
const size_t n_state_size_cur = file.size - file.tell ();
2607
- const size_t n_state_size_exp = llama_get_state_size (ctx);
2662
+ const size_t n_state_size_max = llama_get_state_size (ctx);
2608
2663
2609
- if (n_state_size_cur != n_state_size_exp ) {
2610
- fprintf (stderr, " %s : the state size in session file didn't match! expected %zu, got %zu\n " , __func__, n_state_size_exp , n_state_size_cur);
2664
+ if (n_state_size_cur > n_state_size_max ) {
2665
+ fprintf (stderr, " %s : the state size in session file is too big! max %zu, got %zu\n " , __func__, n_state_size_max , n_state_size_cur);
2611
2666
return false ;
2612
2667
}
2613
2668
2614
- std::vector<uint8_t > state_data (n_state_size_cur );
2669
+ std::vector<uint8_t > state_data (n_state_size_max );
2615
2670
file.read_raw (state_data.data (), n_state_size_cur);
2616
2671
2617
2672
llama_set_state_data (ctx, state_data.data ());
@@ -2634,12 +2689,12 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
2634
2689
2635
2690
// save the context state
2636
2691
{
2637
- const size_t n_state_size = llama_get_state_size (ctx);
2692
+ const size_t n_state_size_max = llama_get_state_size (ctx);
2638
2693
2639
- std::vector<uint8_t > state_data (n_state_size );
2640
- llama_copy_state_data (ctx, state_data.data ());
2694
+ std::vector<uint8_t > state_data (n_state_size_max );
2695
+ const size_t n_state_size_cur = llama_copy_state_data (ctx, state_data.data ());
2641
2696
2642
- file.write_raw (state_data.data (), n_state_size );
2697
+ file.write_raw (state_data.data (), n_state_size_cur );
2643
2698
}
2644
2699
2645
2700
return true ;
0 commit comments