@@ -60,6 +60,7 @@ extern "C" {
60
60
struct llama_model ;
61
61
struct llama_context ;
62
62
struct llama_sampler ;
63
+ struct llama_kv_cache ;
63
64
64
65
typedef int32_t llama_pos;
65
66
typedef int32_t llama_token;
@@ -460,8 +461,9 @@ extern "C" {
460
461
461
462
DEPRECATED (LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
462
463
463
- LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
464
- LLAMA_API enum llama_pooling_type llama_pooling_type (const struct llama_context * ctx);
464
+ LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); // TODO: remove const?
465
+ LLAMA_API struct llama_kv_cache * llama_get_kv_cache ( struct llama_context * ctx);
466
+ LLAMA_API enum llama_pooling_type llama_pooling_type (const struct llama_context * ctx);
465
467
466
468
LLAMA_API const struct llama_vocab * llama_model_get_vocab (const struct llama_model * model);
467
469
LLAMA_API enum llama_rope_type llama_model_rope_type (const struct llama_model * model);
@@ -576,7 +578,7 @@ extern "C" {
576
578
// KV cache
577
579
//
578
580
579
- // TODO: remove llama_kv_cache_view_* API
581
+ // TODO: start using struct llama_kv_cache
580
582
581
583
// Information associated with an individual cell in the KV cache view.
582
584
struct llama_kv_cache_view_cell {
@@ -631,41 +633,47 @@ extern "C" {
631
633
632
634
// Returns the number of tokens in the KV cache (slow, use only for debug)
633
635
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
634
- LLAMA_API int32_t llama_get_kv_cache_token_count (const struct llama_context * ctx);
636
+ LLAMA_API int32_t llama_kv_cache_n_tokens (const struct llama_kv_cache * kv);
637
+
638
+ DEPRECATED (LLAMA_API int32_t llama_get_kv_cache_token_count (const struct llama_context * ctx),
639
+ "use llama_kv_cache_n_tokens instead");
635
640
636
641
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
637
- LLAMA_API int32_t llama_get_kv_cache_used_cells (const struct llama_context * ctx);
642
+ LLAMA_API int32_t llama_kv_cache_used_cells (const struct llama_kv_cache * kv);
643
+
644
+ DEPRECATED (LLAMA_API int32_t llama_get_kv_cache_used_cells (const struct llama_context * ctx),
645
+ "use llama_kv_cache_used_cells instead");
638
646
639
647
// Clear the KV cache - both cell info is erased and KV data is zeroed
640
648
LLAMA_API void llama_kv_cache_clear (
641
- struct llama_context * ctx );
649
+ struct llama_kv_cache * kv );
642
650
643
651
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
644
652
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
645
653
// seq_id < 0 : match any sequence
646
654
// p0 < 0 : [0, p1]
647
655
// p1 < 0 : [p0, inf)
648
656
LLAMA_API bool llama_kv_cache_seq_rm (
649
- struct llama_context * ctx ,
650
- llama_seq_id seq_id,
651
- llama_pos p0,
652
- llama_pos p1);
657
+ struct llama_kv_cache * kv ,
658
+ llama_seq_id seq_id,
659
+ llama_pos p0,
660
+ llama_pos p1);
653
661
654
662
// Copy all tokens that belong to the specified sequence to another sequence
655
663
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
656
664
// p0 < 0 : [0, p1]
657
665
// p1 < 0 : [p0, inf)
658
666
LLAMA_API void llama_kv_cache_seq_cp (
659
- struct llama_context * ctx ,
660
- llama_seq_id seq_id_src,
661
- llama_seq_id seq_id_dst,
662
- llama_pos p0,
663
- llama_pos p1);
667
+ struct llama_kv_cache * kv ,
668
+ llama_seq_id seq_id_src,
669
+ llama_seq_id seq_id_dst,
670
+ llama_pos p0,
671
+ llama_pos p1);
664
672
665
673
// Removes all tokens that do not belong to the specified sequence
666
674
LLAMA_API void llama_kv_cache_seq_keep (
667
- struct llama_context * ctx ,
668
- llama_seq_id seq_id);
675
+ struct llama_kv_cache * kv ,
676
+ llama_seq_id seq_id);
669
677
670
678
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
671
679
// If the KV cache is RoPEd, the KV data is updated accordingly:
@@ -674,11 +682,11 @@ extern "C" {
674
682
// p0 < 0 : [0, p1]
675
683
// p1 < 0 : [p0, inf)
676
684
LLAMA_API void llama_kv_cache_seq_add (
677
- struct llama_context * ctx ,
678
- llama_seq_id seq_id,
679
- llama_pos p0,
680
- llama_pos p1,
681
- llama_pos delta);
685
+ struct llama_kv_cache * kv ,
686
+ llama_seq_id seq_id,
687
+ llama_pos p0,
688
+ llama_pos p1,
689
+ llama_pos delta);
682
690
683
691
// Integer division of the positions by factor of `d > 1`
684
692
// If the KV cache is RoPEd, the KV data is updated accordingly:
@@ -687,31 +695,28 @@ extern "C" {
687
695
// p0 < 0 : [0, p1]
688
696
// p1 < 0 : [p0, inf)
689
697
LLAMA_API void llama_kv_cache_seq_div (
690
- struct llama_context * ctx ,
691
- llama_seq_id seq_id,
692
- llama_pos p0,
693
- llama_pos p1,
694
- int d);
698
+ struct llama_kv_cache * kv ,
699
+ llama_seq_id seq_id,
700
+ llama_pos p0,
701
+ llama_pos p1,
702
+ int d);
695
703
696
704
// Returns the largest position present in the KV cache for the specified sequence
697
705
LLAMA_API llama_pos llama_kv_cache_seq_pos_max (
698
- struct llama_context * ctx,
699
- llama_seq_id seq_id);
700
-
701
- // TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache
702
- // how to avoid this?
706
+ struct llama_kv_cache * kv,
707
+ llama_seq_id seq_id);
703
708
704
709
// Defragment the KV cache
705
710
// This will be applied:
706
711
// - lazily on next llama_decode()
707
712
// - explicitly with llama_kv_cache_update()
708
- LLAMA_API void llama_kv_cache_defrag (struct llama_context * ctx);
709
-
710
- // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
711
- LLAMA_API void llama_kv_cache_update (struct llama_context * ctx);
713
+ LLAMA_API void llama_kv_cache_defrag (struct llama_kv_cache * kv);
712
714
713
715
// Check if the context supports KV cache shifting
714
- LLAMA_API bool llama_kv_cache_can_shift (struct llama_context * ctx);
716
+ LLAMA_API bool llama_kv_cache_can_shift (const struct llama_kv_cache * kv);
717
+
718
+ // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
719
+ LLAMA_API void llama_update_kv_cache (struct llama_context * ctx, struct llama_kv_cache * kv);
715
720
716
721
//
717
722
// State / sessions
0 commit comments