@@ -32,6 +32,38 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
32
32
return relative_bucket;
33
33
}
34
34
35
+ enum ggml_status llama_context::compute_graph (
36
+ ggml_cgraph * graph,
37
+ bool batched) {
38
+ int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads ;
39
+ ggml_threadpool_t tp = batched ? threadpool_batch : threadpool;
40
+
41
+ if (backend_cpu != nullptr ) {
42
+ auto * reg = ggml_backend_dev_backend_reg (ggml_backend_get_device (backend_cpu));
43
+ auto * set_threadpool_fn = (decltype (ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address (reg, " ggml_backend_cpu_set_threadpool" );
44
+ set_threadpool_fn (backend_cpu, tp);
45
+ }
46
+
47
+ // set the number of threads for all the backends
48
+ for (const auto & set_n_threads_fn : set_n_threads_fns) {
49
+ set_n_threads_fn.second (set_n_threads_fn.first , n_threads);
50
+ }
51
+
52
+ auto status = ggml_backend_sched_graph_compute_async (sched.get (), graph);
53
+ if (status != GGML_STATUS_SUCCESS) {
54
+ LLAMA_LOG_ERROR (" %s: ggml_backend_sched_graph_compute_async failed with error %d\n " , __func__, status);
55
+ }
56
+
57
+ // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched));
58
+
59
+ return status;
60
+ }
61
+
62
+
63
+ llama_pos llama_context::pos_max () const {
64
+ return kv_self.pos_max ();
65
+ }
66
+
35
67
// TODO: improve
36
68
void llama_context::reset () {
37
69
inp_tokens = nullptr ;
@@ -540,6 +572,93 @@ ggml_tensor * llama_context::build_lora_mm_id(
540
572
return res;
541
573
}
542
574
575
+ bool llama_context::kv_self_update () {
576
+ bool need_reserve = false ;
577
+
578
+ auto & kv = kv_self;
579
+
580
+ if (kv.has_shift ) {
581
+ if (!kv.can_shift ) {
582
+ GGML_ABORT (" The current context does not support K-shift" );
583
+ }
584
+
585
+ // apply K-shift if needed
586
+ if (model.hparams .rope_type != LLAMA_ROPE_TYPE_NONE) {
587
+ prepare_k_shift ();
588
+
589
+ ggml_backend_sched_reset (sched.get ());
590
+
591
+ struct ggml_init_params params = {
592
+ /* .mem_size =*/ buf_compute_meta.size (),
593
+ /* .mem_buffer =*/ buf_compute_meta.data (),
594
+ /* .no_alloc =*/ true ,
595
+ };
596
+
597
+ ggml_context * ctx0 = ggml_init (params);
598
+
599
+ reset ();
600
+
601
+ ggml_cgraph * gf = ggml_new_graph_custom (ctx0, model.max_nodes (), false );
602
+
603
+ build_k_shift (ctx0, gf);
604
+
605
+ ggml_backend_sched_alloc_graph (sched.get (), gf);
606
+
607
+ set_inputs ({});
608
+
609
+ compute_graph (gf, false );
610
+
611
+ ggml_free (ctx0);
612
+
613
+ need_reserve = true ;
614
+ }
615
+
616
+ {
617
+ kv.has_shift = false ;
618
+
619
+ for (uint32_t i = 0 ; i < kv.size ; ++i) {
620
+ kv.cells [i].delta = 0 ;
621
+ }
622
+ }
623
+ }
624
+
625
+ // defragment the KV cache if needed
626
+ if (kv.do_defrag ) {
627
+ prepare_defrag ();
628
+
629
+ ggml_backend_sched_reset (sched.get ());
630
+
631
+ struct ggml_init_params params = {
632
+ /* .mem_size =*/ buf_compute_meta.size (),
633
+ /* .mem_buffer =*/ buf_compute_meta.data (),
634
+ /* .no_alloc =*/ true ,
635
+ };
636
+
637
+ ggml_context * ctx0 = ggml_init (params);
638
+
639
+ reset ();
640
+
641
+ ggml_cgraph * gf = ggml_new_graph_custom (ctx0, model.max_nodes (), false );
642
+
643
+ build_defrag (ctx0, gf);
644
+
645
+ ggml_backend_sched_alloc_graph (sched.get (), gf);
646
+
647
+ // no input
648
+ // set_inputs({});
649
+
650
+ compute_graph (gf, false );
651
+
652
+ ggml_free (ctx0);
653
+
654
+ need_reserve = true ;
655
+
656
+ kv.do_defrag = false ;
657
+ }
658
+
659
+ return need_reserve;
660
+ }
661
+
543
662
void llama_context::build_attn_inp (
544
663
ggml_context * ctx0,
545
664
int32_t n_tokens,
0 commit comments