Skip to content

Commit 7a359a2

Browse files
committed
context : hide kv cache details in implementation
ggml-ci
1 parent 61710fc commit 7a359a2

File tree

5 files changed

+415
-351
lines changed

5 files changed

+415
-351
lines changed

src/llama-context.cpp

Lines changed: 10 additions & 327 deletions
Original file line numberDiff line numberDiff line change
@@ -436,338 +436,21 @@ const llama_kv_cache * llama_context::get_kv_self() const {
436436
return kv_self;
437437
}
438438

439-
ggml_tensor * llama_context::build_rope_shift(
440-
ggml_context * ctx0,
441-
ggml_tensor * cur,
442-
ggml_tensor * shift,
443-
ggml_tensor * factors,
444-
float freq_base,
445-
float freq_scale) const {
446-
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
447-
448-
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
449-
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
450-
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
451-
452-
const auto & hparams = model.hparams;
453-
454-
const auto & n_rot = hparams.n_rot;
455-
const auto & rope_type = hparams.rope_type;
456-
457-
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
458-
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
459-
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
460-
461-
ggml_tensor * tmp;
462-
463-
if (ggml_is_quantized(cur->type)) {
464-
// dequantize to f32 -> RoPE -> quantize back
465-
tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32);
466-
467-
tmp = ggml_rope_ext(ctx0, tmp,
468-
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
469-
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
470-
471-
tmp = ggml_cpy(ctx0, tmp, cur);
472-
} else {
473-
// we rotate only the first n_rot dimensions
474-
tmp = ggml_rope_ext_inplace(ctx0, cur,
475-
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
476-
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
477-
}
478-
479-
return tmp;
480-
}
481-
482-
class llm_graph_input_k_shift : public llm_graph_input_i {
483-
public:
484-
llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
485-
virtual ~llm_graph_input_k_shift() = default;
486-
487-
void set_input(const llama_ubatch * ubatch) override;
488-
489-
ggml_tensor * k_shift; // I32 [kv_size]
490-
491-
const llama_kv_cache_unified * kv_self;
492-
};
493-
494-
void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
495-
GGML_UNUSED(ubatch);
496-
497-
if (k_shift) {
498-
assert(ggml_backend_buffer_is_host(k_shift->buffer));
499-
500-
int32_t * data = (int32_t *) k_shift->data;
501-
502-
for (uint32_t i = 0; i < kv_self->size; ++i) {
503-
data[i] = kv_self->cells[i].delta;
504-
}
505-
}
506-
}
507-
508-
llm_graph_result_ptr llama_context::build_kv_self_shift(
509-
ggml_context * ctx0,
510-
ggml_cgraph * gf) const {
511-
auto res = std::make_unique<llm_graph_result>();
512-
513-
const auto & hparams = model.hparams;
514-
515-
const auto & n_layer = hparams.n_layer;
516-
517-
const auto & n_embd_head_k = hparams.n_embd_head_k;
518-
//const auto & n_embd_head_v = hparams.n_embd_head_v;
519-
520-
//GGML_ASSERT(kv_self->size == n_ctx);
521-
522-
const auto * kv = static_cast<const llama_kv_cache_unified *>(memory.get());
523-
524-
auto inp = std::make_unique<llm_graph_input_k_shift>(kv);
525-
526-
inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx);
527-
ggml_set_input(inp->k_shift);
528-
529-
for (uint32_t il = 0; il < n_layer; ++il) {
530-
const int64_t n_head_kv = hparams.n_head_kv(il);
531-
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
532-
533-
const bool is_swa = hparams.is_swa(il);
534-
535-
// note: the swa rope params could become part of the cparams in the future
536-
// if we decide to make them configurable, like the non-sliding ones
537-
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
538-
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
539-
540-
ggml_tensor * rope_factors = kv->cbs.get_rope_factors(n_ctx_per_seq(), il);
541-
542-
ggml_tensor * k =
543-
ggml_view_3d(ctx0, kv->k_l[il],
544-
n_embd_head_k, n_head_kv, kv->size,
545-
ggml_row_size(kv->k_l[il]->type, n_embd_head_k),
546-
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa),
547-
0);
548-
549-
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
550-
551-
ggml_build_forward_expand(gf, cur);
552-
}
553-
554-
res->add_input(std::move(inp));
555-
556-
return res;
557-
}
558-
559-
llm_graph_result_ptr llama_context::build_kv_self_defrag(
560-
ggml_context * ctx0,
561-
ggml_cgraph * gf) const {
562-
auto res = std::make_unique<llm_graph_result>();
563-
564-
auto * kv = static_cast<llama_kv_cache_unified *>(memory.get());
565-
566-
const auto & hparams = model.hparams;
567-
568-
const auto & ids = kv->defrag_info.ids;
569-
570-
#if 0
571-
// CPU defrag
572-
//
573-
// TODO: optimizations are possible:
574-
// - multiple threads
575-
// - avoid copying to the host memory when already there
576-
//
577-
// likely not worth the effort, as we have ggml_graph based defrag
578-
//
579-
580-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
581-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
582-
583-
const uint32_t kv_size = size;
584-
585-
std::vector<uint8_t> buf_k;
586-
std::vector<uint8_t> buf_v;
587-
588-
for (uint32_t il = 0; il < n_layer; ++il) {
589-
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
590-
const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
591-
592-
const size_t v_size_el = ggml_type_size(v_l[il]->type);
593-
const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
594-
595-
buf_k.resize(k_size);
596-
buf_v.resize(v_size);
597-
598-
ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
599-
ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
600-
601-
// batch move [i, i+nm) to [id, id+nm)
602-
// note: cells can move only to a lower index
603-
for (uint32_t i = 0; i < n_kv; ++i) {
604-
const uint32_t id = ids[i];
605-
606-
if (i == id || id == n_kv) {
607-
continue;
608-
}
609-
610-
uint32_t nm = 1;
611-
612-
while (i + nm < n_kv && ids[i + nm] == id + nm) {
613-
nm++;
614-
}
615-
616-
// move keys
617-
{
618-
const int64_t os = i*k_size_row;
619-
const int64_t od = id*k_size_row;
620-
621-
memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
622-
}
623-
624-
// move values (note: they are transposed)
625-
{
626-
const int64_t os = i;
627-
const int64_t od = id;
628-
629-
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
630-
memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
631-
}
632-
}
633-
634-
i += nm - 1;
635-
}
636-
637-
ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
638-
ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
639-
}
640-
#else
641-
for (uint32_t i = 0; i < ids.size(); ++i) {
642-
const uint32_t id = ids[i];
643-
644-
if (i == id || id == ids.size()) {
645-
continue;
646-
}
647-
648-
uint32_t nm = 1;
649-
650-
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
651-
nm++;
652-
}
653-
654-
for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
655-
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
656-
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
657-
658-
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv->k_l[il],
659-
n_embd_k_gqa, nm,
660-
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa),
661-
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa*i));
662-
663-
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv->k_l[il],
664-
n_embd_k_gqa, nm,
665-
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa),
666-
ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa*id));
667-
668-
ggml_tensor * view_v_src;
669-
ggml_tensor * view_v_dst;
670-
671-
if (cparams.flash_attn) {
672-
// NOTE: the V cache is not transposed when using flash attention
673-
view_v_src = ggml_view_2d(ctx0, kv->v_l[il],
674-
n_embd_v_gqa, nm,
675-
ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa),
676-
ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa*i));
677-
678-
view_v_dst = ggml_view_2d(ctx0, kv->v_l[il],
679-
n_embd_v_gqa, nm,
680-
ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa),
681-
ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa*id));
682-
} else {
683-
view_v_src = ggml_view_2d(ctx0, kv->v_l[il],
684-
nm, n_embd_v_gqa,
685-
ggml_row_size(kv->v_l[il]->type, kv->size),
686-
ggml_row_size(kv->v_l[il]->type, i));
687-
688-
view_v_dst = ggml_view_2d(ctx0, kv->v_l[il],
689-
nm, n_embd_v_gqa,
690-
ggml_row_size(kv->v_l[il]->type, kv->size),
691-
ggml_row_size(kv->v_l[il]->type, id));
692-
}
693-
694-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
695-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
696-
}
697-
698-
i += nm - 1;
699-
}
700-
701-
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
702-
#endif
703-
704-
return res;
705-
}
706-
707439
void llama_context::kv_self_update() {
708440
bool need_reserve = false;
709441

710442
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
711443

712-
if (kv_self->get_has_shift()) {
713-
if (!kv_self->get_can_shift()) {
714-
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
715-
}
716-
717-
LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
718-
719-
// apply K-shift if needed
720-
if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
721-
ggml_backend_sched_reset(sched.get());
722-
723-
auto * gf = graph_init();
724-
725-
auto res = build_kv_self_shift(ctx_compute.get(), gf);
726-
727-
ggml_backend_sched_alloc_graph(sched.get(), gf);
728-
729-
res->set_inputs(nullptr);
730-
731-
graph_compute(gf, false);
732-
733-
need_reserve = true;
734-
}
735-
736-
{
737-
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self);
738-
739-
kv->has_shift = false;
740-
741-
for (uint32_t i = 0; i < kv->size; ++i) {
742-
kv->cells[i].delta = 0;
743-
}
744-
}
745-
}
746-
747-
// defragment the KV cache if needed
748-
if (kv_self->get_do_defrag()) {
749-
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
750-
751-
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self);
752-
753-
if (kv->defrag_prepare(graph_max_nodes())) {
754-
ggml_backend_sched_reset(sched.get());
755-
756-
auto * gf = graph_init();
757-
758-
auto res = build_kv_self_defrag(ctx_compute.get(), gf);
759-
760-
ggml_backend_sched_alloc_graph(sched.get(), gf);
761-
762-
res->set_inputs(nullptr);
763-
764-
graph_compute(gf, false);
765-
766-
need_reserve = true;
767-
}
768-
769-
kv->do_defrag = false;
770-
}
444+
need_reserve = kv_self->update({
445+
/*.arch =*/ model.arch,
446+
/*.cparams =*/ cparams,
447+
/*.sched =*/ sched.get(),
448+
/*.backends =*/ backends,
449+
/*.n_max_nodes =*/ graph_max_nodes(),
450+
/*.get_ctx_compute =*/ [this]() { return ctx_compute.get(); },
451+
/*.graph_init =*/ [this]() { return graph_init(); },
452+
/*.graph_compute =*/ [this](ggml_cgraph * gf) { graph_compute(gf, false); },
453+
});
771454

772455
// reserve a worst case graph if needed
773456
if (need_reserve) {

src/llama-context.h

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -159,23 +159,6 @@ struct llama_context {
159159

160160
llm_graph_cb graph_get_cb() const;
161161

162-
// used by kv_self_update()
163-
ggml_tensor * build_rope_shift(
164-
ggml_context * ctx0,
165-
ggml_tensor * cur,
166-
ggml_tensor * shift,
167-
ggml_tensor * factors,
168-
float freq_base,
169-
float freq_scale) const;
170-
171-
llm_graph_result_ptr build_kv_self_shift(
172-
ggml_context * ctx0,
173-
ggml_cgraph * gf) const;
174-
175-
llm_graph_result_ptr build_kv_self_defrag(
176-
ggml_context * ctx0,
177-
ggml_cgraph * gf) const;
178-
179162
// TODO: read/write lora adapters and cvec
180163
size_t state_write_data(llama_io_write_i & io);
181164
size_t state_read_data (llama_io_read_i & io);

src/llama-graph.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,8 @@ struct llm_graph_params {
353353
const llama_cparams & cparams;
354354
const llama_ubatch & ubatch;
355355

356-
ggml_backend_sched * sched;
357-
ggml_backend * backend_cpu;
356+
ggml_backend_sched_t sched;
357+
ggml_backend_t backend_cpu;
358358

359359
const llama_adapter_cvec * cvec;
360360
const llama_adapter_loras * loras;
@@ -405,9 +405,9 @@ struct llm_graph_context {
405405

406406
ggml_context * ctx0 = nullptr;
407407

408-
ggml_backend_sched * sched;
408+
ggml_backend_sched_t sched;
409409

410-
ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
410+
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
411411

412412
const llama_adapter_cvec * cvec;
413413
const llama_adapter_loras * loras;

0 commit comments

Comments
 (0)