@@ -177,65 +177,35 @@ llama_context::llama_context(
177
177
}
178
178
179
179
// init the memory module
180
- // TODO: for now, always create a unified KV cache
181
180
if (!hparams.vocab_only ) {
182
- uint32_t kv_size = 0 ;
183
- ggml_type type_k = params.type_k ;
184
- ggml_type type_v = params.type_v ;
181
+ LLAMA_LOG_DEBUG (" %s: n_ctx = %u\n " , __func__, cparams.n_ctx );
185
182
186
183
if (!llama_model_is_recurrent (&model)) {
187
- LLAMA_LOG_DEBUG (" %s: n_ctx = %u\n " , __func__, cparams.n_ctx );
188
-
189
184
cparams.n_ctx = GGML_PAD (cparams.n_ctx , llama_kv_cache_unified::get_padding (cparams));
190
185
191
186
LLAMA_LOG_DEBUG (" %s: n_ctx = %u (padded)\n " , __func__, cparams.n_ctx );
192
187
193
- kv_size = cparams.n_ctx ;
194
- type_k = params.type_k ;
195
- type_v = params.type_v ;
196
-
197
188
llama_memory_params params_mem = {
198
- /* .type_k =*/ type_k,
199
- /* .type_v =*/ type_v,
189
+ /* .type_k =*/ params. type_k ,
190
+ /* .type_v =*/ params. type_v ,
200
191
/* .v_trans =*/ !cparams.flash_attn ,
201
192
/* .offload_kqv =*/ cparams.offload_kqv ,
202
- /* .kv_size =*/ kv_size ,
193
+ /* .kv_size =*/ cparams. n_ctx ,
203
194
};
204
195
205
- auto * kv = static_cast <llama_kv_cache_unified *>(model.create_memory (params_mem));
206
-
207
- kv_self.reset (kv);
196
+ memory.reset (model.create_memory (params_mem));
208
197
} else {
209
- // Mamba needs at least as many KV cells as there are sequences kept at any time
210
- kv_size = std::max ((uint32_t ) 1 , params.n_seq_max );
211
- // it's probably best to keep as much precision as possible for the states
212
- type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
213
- type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
214
-
215
198
llama_memory_params params_mem = {
216
- /* .type_k =*/ type_k,
217
- /* .type_v =*/ type_v,
199
+ /* .type_k =*/ GGML_TYPE_F32, // required by ggml_ssm_conv for Mamba's conv_states
200
+ /* .type_v =*/ GGML_TYPE_F32, // required by ggml_ssm_scan for Mamba's ssm_states
218
201
/* .v_trans =*/ false , // unused
219
- /* .offload_kqv =*/ params .offload_kqv ,
220
- /* .kv_size =*/ kv_size,
202
+ /* .offload_kqv =*/ cparams .offload_kqv ,
203
+ /* .kv_size =*/ std::max (( uint32_t ) 1 , params. n_seq_max ), // Mamba needs at least as many KV cells as there are sequences kept at any time
221
204
};
222
205
223
- auto * kv = static_cast <llama_kv_cache_recurrent *>(model.create_memory (params_mem));
224
-
225
- LLAMA_LOG_DEBUG (" %s: n_ctx = %u\n " , __func__, cparams.n_ctx );
226
-
227
- kv_self.reset (kv);
206
+ memory.reset (model.create_memory (params_mem));
228
207
}
229
208
230
- {
231
- const size_t memory_size_k = kv_self->size_k_bytes ();
232
- const size_t memory_size_v = kv_self->size_v_bytes ();
233
-
234
- LLAMA_LOG_INFO (" %s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n " , __func__,
235
- (float )(memory_size_k + memory_size_v) / (1024 .0f * 1024 .0f ),
236
- ggml_type_name (type_k), (float )memory_size_k / (1024 .0f * 1024 .0f ),
237
- ggml_type_name (type_v), (float )memory_size_v / (1024 .0f * 1024 .0f ));
238
- }
239
209
}
240
210
241
211
// init backends
@@ -326,6 +296,8 @@ llama_context::llama_context(
326
296
int n_nodes_tg = -1 ;
327
297
328
298
// simulate full KV cache
299
+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
300
+
329
301
kv_self->set_full ();
330
302
331
303
cross.v_embd .clear ();
@@ -477,11 +449,13 @@ uint32_t llama_context::n_threads_batch() const {
477
449
}
478
450
479
451
llama_kv_cache * llama_context::get_kv_self () {
480
- return kv_self.get ();
452
+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
453
+ return kv_self;
481
454
}
482
455
483
456
const llama_kv_cache * llama_context::get_kv_self () const {
484
- return kv_self.get ();
457
+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
458
+ return kv_self;
485
459
}
486
460
487
461
ggml_tensor * llama_context::build_rope_shift (
@@ -567,7 +541,7 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
567
541
568
542
// GGML_ASSERT(kv_self->size == n_ctx);
569
543
570
- const auto * kv = static_cast <const llama_kv_cache_unified *>(kv_self .get ());
544
+ const auto * kv = static_cast <const llama_kv_cache_unified *>(memory .get ());
571
545
572
546
auto inp = std::make_unique<llm_graph_input_k_shift>(kv);
573
547
@@ -609,7 +583,7 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
609
583
ggml_cgraph * gf) const {
610
584
auto res = std::make_unique<llm_graph_result>();
611
585
612
- auto * kv = static_cast <llama_kv_cache_unified *>(kv_self .get ());
586
+ auto * kv = static_cast <llama_kv_cache_unified *>(memory .get ());
613
587
614
588
const auto & hparams = model.hparams ;
615
589
@@ -755,6 +729,8 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
755
729
void llama_context::kv_self_update () {
756
730
bool need_reserve = false ;
757
731
732
+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
733
+
758
734
if (kv_self->get_has_shift ()) {
759
735
if (!kv_self->get_can_shift ()) {
760
736
GGML_ABORT (" The current KV cache / model configuration does not support K-shift" );
@@ -780,7 +756,7 @@ void llama_context::kv_self_update() {
780
756
}
781
757
782
758
{
783
- auto * kv = static_cast <llama_kv_cache_unified *>(kv_self. get () );
759
+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self);
784
760
785
761
kv->has_shift = false ;
786
762
@@ -794,7 +770,7 @@ void llama_context::kv_self_update() {
794
770
if (kv_self->get_do_defrag ()) {
795
771
LLAMA_LOG_DEBUG (" %s: defragmenting KV cache\n " , __func__);
796
772
797
- auto * kv = static_cast <llama_kv_cache_unified *>(kv_self. get () );
773
+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self);
798
774
799
775
if (kv->defrag_prepare (graph_max_nodes ())) {
800
776
ggml_backend_sched_reset (sched.get ());
@@ -1043,6 +1019,8 @@ int llama_context::encode(llama_batch & inp_batch) {
1043
1019
return -1 ;
1044
1020
}
1045
1021
1022
+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
1023
+
1046
1024
// temporary allocate memory for the input batch if needed
1047
1025
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
1048
1026
llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max () + 1 );
@@ -1208,6 +1186,8 @@ int llama_context::decode(llama_batch & inp_batch) {
1208
1186
return -1 ;
1209
1187
}
1210
1188
1189
+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
1190
+
1211
1191
// temporary allocate memory for the input batch if needed
1212
1192
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
1213
1193
llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max () + 1 );
@@ -1222,7 +1202,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1222
1202
const int64_t n_tokens_all = batch.n_tokens ;
1223
1203
const int64_t n_embd = hparams.n_embd ;
1224
1204
1225
- llama_kv_cache_guard kv_guard (kv_self. get () );
1205
+ llama_kv_cache_guard kv_guard (kv_self);
1226
1206
1227
1207
GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
1228
1208
@@ -1326,7 +1306,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1326
1306
}
1327
1307
1328
1308
if (!is_recurrent) {
1329
- auto * kv = static_cast <llama_kv_cache_unified *>(kv_self. get () );
1309
+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self);
1330
1310
1331
1311
// a heuristic, to avoid attending the full cache if it is not yet utilized
1332
1312
// after enough generations, the benefit from this heuristic disappears
@@ -1478,7 +1458,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1478
1458
1479
1459
// decide if we need to defrag the kv cache
1480
1460
if (!llama_model_is_recurrent (&model) && cparams.causal_attn && cparams.defrag_thold > 0 .0f ) {
1481
- auto * kv = static_cast <llama_kv_cache_unified *>(kv_self. get () );
1461
+ auto * kv = static_cast <llama_kv_cache_unified *>(kv_self);
1482
1462
1483
1463
// - do not defrag small contexts (i.e. < 2048 tokens)
1484
1464
// - count the padding towards the number of used tokens
@@ -1651,7 +1631,7 @@ llm_graph_result_ptr llama_context::graph_build(
1651
1631
/* .backend_cpu =*/ backend_cpu,
1652
1632
/* .cvec =*/ &cvec,
1653
1633
/* .loras =*/ &loras,
1654
- /* .memory =*/ kv_self .get (),
1634
+ /* .memory =*/ memory .get (),
1655
1635
/* .cross =*/ &cross,
1656
1636
/* .n_outputs =*/ n_outputs,
1657
1637
/* .cb =*/ graph_get_cb (),
@@ -2110,6 +2090,8 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
2110
2090
}
2111
2091
2112
2092
LLAMA_LOG_DEBUG (" %s: - writing KV self\n " , __func__);
2093
+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
2094
+
2113
2095
kv_self->state_write (io);
2114
2096
2115
2097
return io.n_bytes ();
@@ -2194,6 +2176,8 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
2194
2176
}
2195
2177
2196
2178
LLAMA_LOG_DEBUG (" %s: - reading KV self\n " , __func__);
2179
+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
2180
+
2197
2181
kv_self->state_read (io);
2198
2182
2199
2183
return io.n_bytes ();
@@ -2202,6 +2186,8 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
2202
2186
size_t llama_context::state_seq_write_data (llama_io_write_i & io, llama_seq_id seq_id) {
2203
2187
GGML_UNUSED (seq_id);
2204
2188
2189
+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
2190
+
2205
2191
kv_self->state_write (io, seq_id);
2206
2192
2207
2193
return io.n_bytes ();
@@ -2210,6 +2196,8 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
2210
2196
size_t llama_context::state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id) {
2211
2197
GGML_UNUSED (seq_id);
2212
2198
2199
+ llama_kv_cache * kv_self = static_cast <llama_kv_cache *>(memory.get ());
2200
+
2213
2201
kv_self->state_read (io, seq_id);
2214
2202
2215
2203
return io.n_bytes ();
0 commit comments