Skip to content

Commit 70efeb7

Browse files
committed
cont : move kv_self update to llama_context
ggml-ci
1 parent 68fd1b4 commit 70efeb7

File tree

3 files changed

+157
-154
lines changed

3 files changed

+157
-154
lines changed

src/llama-context.cpp

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,38 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
3232
return relative_bucket;
3333
}
3434

35+
enum ggml_status llama_context::compute_graph(
36+
ggml_cgraph * graph,
37+
bool batched) {
38+
int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads;
39+
ggml_threadpool_t tp = batched ? threadpool_batch : threadpool;
40+
41+
if (backend_cpu != nullptr) {
42+
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
43+
auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
44+
set_threadpool_fn(backend_cpu, tp);
45+
}
46+
47+
// set the number of threads for all the backends
48+
for (const auto & set_n_threads_fn : set_n_threads_fns) {
49+
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
50+
}
51+
52+
auto status = ggml_backend_sched_graph_compute_async(sched.get(), graph);
53+
if (status != GGML_STATUS_SUCCESS) {
54+
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
55+
}
56+
57+
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched));
58+
59+
return status;
60+
}
61+
62+
63+
llama_pos llama_context::pos_max() const {
64+
return kv_self.pos_max();
65+
}
66+
3567
// TODO: improve
3668
void llama_context::reset() {
3769
inp_tokens = nullptr;
@@ -540,6 +572,93 @@ ggml_tensor * llama_context::build_lora_mm_id(
540572
return res;
541573
}
542574

575+
bool llama_context::kv_self_update() {
576+
bool need_reserve = false;
577+
578+
auto & kv = kv_self;
579+
580+
if (kv.has_shift) {
581+
if (!kv.can_shift) {
582+
GGML_ABORT("The current context does not support K-shift");
583+
}
584+
585+
// apply K-shift if needed
586+
if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
587+
prepare_k_shift();
588+
589+
ggml_backend_sched_reset(sched.get());
590+
591+
struct ggml_init_params params = {
592+
/*.mem_size =*/ buf_compute_meta.size(),
593+
/*.mem_buffer =*/ buf_compute_meta.data(),
594+
/*.no_alloc =*/ true,
595+
};
596+
597+
ggml_context * ctx0 = ggml_init(params);
598+
599+
reset();
600+
601+
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
602+
603+
build_k_shift(ctx0, gf);
604+
605+
ggml_backend_sched_alloc_graph(sched.get(), gf);
606+
607+
set_inputs({});
608+
609+
compute_graph(gf, false);
610+
611+
ggml_free(ctx0);
612+
613+
need_reserve = true;
614+
}
615+
616+
{
617+
kv.has_shift = false;
618+
619+
for (uint32_t i = 0; i < kv.size; ++i) {
620+
kv.cells[i].delta = 0;
621+
}
622+
}
623+
}
624+
625+
// defragment the KV cache if needed
626+
if (kv.do_defrag) {
627+
prepare_defrag();
628+
629+
ggml_backend_sched_reset(sched.get());
630+
631+
struct ggml_init_params params = {
632+
/*.mem_size =*/ buf_compute_meta.size(),
633+
/*.mem_buffer =*/ buf_compute_meta.data(),
634+
/*.no_alloc =*/ true,
635+
};
636+
637+
ggml_context * ctx0 = ggml_init(params);
638+
639+
reset();
640+
641+
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
642+
643+
build_defrag(ctx0, gf);
644+
645+
ggml_backend_sched_alloc_graph(sched.get(), gf);
646+
647+
// no input
648+
//set_inputs({});
649+
650+
compute_graph(gf, false);
651+
652+
ggml_free(ctx0);
653+
654+
need_reserve = true;
655+
656+
kv.do_defrag = false;
657+
}
658+
659+
return need_reserve;
660+
}
661+
543662
void llama_context::build_attn_inp(
544663
ggml_context * ctx0,
545664
int32_t n_tokens,

src/llama-context.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,13 @@ struct llama_context {
7979
ggml_abort_callback abort_callback = nullptr;
8080
void * abort_callback_data = nullptr;
8181

82+
// returns the result of ggml_backend_sched_graph_compute_async execution
83+
enum ggml_status compute_graph(
84+
ggml_cgraph * graph,
85+
bool batched);
86+
87+
llama_pos pos_max() const;
88+
8289
void reset();
8390

8491
void prepare_k_shift();
@@ -129,6 +136,9 @@ struct llama_context {
129136
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
130137
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
131138

139+
// return true if need to reserve new worst-case graph
140+
bool kv_self_update();
141+
132142
void build_attn_inp(
133143
ggml_context * ctx0,
134144
int32_t n_tokens,

0 commit comments

Comments
 (0)