Skip to content

memory : migrate from llama_kv_cache to more generic llama_memory #14006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 88 additions & 12 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ extern "C" {
struct llama_model;
struct llama_context;
struct llama_sampler;
struct llama_kv_cache;

typedef struct llama_memory_i * llama_memory_t;

struct llama_kv_cache; // DEPRECATED (use llama_memory instead)

typedef int32_t llama_pos;
typedef int32_t llama_token;
Expand Down Expand Up @@ -493,9 +496,11 @@ extern "C" {
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");

LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx);
LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type

DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");

LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);

Expand Down Expand Up @@ -609,7 +614,78 @@ extern "C" {
int32_t il_end);

//
// KV cache
// Memory
//

// Clear the memory contents
LLAMA_API void llama_memory_clear(llama_memory_t mem);

// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
// seq_id < 0 : match any sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API bool llama_memory_seq_rm(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1);

// Copy all tokens that belong to the specified sequence to another sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_memory_seq_cp(
llama_memory_t mem,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1);

// Removes all tokens that do not belong to the specified sequence
LLAMA_API void llama_memory_seq_keep(
llama_memory_t mem,
llama_seq_id seq_id);

// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_memory_seq_add(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta);

// Integer division of the positions by factor of `d > 1`
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_memory_seq_div(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);

// Returns the smallest position present in the memory for the specified sequence
// This is typically non-zero only for SWA caches
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
// Return -1 if the sequence is empty
LLAMA_API llama_pos llama_memory_seq_pos_min(
llama_memory_t mem,
llama_seq_id seq_id);

// Returns the largest position present in the memory for the specified sequence
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
// Return -1 if the sequence is empty
LLAMA_API llama_pos llama_memory_seq_pos_max(
llama_memory_t mem,
llama_seq_id seq_id);

// Check if the memory supports shifting
LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);

//
// KV cache for self-attention (TODO: deprecate in favor of llama_memory)
//

// Returns the number of tokens in the KV cache (slow, use only for debug)
Expand All @@ -623,7 +699,7 @@ extern "C" {

// Clear the KV cache - both cell info is erased and KV data is zeroed
LLAMA_API void llama_kv_self_clear(
struct llama_context * ctx);
struct llama_context * ctx);

// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
Expand Down Expand Up @@ -694,22 +770,22 @@ extern "C" {
// Defragment the KV cache
// This will be applied:
// - lazily on next llama_decode()
LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx),
DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx),
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");

// Check if the context supports KV cache shifting
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);

// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx),
DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx),
"simply remove this call, updates are applied lazily on the next llama_decode()");

//
// State / sessions
//

// Returns the *actual* size in bytes of the state
// (logits, embedding and kv_cache)
// (logits, embedding and memory)
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
Expand Down Expand Up @@ -765,12 +841,12 @@ extern "C" {
size_t n_token_count),
"use llama_state_save_file instead");

// Get the exact size needed to copy the KV cache of a single sequence
// Get the exact size needed to copy the state of a single sequence
LLAMA_API size_t llama_state_seq_get_size(
struct llama_context * ctx,
llama_seq_id seq_id);

// Copy the KV cache of a single sequence into the specified buffer
// Copy the state of a single sequence into the specified buffer
LLAMA_API size_t llama_state_seq_get_data(
struct llama_context * ctx,
uint8_t * dst,
Expand Down Expand Up @@ -836,16 +912,16 @@ extern "C" {
// For encode-decoder contexts, processes the batch using the encoder.
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
// 0 - success
// < 0 - error. the KV cache state is restored to the state before this call
// < 0 - error. the memory state is restored to the state before this call
LLAMA_API int32_t llama_encode(
struct llama_context * ctx,
struct llama_batch batch);

// Process a batch of tokens.
// Requires KV cache.
// Requires the context to have a memory.
// For encode-decoder contexts, processes the batch using the decoder.
// Positive return values does not mean a fatal error, but rather a warning.
// Upon non-zero return values, the KV cache state is restored to the state before this call
// Upon non-zero return values, the memory state is restored to the state before this call
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// 2 - aborted
Expand Down
1 change: 0 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ add_library(llama
llama-hparams.cpp
llama-impl.cpp
llama-io.cpp
llama-kv-cache.cpp
llama-kv-cache-unified.cpp
llama-kv-cache-unified-iswa.cpp
llama-kv-cache-recurrent.cpp
Expand Down
Loading
Loading