@@ -1285,6 +1285,9 @@ static bool llama_eval_internal(
1285
1285
// embd_w.resize(n_vocab*N);
1286
1286
// memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
1287
1287
1288
+ // update kv token count
1289
+ lctx.model .kv_self .n = n_past + N;
1290
+
1288
1291
// extract logits
1289
1292
{
1290
1293
auto & logits_out = lctx.logits ;
@@ -2401,7 +2404,7 @@ void llama_set_rng_seed(struct llama_context * ctx, int seed) {
2401
2404
ctx->rng .seed (seed);
2402
2405
}
2403
2406
2404
- // Returns the size of the state
2407
+ // Returns the *maximum* size of the state
2405
2408
size_t llama_get_state_size (const struct llama_context * ctx) {
2406
2409
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
2407
2410
// for reference, std::mt19937(1337) serializes to 6701 bytes.
@@ -2480,21 +2483,50 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
2480
2483
2481
2484
// copy kv cache
2482
2485
{
2483
- const size_t kv_size = ctx->model .kv_self .buf .size ;
2486
+ const auto & kv_self = ctx->model .kv_self ;
2487
+ const auto & hparams = ctx->model .hparams ;
2488
+ const int n_layer = hparams.n_layer ;
2489
+ const int n_embd = hparams.n_embd ;
2490
+ const int n_ctx = hparams.n_ctx ;
2491
+
2492
+ const size_t kv_size = kv_self.buf .size ;
2484
2493
const int kv_ntok = llama_get_kv_cache_token_count (ctx);
2485
2494
2486
2495
memcpy (out, &kv_size, sizeof (kv_size)); out += sizeof (kv_size);
2487
2496
memcpy (out, &kv_ntok, sizeof (kv_ntok)); out += sizeof (kv_ntok);
2488
2497
2489
2498
if (kv_size) {
2490
- memcpy (out, ctx->model .kv_self .buf .addr , kv_size); out += kv_size;
2499
+ {
2500
+ // copy k: k layout is n_layer > n_ctx (tokens) > n_embd
2501
+ const uint8_t * k_data = (uint8_t *) kv_self.k ->data ;
2502
+ const size_t elt_size = ggml_element_size (kv_self.k );
2503
+
2504
+ for (int il = 0 ; il < n_layer; il++) {
2505
+ const size_t offset = il * n_ctx * n_embd * elt_size;
2506
+ const size_t size = kv_ntok * n_embd * elt_size;
2507
+ memcpy (out, k_data + offset, size); out += size;
2508
+ }
2509
+ }
2510
+
2511
+ {
2512
+ // copy v: v layout is n_layer > n_embd > n_ctx (tokens)
2513
+ const uint8_t * v_data = (uint8_t *) kv_self.v ->data ;
2514
+ const size_t elt_size = ggml_element_size (kv_self.v );
2515
+ const int n_layer_embd = n_layer * n_embd;
2516
+
2517
+ for (int ile = 0 ; ile < n_layer_embd; ile++) {
2518
+ const size_t offset = ile * n_ctx * elt_size;
2519
+ const size_t size = kv_ntok * elt_size;
2520
+ memcpy (out, v_data + offset, size); out += size;
2521
+ }
2522
+ }
2491
2523
}
2492
2524
}
2493
2525
2494
2526
const size_t written = out - dest;
2495
- const size_t expected = llama_get_state_size (ctx);
2527
+ const size_t max_size = llama_get_state_size (ctx);
2496
2528
2497
- LLAMA_ASSERT (written == expected );
2529
+ LLAMA_ASSERT (written <= max_size );
2498
2530
2499
2531
return written;
2500
2532
}
@@ -2552,32 +2584,55 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
2552
2584
2553
2585
// set kv cache
2554
2586
{
2587
+ const auto & kv_self = ctx->model .kv_self ;
2588
+ const auto & hparams = ctx->model .hparams ;
2589
+ const int n_layer = hparams.n_layer ;
2590
+ const int n_embd = hparams.n_embd ;
2591
+ const int n_ctx = hparams.n_ctx ;
2592
+
2555
2593
size_t kv_size;
2556
2594
int kv_ntok;
2557
2595
2558
2596
memcpy (&kv_size, in, sizeof (kv_size)); in += sizeof (kv_size);
2559
2597
memcpy (&kv_ntok, in, sizeof (kv_ntok)); in += sizeof (kv_ntok);
2560
2598
2561
2599
if (kv_size) {
2562
- LLAMA_ASSERT (ctx->model .kv_self .buf .size == kv_size);
2563
-
2564
- void * k_data = ctx->model .kv_self .k ->data ; // remember data pointers
2565
- void * v_data = ctx->model .kv_self .v ->data ; // because their value is stored in buf and overwritten by memcpy
2600
+ LLAMA_ASSERT (kv_self.buf .size == kv_size);
2566
2601
2567
- memcpy (ctx->model .kv_self .buf .addr , in, kv_size); in += kv_size;
2602
+ {
2603
+ // set k data: k layout is n_layer > n_ctx (tokens) > n_embd
2604
+ uint8_t * k_data = (uint8_t *) kv_self.k ->data ;
2605
+ const size_t elt_size = ggml_element_size (kv_self.k );
2606
+
2607
+ for (int il = 0 ; il < n_layer; il++) {
2608
+ const size_t offset = il * n_ctx * n_embd * elt_size;
2609
+ const size_t size = kv_ntok * n_embd * elt_size;
2610
+ memcpy (k_data + offset, in, size); in += size;
2611
+ }
2612
+ }
2568
2613
2569
- ctx->model .kv_self .k ->data = k_data; // restore correct data pointers
2570
- ctx->model .kv_self .v ->data = v_data;
2614
+ {
2615
+ // set v data: v layout is n_layer > n_embd > n_ctx (tokens)
2616
+ uint8_t * v_data = (uint8_t *) kv_self.v ->data ;
2617
+ const size_t elt_size = ggml_element_size (kv_self.v );
2618
+ const int n_layer_embd = n_layer * n_embd;
2619
+
2620
+ for (int ile = 0 ; ile < n_layer_embd; ile++) {
2621
+ const size_t offset = ile * n_ctx * elt_size;
2622
+ const size_t size = kv_ntok * elt_size;
2623
+ memcpy (v_data + offset, in, size); in += size;
2624
+ }
2625
+ }
2571
2626
2572
2627
}
2573
2628
2574
2629
ctx->model .kv_self .n = kv_ntok;
2575
2630
}
2576
2631
2577
2632
const size_t nread = in - src;
2578
- const size_t expected = llama_get_state_size (ctx);
2633
+ const size_t max_size = llama_get_state_size (ctx);
2579
2634
2580
- LLAMA_ASSERT (nread == expected );
2635
+ LLAMA_ASSERT (nread <= max_size );
2581
2636
2582
2637
return nread;
2583
2638
}
@@ -2620,14 +2675,14 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi
2620
2675
// restore the context state
2621
2676
{
2622
2677
const size_t n_state_size_cur = file.size - file.tell ();
2623
- const size_t n_state_size_exp = llama_get_state_size (ctx);
2678
+ const size_t n_state_size_max = llama_get_state_size (ctx);
2624
2679
2625
- if (n_state_size_cur != n_state_size_exp ) {
2626
- fprintf (stderr, " %s : the state size in session file didn't match! expected %zu, got %zu\n " , __func__, n_state_size_exp , n_state_size_cur);
2680
+ if (n_state_size_cur > n_state_size_max ) {
2681
+ fprintf (stderr, " %s : the state size in session file is too big! max %zu, got %zu\n " , __func__, n_state_size_max , n_state_size_cur);
2627
2682
return false ;
2628
2683
}
2629
2684
2630
- std::vector<uint8_t > state_data (n_state_size_cur );
2685
+ std::vector<uint8_t > state_data (n_state_size_max );
2631
2686
file.read_raw (state_data.data (), n_state_size_cur);
2632
2687
2633
2688
llama_set_state_data (ctx, state_data.data ());
@@ -2650,12 +2705,12 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
2650
2705
2651
2706
// save the context state
2652
2707
{
2653
- const size_t n_state_size = llama_get_state_size (ctx);
2708
+ const size_t n_state_size_max = llama_get_state_size (ctx);
2654
2709
2655
- std::vector<uint8_t > state_data (n_state_size );
2656
- llama_copy_state_data (ctx, state_data.data ());
2710
+ std::vector<uint8_t > state_data (n_state_size_max );
2711
+ const size_t n_state_size_cur = llama_copy_state_data (ctx, state_data.data ());
2657
2712
2658
- file.write_raw (state_data.data (), n_state_size );
2713
+ file.write_raw (state_data.data (), n_state_size_cur );
2659
2714
}
2660
2715
2661
2716
return true ;
0 commit comments