Skip to content

Commit b018bb9

Browse files
committed
kv-cache : serparate recurrent vs non-recurrent impl (wip)
ggml-ci
1 parent 971f245 commit b018bb9

File tree

7 files changed

+1241
-382
lines changed

7 files changed

+1241
-382
lines changed

src/llama-context.cpp

Lines changed: 80 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -179,24 +179,37 @@ llama_context::llama_context(
179179
// init the memory module
180180
// TODO: for now, always create a unified KV cache
181181
if (!hparams.vocab_only) {
182-
kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
182+
uint32_t kv_size = 0;
183+
ggml_type type_k = params.type_k;
184+
ggml_type type_v = params.type_v;
183185

184-
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
186+
if (!llama_model_is_recurrent(&model)) {
187+
//kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
188+
auto * kv = static_cast<llama_kv_cache_unified *>(model.create_memory());
185189

186-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
190+
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
187191

188-
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
192+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv->get_padding(cparams));
189193

190-
uint32_t kv_size = cparams.n_ctx;
191-
ggml_type type_k = params.type_k;
192-
ggml_type type_v = params.type_v;
194+
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
195+
196+
kv_size = cparams.n_ctx;
197+
type_k = params.type_k;
198+
type_v = params.type_v;
199+
200+
kv_self.reset(kv);
201+
} else {
202+
auto * kv = static_cast<llama_kv_cache_recurrent *>(model.create_memory());
203+
204+
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
193205

194-
if (llama_model_is_recurrent(&model)) {
195206
// Mamba needs at least as many KV cells as there are sequences kept at any time
196207
kv_size = std::max((uint32_t) 1, params.n_seq_max);
197208
// it's probably best to keep as much precision as possible for the states
198209
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
199210
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
211+
212+
kv_self.reset(kv);
200213
}
201214

202215
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
@@ -305,7 +318,7 @@ llama_context::llama_context(
305318
int n_nodes_tg = -1;
306319

307320
// simulate full KV cache
308-
kv_self->n = kv_self->size;
321+
kv_self->set_full();
309322

310323
cross.v_embd.clear();
311324

@@ -557,7 +570,9 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
557570

558571
//GGML_ASSERT(kv_self->size == n_ctx);
559572

560-
auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get());
573+
const auto * kv = static_cast<const llama_kv_cache_unified *>(kv_self.get());
574+
575+
auto inp = std::make_unique<llm_graph_input_k_shift>(kv);
561576

562577
inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx);
563578
ggml_set_input(inp->k_shift);
@@ -573,16 +588,16 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
573588
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
574589
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
575590

576-
ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
591+
ggml_tensor * rope_factors = kv->cbs.get_rope_factors(n_ctx_per_seq(), il);
577592

578593
ggml_tensor * k =
579-
ggml_view_3d(ctx0, kv_self->k_l[il],
580-
n_embd_head_k, n_head_kv, kv_self->size,
581-
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
582-
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
594+
ggml_view_3d(ctx0, kv->k_l[il],
595+
n_embd_head_k, n_head_kv, kv->size,
596+
ggml_row_size(kv->k_l[il]->type, n_embd_head_k),
597+
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa),
583598
0);
584599

585-
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
600+
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv->k_l[il]->buffer);
586601

587602
ggml_build_forward_expand(gf, cur);
588603
}
@@ -597,9 +612,11 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
597612
ggml_cgraph * gf) const {
598613
auto res = std::make_unique<llm_graph_result>();
599614

615+
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
616+
600617
const auto & hparams = model.hparams;
601618

602-
const auto & ids = kv_self->defrag_info.ids;
619+
const auto & ids = kv->defrag_info.ids;
603620

604621
#if 0
605622
// CPU defrag
@@ -689,40 +706,40 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
689706
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
690707
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
691708

692-
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il],
709+
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv->k_l[il],
693710
n_embd_k_gqa, nm,
694-
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
695-
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
711+
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa),
712+
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa*i));
696713

697-
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il],
714+
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv->k_l[il],
698715
n_embd_k_gqa, nm,
699-
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
700-
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
716+
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa),
717+
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa*id));
701718

702719
ggml_tensor * view_v_src;
703720
ggml_tensor * view_v_dst;
704721

705722
if (cparams.flash_attn) {
706723
// NOTE: the V cache is not transposed when using flash attention
707-
view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
724+
view_v_src = ggml_view_2d(ctx0, kv->v_l[il],
708725
n_embd_v_gqa, nm,
709-
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
710-
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
726+
ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa),
727+
ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa*i));
711728

712-
view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
729+
view_v_dst = ggml_view_2d(ctx0, kv->v_l[il],
713730
n_embd_v_gqa, nm,
714-
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
715-
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
731+
ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa),
732+
ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa*id));
716733
} else {
717-
view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
734+
view_v_src = ggml_view_2d(ctx0, kv->v_l[il],
718735
nm, n_embd_v_gqa,
719-
ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
720-
ggml_row_size(kv_self->v_l[il]->type, i));
736+
ggml_row_size(kv->v_l[il]->type, kv->size),
737+
ggml_row_size(kv->v_l[il]->type, i));
721738

722-
view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
739+
view_v_dst = ggml_view_2d(ctx0, kv->v_l[il],
723740
nm, n_embd_v_gqa,
724-
ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
725-
ggml_row_size(kv_self->v_l[il]->type, id));
741+
ggml_row_size(kv->v_l[il]->type, kv->size),
742+
ggml_row_size(kv->v_l[il]->type, id));
726743
}
727744

728745
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
@@ -739,13 +756,11 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag(
739756
}
740757

741758
void llama_context::kv_self_update() {
742-
auto & kv = kv_self;
743-
744759
bool need_reserve = false;
745760

746-
if (kv->has_shift) {
747-
if (!kv->get_can_shift()) {
748-
GGML_ABORT("The current context does not support K-shift");
761+
if (kv_self->get_has_shift()) {
762+
if (!kv_self->get_can_shift()) {
763+
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
749764
}
750765

751766
LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
@@ -768,6 +783,8 @@ void llama_context::kv_self_update() {
768783
}
769784

770785
{
786+
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
787+
771788
kv->has_shift = false;
772789

773790
for (uint32_t i = 0; i < kv->size; ++i) {
@@ -777,9 +794,11 @@ void llama_context::kv_self_update() {
777794
}
778795

779796
// defragment the KV cache if needed
780-
if (kv->do_defrag) {
797+
if (kv_self->get_do_defrag()) {
781798
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
782799

800+
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
801+
783802
if (kv->defrag_prepare(graph_max_nodes())) {
784803
ggml_backend_sched_reset(sched.get());
785804

@@ -808,7 +827,7 @@ void llama_context::kv_self_update() {
808827
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
809828

810829
// simulate full KV cache
811-
kv_self->n = kv_self->size;
830+
kv_self->set_full();
812831

813832
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
814833
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
@@ -1028,8 +1047,8 @@ int llama_context::encode(llama_batch & inp_batch) {
10281047
}
10291048

10301049
// temporary allocate memory for the input batch if needed
1031-
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
1032-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
1050+
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
1051+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
10331052

10341053
const llama_batch & batch = batch_allocr.batch;
10351054
const int32_t n_tokens = batch.n_tokens;
@@ -1193,8 +1212,8 @@ int llama_context::decode(llama_batch & inp_batch) {
11931212
}
11941213

11951214
// temporary allocate memory for the input batch if needed
1196-
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
1197-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
1215+
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
1216+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
11981217

11991218
const llama_batch & batch = batch_allocr.batch;
12001219

@@ -1249,8 +1268,10 @@ int llama_context::decode(llama_batch & inp_batch) {
12491268

12501269
const bool logits_all = n_outputs_all == n_tokens_all;
12511270

1271+
const bool is_recurrent = llama_model_is_recurrent(&model);
1272+
12521273
sbatch.from_batch(batch, n_embd,
1253-
/* simple_split */ !kv_self->recurrent,
1274+
/* simple_split */ !is_recurrent,
12541275
/* logits_all */ logits_all);
12551276

12561277
// reserve output buffer
@@ -1269,7 +1290,7 @@ int llama_context::decode(llama_batch & inp_batch) {
12691290

12701291
const auto & n_ubatch = cparams.n_ubatch;
12711292

1272-
if (kv_self->recurrent) {
1293+
if (is_recurrent) {
12731294
if (embd_pooled) {
12741295
// Pooled embeddings cannot be split across ubatches (yet)
12751296
ubatch = sbatch.split_seq(cparams.n_ubatch);
@@ -1307,17 +1328,19 @@ int llama_context::decode(llama_batch & inp_batch) {
13071328
return 1;
13081329
}
13091330

1310-
if (!kv_self->recurrent) {
1331+
if (!is_recurrent) {
1332+
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
1333+
13111334
// a heuristic, to avoid attending the full cache if it is not yet utilized
13121335
// after enough generations, the benefit from this heuristic disappears
13131336
// if we start defragmenting the cache, the benefit from this will be more important
1314-
const uint32_t pad = kv_self->get_padding(cparams);
1315-
kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad)));
1337+
const uint32_t pad = kv->get_padding(cparams);
1338+
kv->n = std::min(kv->size, std::max(pad, GGML_PAD(kv->cell_max(), pad)));
1339+
1340+
//printf("kv.n = %5d, kv.used = %5d, kv.head = %5d\n", kv->n, kv->used, kv->head);
13161341
}
13171342
}
13181343

1319-
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
1320-
13211344
ggml_backend_sched_reset(sched.get());
13221345
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
13231346

@@ -1457,10 +1480,12 @@ int llama_context::decode(llama_batch & inp_batch) {
14571480
//synchronize();
14581481

14591482
// decide if we need to defrag the kv cache
1460-
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
1483+
if (!llama_model_is_recurrent(&model) && cparams.causal_attn && cparams.defrag_thold > 0.0f) {
1484+
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self.get());
1485+
14611486
// - do not defrag small contexts (i.e. < 2048 tokens)
14621487
// - count the padding towards the number of used tokens
1463-
const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
1488+
const float fragmentation = kv->n >= 2048 ? std::max(0.0f, 1.0f - float(kv->used + kv->get_padding(cparams))/float(kv->n)) : 0.0f;
14641489

14651490
// queue defragmentation for next llama_kv_cache_update
14661491
if (fragmentation > cparams.defrag_thold) {

src/llama-context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ struct llama_context {
201201

202202
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
203203

204-
std::unique_ptr<llama_kv_cache_unified> kv_self;
204+
std::unique_ptr<llama_kv_cache> kv_self;
205205

206206
// TODO: remove
207207
bool logits_all = false;

src/llama-graph.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
274274

275275
//////////////////////////////////////////////
276276
// TODO: this should not mutate the KV cache !
277-
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
277+
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_recurrent *>(kv_self)->cells[i];
278278

279279
// prevent out-of-bound sources
280280
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
@@ -307,7 +307,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
307307

308308
//////////////////////////////////////////////
309309
// TODO: this should not mutate the KV cache !
310-
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
310+
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_recurrent *>(kv_self)->cells[i];
311311

312312
data[i] = (float) (kv_cell.src >= 0);
313313

@@ -1079,7 +1079,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
10791079
}
10801080

10811081
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
1082-
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1082+
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
10831083

10841084
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
10851085

@@ -1096,7 +1096,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
10961096
}
10971097

10981098
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
1099-
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1099+
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
11001100

11011101
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
11021102

@@ -1408,8 +1408,6 @@ ggml_tensor * llm_graph_context::build_attn(
14081408

14091409
// store to KV cache
14101410
{
1411-
GGML_ASSERT(!kv_self->recurrent);
1412-
14131411
const auto kv_head = kv_self->head;
14141412

14151413
GGML_ASSERT(kv_self->size == n_ctx);
@@ -1559,7 +1557,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
15591557
ggml_tensor * state_mask,
15601558
int32_t n_state,
15611559
int32_t n_seqs) const {
1562-
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1560+
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
15631561

15641562
const auto n_kv = kv_self->n;
15651563
const auto kv_head = kv_self->head;
@@ -1591,7 +1589,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
15911589
ggml_tensor * state_mask,
15921590
const llama_ubatch & ubatch,
15931591
int il) const {
1594-
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1592+
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
15951593

15961594
const auto token_shift_count = hparams.token_shift_count;
15971595

@@ -1612,7 +1610,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
16121610
ggml_tensor * token_shift,
16131611
const llama_ubatch & ubatch,
16141612
int il) const {
1615-
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1613+
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
16161614

16171615
const auto token_shift_count = hparams.token_shift_count;
16181616
const auto n_embd = hparams.n_embd;

0 commit comments

Comments
 (0)