@@ -179,24 +179,37 @@ llama_context::llama_context(
179
179
// init the memory module
180
180
// TODO: for now, always create a unified KV cache
181
181
if (!hparams.vocab_only ) {
182
- kv_self.reset (static_cast <llama_kv_cache_unified *>(model.create_memory ()));
182
+ uint32_t kv_size = 0 ;
183
+ ggml_type type_k = params.type_k ;
184
+ ggml_type type_v = params.type_v ;
183
185
184
- LLAMA_LOG_DEBUG (" %s: n_ctx = %u\n " , __func__, cparams.n_ctx );
186
+ if (!llama_model_is_recurrent (&model)) {
187
+ // kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
188
+ auto * kv = static_cast <llama_kv_cache_unified *>(model.create_memory ());
185
189
186
- cparams. n_ctx = GGML_PAD (cparams. n_ctx , kv_self-> get_padding ( cparams) );
190
+ LLAMA_LOG_DEBUG ( " %s: n_ctx = %u \n " , __func__, cparams. n_ctx );
187
191
188
- LLAMA_LOG_DEBUG ( " %s: n_ctx = %u (padded) \n " , __func__, cparams.n_ctx );
192
+ cparams. n_ctx = GGML_PAD ( cparams.n_ctx , kv-> get_padding (cparams) );
189
193
190
- uint32_t kv_size = cparams.n_ctx ;
191
- ggml_type type_k = params.type_k ;
192
- ggml_type type_v = params.type_v ;
194
+ LLAMA_LOG_DEBUG (" %s: n_ctx = %u (padded)\n " , __func__, cparams.n_ctx );
195
+
196
+ kv_size = cparams.n_ctx ;
197
+ type_k = params.type_k ;
198
+ type_v = params.type_v ;
199
+
200
+ kv_self.reset (kv);
201
+ } else {
202
+ auto * kv = static_cast <llama_kv_cache_recurrent *>(model.create_memory ());
203
+
204
+ LLAMA_LOG_DEBUG (" %s: n_ctx = %u\n " , __func__, cparams.n_ctx );
193
205
194
- if (llama_model_is_recurrent (&model)) {
195
206
// Mamba needs at least as many KV cells as there are sequences kept at any time
196
207
kv_size = std::max ((uint32_t ) 1 , params.n_seq_max );
197
208
// it's probably best to keep as much precision as possible for the states
198
209
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
199
210
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
211
+
212
+ kv_self.reset (kv);
200
213
}
201
214
202
215
GGML_ASSERT (hparams.n_embd_head_k % ggml_blck_size (type_k) == 0 );
@@ -305,7 +318,7 @@ llama_context::llama_context(
305
318
int n_nodes_tg = -1 ;
306
319
307
320
// simulate full KV cache
308
- kv_self->n = kv_self-> size ;
321
+ kv_self->set_full () ;
309
322
310
323
cross.v_embd .clear ();
311
324
@@ -557,7 +570,9 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
557
570
558
571
// GGML_ASSERT(kv_self->size == n_ctx);
559
572
560
- auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get ());
573
+ const auto * kv = static_cast <const llama_kv_cache_unified *>(kv_self.get ());
574
+
575
+ auto inp = std::make_unique<llm_graph_input_k_shift>(kv);
561
576
562
577
inp->k_shift = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, cparams.n_ctx );
563
578
ggml_set_input (inp->k_shift );
@@ -573,16 +588,16 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
573
588
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base ;
574
589
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale ;
575
590
576
- ggml_tensor * rope_factors = kv_self ->cbs .get_rope_factors (n_ctx_per_seq (), il);
591
+ ggml_tensor * rope_factors = kv ->cbs .get_rope_factors (n_ctx_per_seq (), il);
577
592
578
593
ggml_tensor * k =
579
- ggml_view_3d (ctx0, kv_self ->k_l [il],
580
- n_embd_head_k, n_head_kv, kv_self ->size ,
581
- ggml_row_size (kv_self ->k_l [il]->type , n_embd_head_k),
582
- ggml_row_size (kv_self ->k_l [il]->type , n_embd_k_gqa),
594
+ ggml_view_3d (ctx0, kv ->k_l [il],
595
+ n_embd_head_k, n_head_kv, kv ->size ,
596
+ ggml_row_size (kv ->k_l [il]->type , n_embd_head_k),
597
+ ggml_row_size (kv ->k_l [il]->type , n_embd_k_gqa),
583
598
0 );
584
599
585
- ggml_tensor * cur = build_rope_shift (ctx0, k, inp->k_shift , rope_factors, freq_base_l, freq_scale_l, kv_self ->k_l [il]->buffer );
600
+ ggml_tensor * cur = build_rope_shift (ctx0, k, inp->k_shift , rope_factors, freq_base_l, freq_scale_l, kv ->k_l [il]->buffer );
586
601
587
602
ggml_build_forward_expand (gf, cur);
588
603
}
@@ -597,9 +612,11 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
597
612
ggml_cgraph * gf) const {
598
613
auto res = std::make_unique<llm_graph_result>();
599
614
615
+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self.get ());
616
+
600
617
const auto & hparams = model.hparams ;
601
618
602
- const auto & ids = kv_self ->defrag_info .ids ;
619
+ const auto & ids = kv ->defrag_info .ids ;
603
620
604
621
#if 0
605
622
// CPU defrag
@@ -689,40 +706,40 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
689
706
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
690
707
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa (il);
691
708
692
- ggml_tensor * view_k_src = ggml_view_2d (ctx0, kv_self ->k_l [il],
709
+ ggml_tensor * view_k_src = ggml_view_2d (ctx0, kv ->k_l [il],
693
710
n_embd_k_gqa, nm,
694
- ggml_row_size (kv_self ->k_l [il]->type , n_embd_k_gqa),
695
- ggml_row_size (kv_self ->k_l [il]->type , n_embd_k_gqa*i));
711
+ ggml_row_size (kv ->k_l [il]->type , n_embd_k_gqa),
712
+ ggml_row_size (kv ->k_l [il]->type , n_embd_k_gqa*i));
696
713
697
- ggml_tensor * view_k_dst = ggml_view_2d (ctx0, kv_self ->k_l [il],
714
+ ggml_tensor * view_k_dst = ggml_view_2d (ctx0, kv ->k_l [il],
698
715
n_embd_k_gqa, nm,
699
- ggml_row_size (kv_self ->k_l [il]->type , n_embd_k_gqa),
700
- ggml_row_size (kv_self ->k_l [il]->type , n_embd_k_gqa*id));
716
+ ggml_row_size (kv ->k_l [il]->type , n_embd_k_gqa),
717
+ ggml_row_size (kv ->k_l [il]->type , n_embd_k_gqa*id));
701
718
702
719
ggml_tensor * view_v_src;
703
720
ggml_tensor * view_v_dst;
704
721
705
722
if (cparams.flash_attn ) {
706
723
// NOTE: the V cache is not transposed when using flash attention
707
- view_v_src = ggml_view_2d (ctx0, kv_self ->v_l [il],
724
+ view_v_src = ggml_view_2d (ctx0, kv ->v_l [il],
708
725
n_embd_v_gqa, nm,
709
- ggml_row_size (kv_self ->v_l [il]->type , n_embd_v_gqa),
710
- ggml_row_size (kv_self ->v_l [il]->type , n_embd_v_gqa*i));
726
+ ggml_row_size (kv ->v_l [il]->type , n_embd_v_gqa),
727
+ ggml_row_size (kv ->v_l [il]->type , n_embd_v_gqa*i));
711
728
712
- view_v_dst = ggml_view_2d (ctx0, kv_self ->v_l [il],
729
+ view_v_dst = ggml_view_2d (ctx0, kv ->v_l [il],
713
730
n_embd_v_gqa, nm,
714
- ggml_row_size (kv_self ->v_l [il]->type , n_embd_v_gqa),
715
- ggml_row_size (kv_self ->v_l [il]->type , n_embd_v_gqa*id));
731
+ ggml_row_size (kv ->v_l [il]->type , n_embd_v_gqa),
732
+ ggml_row_size (kv ->v_l [il]->type , n_embd_v_gqa*id));
716
733
} else {
717
- view_v_src = ggml_view_2d (ctx0, kv_self ->v_l [il],
734
+ view_v_src = ggml_view_2d (ctx0, kv ->v_l [il],
718
735
nm, n_embd_v_gqa,
719
- ggml_row_size (kv_self ->v_l [il]->type , kv_self ->size ),
720
- ggml_row_size (kv_self ->v_l [il]->type , i));
736
+ ggml_row_size (kv ->v_l [il]->type , kv ->size ),
737
+ ggml_row_size (kv ->v_l [il]->type , i));
721
738
722
- view_v_dst = ggml_view_2d (ctx0, kv_self ->v_l [il],
739
+ view_v_dst = ggml_view_2d (ctx0, kv ->v_l [il],
723
740
nm, n_embd_v_gqa,
724
- ggml_row_size (kv_self ->v_l [il]->type , kv_self ->size ),
725
- ggml_row_size (kv_self ->v_l [il]->type , id));
741
+ ggml_row_size (kv ->v_l [il]->type , kv ->size ),
742
+ ggml_row_size (kv ->v_l [il]->type , id));
726
743
}
727
744
728
745
ggml_build_forward_expand (gf, ggml_cpy (ctx0, view_k_src, view_k_dst));
@@ -739,13 +756,11 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
739
756
}
740
757
741
758
void llama_context::kv_self_update () {
742
- auto & kv = kv_self;
743
-
744
759
bool need_reserve = false ;
745
760
746
- if (kv-> has_shift ) {
747
- if (!kv ->get_can_shift ()) {
748
- GGML_ABORT (" The current context does not support K-shift" );
761
+ if (kv_self-> get_has_shift () ) {
762
+ if (!kv_self ->get_can_shift ()) {
763
+ GGML_ABORT (" The current KV cache / model configuration does not support K-shift" );
749
764
}
750
765
751
766
LLAMA_LOG_DEBUG (" %s: applying K-shift\n " , __func__);
@@ -768,6 +783,8 @@ void llama_context::kv_self_update() {
768
783
}
769
784
770
785
{
786
+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self.get ());
787
+
771
788
kv->has_shift = false ;
772
789
773
790
for (uint32_t i = 0 ; i < kv->size ; ++i) {
@@ -777,9 +794,11 @@ void llama_context::kv_self_update() {
777
794
}
778
795
779
796
// defragment the KV cache if needed
780
- if (kv-> do_defrag ) {
797
+ if (kv_self-> get_do_defrag () ) {
781
798
LLAMA_LOG_DEBUG (" %s: defragmenting KV cache\n " , __func__);
782
799
800
+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self.get ());
801
+
783
802
if (kv->defrag_prepare (graph_max_nodes ())) {
784
803
ggml_backend_sched_reset (sched.get ());
785
804
@@ -808,7 +827,7 @@ void llama_context::kv_self_update() {
808
827
uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
809
828
810
829
// simulate full KV cache
811
- kv_self->n = kv_self-> size ;
830
+ kv_self->set_full () ;
812
831
813
832
llama_token token = model.vocab .token_bos (); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
814
833
llama_ubatch ubatch = { true , n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr , nullptr , nullptr , nullptr , nullptr };
@@ -1028,8 +1047,8 @@ int llama_context::encode(llama_batch & inp_batch) {
1028
1047
}
1029
1048
1030
1049
// temporary allocate memory for the input batch if needed
1031
- // TODO: this is incorrect for multiple sequences because pos_max () is the maximum across all sequences
1032
- llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : kv_self->pos_max () + 1 );
1050
+ // TODO: this is incorrect for multiple sequences because get_pos_max () is the maximum across all sequences
1051
+ llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max () + 1 );
1033
1052
1034
1053
const llama_batch & batch = batch_allocr.batch ;
1035
1054
const int32_t n_tokens = batch.n_tokens ;
@@ -1193,8 +1212,8 @@ int llama_context::decode(llama_batch & inp_batch) {
1193
1212
}
1194
1213
1195
1214
// temporary allocate memory for the input batch if needed
1196
- // TODO: this is incorrect for multiple sequences because pos_max () is the maximum across all sequences
1197
- llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : kv_self->pos_max () + 1 );
1215
+ // TODO: this is incorrect for multiple sequences because get_pos_max () is the maximum across all sequences
1216
+ llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max () + 1 );
1198
1217
1199
1218
const llama_batch & batch = batch_allocr.batch ;
1200
1219
@@ -1249,8 +1268,10 @@ int llama_context::decode(llama_batch & inp_batch) {
1249
1268
1250
1269
const bool logits_all = n_outputs_all == n_tokens_all;
1251
1270
1271
+ const bool is_recurrent = llama_model_is_recurrent (&model);
1272
+
1252
1273
sbatch.from_batch (batch, n_embd,
1253
- /* simple_split */ !kv_self-> recurrent ,
1274
+ /* simple_split */ !is_recurrent ,
1254
1275
/* logits_all */ logits_all);
1255
1276
1256
1277
// reserve output buffer
@@ -1269,7 +1290,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1269
1290
1270
1291
const auto & n_ubatch = cparams.n_ubatch ;
1271
1292
1272
- if (kv_self-> recurrent ) {
1293
+ if (is_recurrent ) {
1273
1294
if (embd_pooled) {
1274
1295
// Pooled embeddings cannot be split across ubatches (yet)
1275
1296
ubatch = sbatch.split_seq (cparams.n_ubatch );
@@ -1307,17 +1328,19 @@ int llama_context::decode(llama_batch & inp_batch) {
1307
1328
return 1 ;
1308
1329
}
1309
1330
1310
- if (!kv_self->recurrent ) {
1331
+ if (!is_recurrent) {
1332
+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self.get ());
1333
+
1311
1334
// a heuristic, to avoid attending the full cache if it is not yet utilized
1312
1335
// after enough generations, the benefit from this heuristic disappears
1313
1336
// if we start defragmenting the cache, the benefit from this will be more important
1314
- const uint32_t pad = kv_self->get_padding (cparams);
1315
- kv_self->n = std::min (kv_self->size , std::max (pad, GGML_PAD (kv_self->cell_max (), pad)));
1337
+ const uint32_t pad = kv->get_padding (cparams);
1338
+ kv->n = std::min (kv->size , std::max (pad, GGML_PAD (kv->cell_max (), pad)));
1339
+
1340
+ // printf("kv.n = %5d, kv.used = %5d, kv.head = %5d\n", kv->n, kv->used, kv->head);
1316
1341
}
1317
1342
}
1318
1343
1319
- // printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
1320
-
1321
1344
ggml_backend_sched_reset (sched.get ());
1322
1345
ggml_backend_sched_set_eval_callback (sched.get (), cparams.cb_eval , cparams.cb_eval_user_data );
1323
1346
@@ -1457,10 +1480,12 @@ int llama_context::decode(llama_batch & inp_batch) {
1457
1480
// synchronize();
1458
1481
1459
1482
// decide if we need to defrag the kv cache
1460
- if (cparams.causal_attn && cparams.defrag_thold > 0 .0f ) {
1483
+ if (!llama_model_is_recurrent (&model) && cparams.causal_attn && cparams.defrag_thold > 0 .0f ) {
1484
+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self.get ());
1485
+
1461
1486
// - do not defrag small contexts (i.e. < 2048 tokens)
1462
1487
// - count the padding towards the number of used tokens
1463
- const float fragmentation = kv_self ->n >= 2048 ? std::max (0 .0f , 1 .0f - float (kv_self ->used + kv_self ->get_padding (cparams))/float (kv_self ->n )) : 0 .0f ;
1488
+ const float fragmentation = kv ->n >= 2048 ? std::max (0 .0f , 1 .0f - float (kv ->used + kv ->get_padding (cparams))/float (kv ->n )) : 0 .0f ;
1464
1489
1465
1490
// queue defragmentation for next llama_kv_cache_update
1466
1491
if (fragmentation > cparams.defrag_thold ) {
0 commit comments