@@ -1285,6 +1285,9 @@ static bool llama_eval_internal(
1285
1285
// embd_w.resize(n_vocab*N);
1286
1286
// memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
1287
1287
1288
+ // update kv token count
1289
+ lctx.model .kv_self .n = n_past + N;
1290
+
1288
1291
// extract logits
1289
1292
{
1290
1293
auto & logits_out = lctx.logits ;
@@ -2401,7 +2404,7 @@ void llama_set_rng_seed(struct llama_context * ctx, int seed) {
2401
2404
ctx->rng .seed (seed);
2402
2405
}
2403
2406
2404
- // Returns the size of the state
2407
+ // Returns the *maximum* size of the state
2405
2408
size_t llama_get_state_size (const struct llama_context * ctx) {
2406
2409
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
2407
2410
// for reference, std::mt19937(1337) serializes to 6701 bytes.
@@ -2480,21 +2483,51 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
2480
2483
2481
2484
// copy kv cache
2482
2485
{
2483
- const size_t kv_size = ctx->model .kv_self .buf .size ;
2486
+ const auto & kv_self = ctx->model .kv_self ;
2487
+ const auto & hparams = ctx->model .hparams ;
2488
+ const int n_layer = hparams.n_layer ;
2489
+ const int n_embd = hparams.n_embd ;
2490
+ const int n_ctx = hparams.n_ctx ;
2491
+
2492
+ const size_t kv_size = kv_self.buf .size ;
2484
2493
const int kv_ntok = llama_get_kv_cache_token_count (ctx);
2485
2494
2486
2495
memcpy (out, &kv_size, sizeof (kv_size)); out += sizeof (kv_size);
2487
2496
memcpy (out, &kv_ntok, sizeof (kv_ntok)); out += sizeof (kv_ntok);
2488
2497
2489
2498
if (kv_size) {
2490
- memcpy (out, ctx->model .kv_self .buf .addr , kv_size); out += kv_size;
2499
+ const size_t elt_size = ggml_element_size (kv_self.k );
2500
+ char buffer[4096 ];
2501
+ ggml_context * cpy_ctx = ggml_init ({ sizeof (buffer), buffer, /* no_alloc */ true });
2502
+ ggml_cgraph gf{};
2503
+ gf.n_threads = 1 ;
2504
+
2505
+ ggml_tensor * kout3d = ggml_new_tensor_3d (cpy_ctx, kv_self.k ->type , n_embd, kv_ntok, n_layer);
2506
+ kout3d->data = out;
2507
+ out += ggml_nbytes (kout3d);
2508
+
2509
+ ggml_tensor * vout3d = ggml_new_tensor_3d (cpy_ctx, kv_self.v ->type , kv_ntok, n_embd, n_layer);
2510
+ vout3d->data = out;
2511
+ out += ggml_nbytes (vout3d);
2512
+
2513
+ ggml_tensor * k3d = ggml_view_3d (cpy_ctx, kv_self.k ,
2514
+ n_embd, kv_ntok, n_layer,
2515
+ elt_size*n_embd, elt_size*n_embd*n_ctx, 0 );
2516
+
2517
+ ggml_tensor * v3d = ggml_view_3d (cpy_ctx, kv_self.v ,
2518
+ kv_ntok, n_embd, n_layer,
2519
+ elt_size*n_ctx, elt_size*n_ctx*n_embd, 0 );
2520
+
2521
+ ggml_build_forward_expand (&gf, ggml_cpy (cpy_ctx, k3d, kout3d));
2522
+ ggml_build_forward_expand (&gf, ggml_cpy (cpy_ctx, v3d, vout3d));
2523
+ ggml_graph_compute (cpy_ctx, &gf);
2491
2524
}
2492
2525
}
2493
2526
2494
2527
const size_t written = out - dest;
2495
- const size_t expected = llama_get_state_size (ctx);
2528
+ const size_t max_size = llama_get_state_size (ctx);
2496
2529
2497
- LLAMA_ASSERT (written == expected );
2530
+ LLAMA_ASSERT (written <= max_size );
2498
2531
2499
2532
return written;
2500
2533
}
@@ -2552,32 +2585,55 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
2552
2585
2553
2586
// set kv cache
2554
2587
{
2588
+ const auto & kv_self = ctx->model .kv_self ;
2589
+ const auto & hparams = ctx->model .hparams ;
2590
+ const int n_layer = hparams.n_layer ;
2591
+ const int n_embd = hparams.n_embd ;
2592
+ const int n_ctx = hparams.n_ctx ;
2593
+
2555
2594
size_t kv_size;
2556
2595
int kv_ntok;
2557
2596
2558
2597
memcpy (&kv_size, in, sizeof (kv_size)); in += sizeof (kv_size);
2559
2598
memcpy (&kv_ntok, in, sizeof (kv_ntok)); in += sizeof (kv_ntok);
2560
2599
2561
2600
if (kv_size) {
2562
- LLAMA_ASSERT (ctx->model .kv_self .buf .size == kv_size);
2601
+ LLAMA_ASSERT (kv_self.buf .size == kv_size);
2602
+
2603
+ const size_t elt_size = ggml_element_size (kv_self.k );
2604
+ char buffer[4096 ];
2605
+ ggml_context * cpy_ctx = ggml_init ({ sizeof (buffer), buffer, /* no_alloc */ true });
2606
+ ggml_cgraph gf{};
2607
+ gf.n_threads = 1 ;
2608
+
2609
+ ggml_tensor * kin3d = ggml_new_tensor_3d (cpy_ctx, kv_self.k ->type , n_embd, kv_ntok, n_layer);
2610
+ kin3d->data = (void *) in;
2611
+ in += ggml_nbytes (kin3d);
2563
2612
2564
- void * k_data = ctx->model .kv_self .k ->data ; // remember data pointers
2565
- void * v_data = ctx->model .kv_self .v ->data ; // because their value is stored in buf and overwritten by memcpy
2613
+ ggml_tensor * vin3d = ggml_new_tensor_3d (cpy_ctx, kv_self.v ->type , kv_ntok, n_embd, n_layer);
2614
+ vin3d->data = (void *) in;
2615
+ in += ggml_nbytes (vin3d);
2566
2616
2567
- memcpy (ctx->model .kv_self .buf .addr , in, kv_size); in += kv_size;
2617
+ ggml_tensor * k3d = ggml_view_3d (cpy_ctx, kv_self.k ,
2618
+ n_embd, kv_ntok, n_layer,
2619
+ elt_size*n_embd, elt_size*n_embd*n_ctx, 0 );
2568
2620
2569
- ctx->model .kv_self .k ->data = k_data; // restore correct data pointers
2570
- ctx->model .kv_self .v ->data = v_data;
2621
+ ggml_tensor * v3d = ggml_view_3d (cpy_ctx, kv_self.v ,
2622
+ kv_ntok, n_embd, n_layer,
2623
+ elt_size*n_ctx, elt_size*n_ctx*n_embd, 0 );
2571
2624
2625
+ ggml_build_forward_expand (&gf, ggml_cpy (cpy_ctx, kin3d, k3d));
2626
+ ggml_build_forward_expand (&gf, ggml_cpy (cpy_ctx, vin3d, v3d));
2627
+ ggml_graph_compute (cpy_ctx, &gf);
2572
2628
}
2573
2629
2574
2630
ctx->model .kv_self .n = kv_ntok;
2575
2631
}
2576
2632
2577
2633
const size_t nread = in - src;
2578
- const size_t expected = llama_get_state_size (ctx);
2634
+ const size_t max_size = llama_get_state_size (ctx);
2579
2635
2580
- LLAMA_ASSERT (nread == expected );
2636
+ LLAMA_ASSERT (nread <= max_size );
2581
2637
2582
2638
return nread;
2583
2639
}
@@ -2620,14 +2676,14 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi
2620
2676
// restore the context state
2621
2677
{
2622
2678
const size_t n_state_size_cur = file.size - file.tell ();
2623
- const size_t n_state_size_exp = llama_get_state_size (ctx);
2679
+ const size_t n_state_size_max = llama_get_state_size (ctx);
2624
2680
2625
- if (n_state_size_cur != n_state_size_exp ) {
2626
- fprintf (stderr, " %s : the state size in session file didn't match! expected %zu, got %zu\n " , __func__, n_state_size_exp , n_state_size_cur);
2681
+ if (n_state_size_cur > n_state_size_max ) {
2682
+ fprintf (stderr, " %s : the state size in session file is too big! max %zu, got %zu\n " , __func__, n_state_size_max , n_state_size_cur);
2627
2683
return false ;
2628
2684
}
2629
2685
2630
- std::vector<uint8_t > state_data (n_state_size_cur );
2686
+ std::vector<uint8_t > state_data (n_state_size_max );
2631
2687
file.read_raw (state_data.data (), n_state_size_cur);
2632
2688
2633
2689
llama_set_state_data (ctx, state_data.data ());
@@ -2650,12 +2706,12 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
2650
2706
2651
2707
// save the context state
2652
2708
{
2653
- const size_t n_state_size = llama_get_state_size (ctx);
2709
+ const size_t n_state_size_max = llama_get_state_size (ctx);
2654
2710
2655
- std::vector<uint8_t > state_data (n_state_size );
2656
- llama_copy_state_data (ctx, state_data.data ());
2711
+ std::vector<uint8_t > state_data (n_state_size_max );
2712
+ const size_t n_state_size_cur = llama_copy_state_data (ctx, state_data.data ());
2657
2713
2658
- file.write_raw (state_data.data (), n_state_size );
2714
+ file.write_raw (state_data.data (), n_state_size_cur );
2659
2715
}
2660
2716
2661
2717
return true ;
0 commit comments