@@ -177,65 +177,35 @@ llama_context::llama_context(
177
177
}
178
178
179
179
// init the memory module
180
- // TODO: for now, always create a unified KV cache
181
180
if (!hparams.vocab_only ) {
182
- uint32_t kv_size = 0 ;
183
- ggml_type type_k = params.type_k ;
184
- ggml_type type_v = params.type_v ;
181
+ LLAMA_LOG_DEBUG (" %s: n_ctx = %u\n " , __func__, cparams.n_ctx );
185
182
186
183
if (!llama_model_is_recurrent (&model)) {
187
- LLAMA_LOG_DEBUG (" %s: n_ctx = %u\n " , __func__, cparams.n_ctx );
188
-
189
184
cparams.n_ctx = GGML_PAD (cparams.n_ctx , llama_kv_cache_unified::get_padding (cparams));
190
185
191
186
LLAMA_LOG_DEBUG (" %s: n_ctx = %u (padded)\n " , __func__, cparams.n_ctx );
192
187
193
- kv_size = cparams.n_ctx ;
194
- type_k = params.type_k ;
195
- type_v = params.type_v ;
196
-
197
188
llama_memory_params params_mem = {
198
- /* .type_k =*/ type_k,
199
- /* .type_v =*/ type_v,
189
+ /* .type_k =*/ params. type_k ,
190
+ /* .type_v =*/ params. type_v ,
200
191
/* .v_trans =*/ !cparams.flash_attn ,
201
192
/* .offload_kqv =*/ cparams.offload_kqv ,
202
- /* .kv_size =*/ kv_size ,
193
+ /* .kv_size =*/ cparams. n_ctx ,
203
194
};
204
195
205
- auto * kv = static_cast <llama_kv_cache_unified *>(model.create_memory (params_mem));
206
-
207
- kv_self.reset (kv);
196
+ memory.reset (model.create_memory (params_mem));
208
197
} else {
209
- // Mamba needs at least as many KV cells as there are sequences kept at any time
210
- kv_size = std::max ((uint32_t ) 1 , params.n_seq_max );
211
- // it's probably best to keep as much precision as possible for the states
212
- type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
213
- type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
214
-
215
198
llama_memory_params params_mem = {
216
- /* .type_k =*/ type_k,
217
- /* .type_v =*/ type_v,
199
+ /* .type_k =*/ GGML_TYPE_F32, // required by ggml_ssm_conv for Mamba's conv_states
200
+ /* .type_v =*/ GGML_TYPE_F32, // required by ggml_ssm_scan for Mamba's ssm_states
218
201
/* .v_trans =*/ false , // unused
219
- /* .offload_kqv =*/ params .offload_kqv ,
220
- /* .kv_size =*/ kv_size,
202
+ /* .offload_kqv =*/ cparams .offload_kqv ,
203
+ /* .kv_size =*/ std::max (( uint32_t ) 1 , params. n_seq_max ), // Mamba needs at least as many KV cells as there are sequences kept at any time
221
204
};
222
205
223
- auto * kv = static_cast <llama_kv_cache_recurrent *>(model.create_memory (params_mem));
224
-
225
- LLAMA_LOG_DEBUG (" %s: n_ctx = %u\n " , __func__, cparams.n_ctx );
226
-
227
- kv_self.reset (kv);
206
+ memory.reset (model.create_memory (params_mem));
228
207
}
229
208
230
- {
231
- const size_t memory_size_k = kv_self->size_k_bytes ();
232
- const size_t memory_size_v = kv_self->size_v_bytes ();
233
-
234
- LLAMA_LOG_INFO (" %s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n " , __func__,
235
- (float )(memory_size_k + memory_size_v) / (1024 .0f * 1024 .0f ),
236
- ggml_type_name (type_k), (float )memory_size_k / (1024 .0f * 1024 .0f ),
237
- ggml_type_name (type_v), (float )memory_size_v / (1024 .0f * 1024 .0f ));
238
- }
239
209
}
240
210
241
211
// init backends
@@ -326,6 +296,8 @@ llama_context::llama_context(
326
296
int n_nodes_tg = -1 ;
327
297
328
298
// simulate full KV cache
299
+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
300
+
329
301
kv_self->set_full ();
330
302
331
303
cross.v_embd .clear ();
@@ -477,11 +449,13 @@ uint32_t llama_context::n_threads_batch() const {
477
449
}
478
450
479
451
llama_kv_cache * llama_context::get_kv_self () {
480
- return kv_self.get ();
452
+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
453
+ return kv_self;
481
454
}
482
455
483
456
const llama_kv_cache * llama_context::get_kv_self () const {
484
- return kv_self.get ();
457
+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
458
+ return kv_self;
485
459
}
486
460
487
461
ggml_tensor * llama_context::build_rope_shift (
@@ -578,7 +552,7 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
578
552
579
553
// GGML_ASSERT(kv_self->size == n_ctx);
580
554
581
- const auto * kv = static_cast <const llama_kv_cache_unified *>(kv_self .get ());
555
+ const auto * kv = static_cast <const llama_kv_cache_unified *>(memory .get ());
582
556
583
557
auto inp = std::make_unique<llm_graph_input_k_shift>(kv);
584
558
@@ -620,7 +594,7 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
620
594
ggml_cgraph * gf) const {
621
595
auto res = std::make_unique<llm_graph_result>();
622
596
623
- auto * kv = static_cast <llama_kv_cache_unified *>(kv_self .get ());
597
+ auto * kv = static_cast <llama_kv_cache_unified *>(memory .get ());
624
598
625
599
const auto & hparams = model.hparams ;
626
600
@@ -766,6 +740,8 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
766
740
void llama_context::kv_self_update () {
767
741
bool need_reserve = false ;
768
742
743
+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
744
+
769
745
if (kv_self->get_has_shift ()) {
770
746
if (!kv_self->get_can_shift ()) {
771
747
GGML_ABORT (" The current KV cache / model configuration does not support K-shift" );
@@ -791,7 +767,7 @@ void llama_context::kv_self_update() {
791
767
}
792
768
793
769
{
794
- auto * kv = static_cast <llama_kv_cache_unified *>(kv_self. get () );
770
+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self);
795
771
796
772
kv->has_shift = false ;
797
773
@@ -805,7 +781,7 @@ void llama_context::kv_self_update() {
805
781
if (kv_self->get_do_defrag ()) {
806
782
LLAMA_LOG_DEBUG (" %s: defragmenting KV cache\n " , __func__);
807
783
808
- auto * kv = static_cast <llama_kv_cache_unified *>(kv_self. get () );
784
+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self);
809
785
810
786
if (kv->defrag_prepare (graph_max_nodes ())) {
811
787
ggml_backend_sched_reset (sched.get ());
@@ -1054,6 +1030,8 @@ int llama_context::encode(llama_batch & inp_batch) {
1054
1030
return -1 ;
1055
1031
}
1056
1032
1033
+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
1034
+
1057
1035
// temporary allocate memory for the input batch if needed
1058
1036
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
1059
1037
llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max () + 1 );
@@ -1219,6 +1197,8 @@ int llama_context::decode(llama_batch & inp_batch) {
1219
1197
return -1 ;
1220
1198
}
1221
1199
1200
+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
1201
+
1222
1202
// temporary allocate memory for the input batch if needed
1223
1203
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
1224
1204
llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max () + 1 );
@@ -1233,7 +1213,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1233
1213
const int64_t n_tokens_all = batch.n_tokens ;
1234
1214
const int64_t n_embd = hparams.n_embd ;
1235
1215
1236
- llama_kv_cache_guard kv_guard (kv_self. get () );
1216
+ llama_kv_cache_guard kv_guard (kv_self);
1237
1217
1238
1218
GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
1239
1219
@@ -1337,7 +1317,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1337
1317
}
1338
1318
1339
1319
if (!is_recurrent) {
1340
- auto * kv = static_cast <llama_kv_cache_unified *>(kv_self. get () );
1320
+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self);
1341
1321
1342
1322
// a heuristic, to avoid attending the full cache if it is not yet utilized
1343
1323
// after enough generations, the benefit from this heuristic disappears
@@ -1489,7 +1469,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1489
1469
1490
1470
// decide if we need to defrag the kv cache
1491
1471
if (!llama_model_is_recurrent (&model) && cparams.causal_attn && cparams.defrag_thold > 0 .0f ) {
1492
- auto * kv = static_cast <llama_kv_cache_unified *>(kv_self. get () );
1472
+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self);
1493
1473
1494
1474
// - do not defrag small contexts (i.e. < 2048 tokens)
1495
1475
// - count the padding towards the number of used tokens
@@ -1662,7 +1642,7 @@ llm_graph_result_ptr llama_context::graph_build(
1662
1642
/* .backend_cpu =*/ backend_cpu,
1663
1643
/* .cvec =*/ &cvec,
1664
1644
/* .loras =*/ &loras,
1665
- /* .memory =*/ kv_self .get (),
1645
+ /* .memory =*/ memory .get (),
1666
1646
/* .cross =*/ &cross,
1667
1647
/* .n_outputs =*/ n_outputs,
1668
1648
/* .cb =*/ graph_get_cb (),
@@ -2121,6 +2101,8 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
2121
2101
}
2122
2102
2123
2103
LLAMA_LOG_DEBUG (" %s: - writing KV self\n " , __func__);
2104
+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
2105
+
2124
2106
kv_self->state_write (io);
2125
2107
2126
2108
return io.n_bytes ();
@@ -2205,6 +2187,8 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
2205
2187
}
2206
2188
2207
2189
LLAMA_LOG_DEBUG (" %s: - reading KV self\n " , __func__);
2190
+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
2191
+
2208
2192
kv_self->state_read (io);
2209
2193
2210
2194
return io.n_bytes ();
@@ -2213,6 +2197,8 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
2213
2197
size_t llama_context::state_seq_write_data (llama_io_write_i & io, llama_seq_id seq_id) {
2214
2198
GGML_UNUSED (seq_id);
2215
2199
2200
+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
2201
+
2216
2202
kv_self->state_write (io, seq_id);
2217
2203
2218
2204
return io.n_bytes ();
@@ -2221,6 +2207,8 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
2221
2207
size_t llama_context::state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id) {
2222
2208
GGML_UNUSED (seq_id);
2223
2209
2210
+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
2211
+
2224
2212
kv_self->state_read (io, seq_id);
2225
2213
2226
2214
return io.n_bytes ();
0 commit comments