Skip to content

Hybrid recurrent cache #13979

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
ec8fe17
feat: Add llama_model_is_hybrid API call
gabe-l-hart May 9, 2025
5e2f2c3
feat: Add c++ side constants for attention layer indices hparam
gabe-l-hart May 9, 2025
05f1958
feat: Add support for distinguishing recurrent vs non-recurrent layer…
gabe-l-hart May 9, 2025
fc9e0b5
feat: Auto-fill hparams.recurrent_layer_arr based on whether the mode…
gabe-l-hart May 9, 2025
fb26e95
refactor: rename *_is_hybrid -> *_is_hybrid_recurrent
gabe-l-hart May 28, 2025
40e9187
feat: Add layer filter to recurrent cache
gabe-l-hart May 20, 2025
13332a7
fix: Use per-layer sizing everywhere in kv caches
gabe-l-hart May 14, 2025
c71eaa3
feat: First pass at llama_kv_cache_hybrid_recurrent
gabe-l-hart May 30, 2025
423c894
feat: Construct hybrid recurrent cache for hybrid recurrent models
gabe-l-hart May 28, 2025
6c6ec00
fix: Fix wrong bool condition for split equal in hybrid cache
gabe-l-hart May 28, 2025
cf03d4a
fix: Fix shift logic to defer to unified cache
gabe-l-hart Jun 3, 2025
e3c1631
feat: Support hybrid recurrent in llama-graph
gabe-l-hart Jun 4, 2025
a9b5fe9
fix: Fix logic for initializing inputs and attn layers for hybrid caches
gabe-l-hart Jun 4, 2025
d369936
fix: Update recurrent cache for changes to remove intermediate kv_cac…
gabe-l-hart Jun 5, 2025
911e694
fix: Fix status for init_update sig for recurrent cache state
gabe-l-hart Jun 5, 2025
de9297f
fix: Add missing padding to n_ctx for hybrid cache construction
gabe-l-hart Jun 5, 2025
9c1a604
fix: Update clear signature for data argument after rebase
gabe-l-hart Jun 6, 2025
f6d5f05
fix: Remove errant virtual destructor leftover from previous impl att…
gabe-l-hart Jun 10, 2025
833dfb5
fix: Use per-layer n_embd_k/v_s calls for mamba (1) layers
gabe-l-hart Jun 10, 2025
1dd1213
refactor: Remove n_embd_k/v_s from unified cache
gabe-l-hart Jun 11, 2025
b42c8b4
refactor: Remove layer index from n_embd_k/v_s
gabe-l-hart Jun 11, 2025
d5d7628
refactor: Remove n_embd_k/v_gqa from recurrent cache
gabe-l-hart Jun 11, 2025
d8c929f
feat: Allow custom layer filters for hybrid recurrent
gabe-l-hart Jun 11, 2025
1510016
fix: Remove logits_all after rebase
gabe-l-hart Jun 12, 2025
7ba463b
fix: Remove llama_model_is_hybrid_Recurrent public API
gabe-l-hart Jun 12, 2025
4ec4e6a
refactor: Use llama_memory_state_ptr for child states in hybrid memor…
gabe-l-hart Jun 12, 2025
11cd80d
feat: Overhaul build_recurrent_state / build_inp_s_copy to match atte…
gabe-l-hart Jun 12, 2025
9db44a2
fix: Fix resize vs reserve and skip null tensors in size computation
gabe-l-hart Jun 16, 2025
5046d41
fix: Fix initialization of child states
gabe-l-hart Jun 16, 2025
faf4119
refactor: Use a common build_recurrent_state method that is cache-agn…
gabe-l-hart Jun 16, 2025
59fee24
recurrent : rework graph inputs + add TODOs
ggerganov Jun 18, 2025
c80e68c
Merge pull request #2 from ggml-org/gabe-l-hart/HybridRecurrentCache
gabe-l-hart Jun 18, 2025
8488f5e
refactor: Make status and child states const in hybrid and iswa
gabe-l-hart Jun 18, 2025
88213a9
refactor: Rename llama_kv_cache_[recurrent|hybrid_recurrent] to remov…
gabe-l-hart Jun 18, 2025
8e39e04
refactor!: Rename all k/v related values for recurrent/hybrid to r/s
gabe-l-hart Jun 18, 2025
6403f19
refacor: _recurrent -> _recr for brevity
gabe-l-hart Jun 18, 2025
d0565e8
style: Fix spacing for ref
gabe-l-hart Jun 18, 2025
35c0233
refactor: recurrent_layer() -> is_recurrent()
gabe-l-hart Jun 18, 2025
304f86e
style: Fix spacing for size_s_bytes declaration
gabe-l-hart Jun 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ add_library(llama
llama-io.cpp
llama-kv-cache-unified.cpp
llama-kv-cache-unified-iswa.cpp
llama-kv-cache-recurrent.cpp
llama-memory.cpp
llama-memory-hybrid.cpp
llama-memory-recurrent.cpp
llama-mmap.cpp
llama-model-loader.cpp
llama-model-saver.cpp
Expand Down
23 changes: 23 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
{ LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" },

{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
Expand Down Expand Up @@ -1816,3 +1817,25 @@ llm_arch llm_arch_from_string(const std::string & name) {
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
return LLM_TENSOR_INFOS.at(tensor);
}

bool llm_arch_is_recurrent(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_MAMBA:
case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
return true;
default:
return false;
}
}

bool llm_arch_is_hybrid(const llm_arch & arch) {
// TODO: There are currently no hybrid models! Once there are, this will be
// the place to identify them
switch (arch) {
default:
return false;
}
}
4 changes: 4 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ enum llm_kv {
LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
LLM_KV_ATTENTION_LAYER_INDICES,

LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_SECTIONS,
Expand Down Expand Up @@ -439,3 +440,6 @@ const char * llm_arch_name(llm_arch arch);
llm_arch llm_arch_from_string(const std::string & name);

const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);

bool llm_arch_is_recurrent(const llm_arch & arch);
bool llm_arch_is_hybrid (const llm_arch & arch);
Loading
Loading