Skip to content

Commit f149a8e

Browse files
committed
memory : merge llama_kv_cache into llama_memory + new llama_memory API
ggml-ci
1 parent 9e31bec commit f149a8e

11 files changed

+324
-220
lines changed

include/llama.h

Lines changed: 88 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ extern "C" {
6161
struct llama_model;
6262
struct llama_context;
6363
struct llama_sampler;
64-
struct llama_kv_cache;
64+
65+
typedef struct llama_memory_i * llama_memory_t;
66+
67+
struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
6568

6669
typedef int32_t llama_pos;
6770
typedef int32_t llama_token;
@@ -493,9 +496,11 @@ extern "C" {
493496
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
494497

495498
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
496-
LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx);
499+
LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
497500
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
498501

502+
DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
503+
499504
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
500505
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
501506

@@ -609,7 +614,78 @@ extern "C" {
609614
int32_t il_end);
610615

611616
//
612-
// KV cache
617+
// Memory
618+
//
619+
620+
// Clear the memory contents
621+
LLAMA_API void llama_memory_clear(llama_memory_t mem);
622+
623+
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
624+
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
625+
// seq_id < 0 : match any sequence
626+
// p0 < 0 : [0, p1]
627+
// p1 < 0 : [p0, inf)
628+
LLAMA_API bool llama_memory_seq_rm(
629+
llama_memory_t mem,
630+
llama_seq_id seq_id,
631+
llama_pos p0,
632+
llama_pos p1);
633+
634+
// Copy all tokens that belong to the specified sequence to another sequence
635+
// p0 < 0 : [0, p1]
636+
// p1 < 0 : [p0, inf)
637+
LLAMA_API void llama_memory_seq_cp(
638+
llama_memory_t mem,
639+
llama_seq_id seq_id_src,
640+
llama_seq_id seq_id_dst,
641+
llama_pos p0,
642+
llama_pos p1);
643+
644+
// Removes all tokens that do not belong to the specified sequence
645+
LLAMA_API void llama_memory_seq_keep(
646+
llama_memory_t mem,
647+
llama_seq_id seq_id);
648+
649+
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
650+
// p0 < 0 : [0, p1]
651+
// p1 < 0 : [p0, inf)
652+
LLAMA_API void llama_memory_seq_add(
653+
llama_memory_t mem,
654+
llama_seq_id seq_id,
655+
llama_pos p0,
656+
llama_pos p1,
657+
llama_pos delta);
658+
659+
// Integer division of the positions by factor of `d > 1`
660+
// p0 < 0 : [0, p1]
661+
// p1 < 0 : [p0, inf)
662+
LLAMA_API void llama_memory_seq_div(
663+
llama_memory_t mem,
664+
llama_seq_id seq_id,
665+
llama_pos p0,
666+
llama_pos p1,
667+
int d);
668+
669+
// Returns the smallest position present in the memory for the specified sequence
670+
// This is typically non-zero only for SWA caches
671+
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
672+
// Return -1 if the sequence is empty
673+
LLAMA_API llama_pos llama_memory_seq_pos_min(
674+
llama_memory_t mem,
675+
llama_seq_id seq_id);
676+
677+
// Returns the largest position present in the memory for the specified sequence
678+
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
679+
// Return -1 if the sequence is empty
680+
LLAMA_API llama_pos llama_memory_seq_pos_max(
681+
llama_memory_t mem,
682+
llama_seq_id seq_id);
683+
684+
// Check if the memory supports shifting
685+
LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
686+
687+
//
688+
// KV cache for self-attention (TODO: deprecate in favor of llama_memory)
613689
//
614690

615691
// Returns the number of tokens in the KV cache (slow, use only for debug)
@@ -623,7 +699,7 @@ extern "C" {
623699

624700
// Clear the KV cache - both cell info is erased and KV data is zeroed
625701
LLAMA_API void llama_kv_self_clear(
626-
struct llama_context * ctx);
702+
struct llama_context * ctx);
627703

628704
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
629705
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
@@ -694,22 +770,22 @@ extern "C" {
694770
// Defragment the KV cache
695771
// This will be applied:
696772
// - lazily on next llama_decode()
697-
LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx),
773+
DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx),
698774
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
699775

700776
// Check if the context supports KV cache shifting
701777
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
702778

703779
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
704-
LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx),
780+
DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx),
705781
"simply remove this call, updates are applied lazily on the next llama_decode()");
706782

707783
//
708784
// State / sessions
709785
//
710786

711787
// Returns the *actual* size in bytes of the state
712-
// (logits, embedding and kv_cache)
788+
// (logits, embedding and memory)
713789
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
714790
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
715791
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
@@ -765,12 +841,12 @@ extern "C" {
765841
size_t n_token_count),
766842
"use llama_state_save_file instead");
767843

768-
// Get the exact size needed to copy the KV cache of a single sequence
844+
// Get the exact size needed to copy the state of a single sequence
769845
LLAMA_API size_t llama_state_seq_get_size(
770846
struct llama_context * ctx,
771847
llama_seq_id seq_id);
772848

773-
// Copy the KV cache of a single sequence into the specified buffer
849+
// Copy the state of a single sequence into the specified buffer
774850
LLAMA_API size_t llama_state_seq_get_data(
775851
struct llama_context * ctx,
776852
uint8_t * dst,
@@ -836,16 +912,16 @@ extern "C" {
836912
// For encode-decoder contexts, processes the batch using the encoder.
837913
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
838914
// 0 - success
839-
// < 0 - error. the KV cache state is restored to the state before this call
915+
// < 0 - error. the memory state is restored to the state before this call
840916
LLAMA_API int32_t llama_encode(
841917
struct llama_context * ctx,
842918
struct llama_batch batch);
843919

844920
// Process a batch of tokens.
845-
// Requires KV cache.
921+
// Requires the context to have a memory.
846922
// For encode-decoder contexts, processes the batch using the decoder.
847923
// Positive return values does not mean a fatal error, but rather a warning.
848-
// Upon non-zero return values, the KV cache state is restored to the state before this call
924+
// Upon non-zero return values, the memory state is restored to the state before this call
849925
// 0 - success
850926
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
851927
// 2 - aborted

src/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ add_library(llama
2020
llama-hparams.cpp
2121
llama-impl.cpp
2222
llama-io.cpp
23-
llama-kv-cache.cpp
2423
llama-kv-cache-unified.cpp
2524
llama-kv-cache-unified-iswa.cpp
2625
llama-kv-cache-recurrent.cpp

0 commit comments

Comments
 (0)