Skip to content

Commit a00f3f6

Browse files
committed
feat: Add llama_model_is_hybrid API call
Also, split llama_model_is_recurrent into llm_arch_is_recurrent in llama-arch with llama_model_is_recurrent delegating to llm_arch_is_recurrent. The same split is done for hybird. This is needed because there are places where the llama_model has not yet been initialized but we need to check if the model is recurrent (specifically for the per-layer recurrent check array in hparams). Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
1 parent cd0dc98 commit a00f3f6

File tree

4 files changed

+33
-8
lines changed

4 files changed

+33
-8
lines changed

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,9 @@ extern "C" {
552552
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
553553
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
554554

555+
// Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.)
556+
LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);
557+
555558
// Returns 0 on success
556559
LLAMA_API uint32_t llama_model_quantize(
557560
const char * fname_inp,

src/llama-arch.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,3 +1745,25 @@ llm_arch llm_arch_from_string(const std::string & name) {
17451745
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
17461746
return LLM_TENSOR_INFOS.at(tensor);
17471747
}
1748+
1749+
bool llm_arch_is_recurrent(const llm_arch & arch) {
1750+
switch (arch) {
1751+
case LLM_ARCH_MAMBA:
1752+
case LLM_ARCH_RWKV6:
1753+
case LLM_ARCH_RWKV6QWEN2:
1754+
case LLM_ARCH_RWKV7:
1755+
case LLM_ARCH_ARWKV7:
1756+
return true;
1757+
default:
1758+
return false;
1759+
}
1760+
}
1761+
1762+
bool llm_arch_is_hybrid(const llm_arch & arch) {
1763+
// TODO: There are currently no hybrid models! Once there are, this will be
1764+
// the place to identify them
1765+
switch (arch) {
1766+
default:
1767+
return false;
1768+
}
1769+
}

src/llama-arch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,3 +436,6 @@ const char * llm_arch_name(llm_arch arch);
436436
llm_arch llm_arch_from_string(const std::string & name);
437437

438438
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
439+
440+
bool llm_arch_is_recurrent(const llm_arch& arch);
441+
bool llm_arch_is_hybrid(const llm_arch& arch);

src/llama-model.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13616,14 +13616,11 @@ llama_token llama_model_decoder_start_token(const llama_model * model) {
1361613616
}
1361713617

1361813618
bool llama_model_is_recurrent(const llama_model * model) {
13619-
switch (model->arch) {
13620-
case LLM_ARCH_MAMBA: return true;
13621-
case LLM_ARCH_RWKV6: return true;
13622-
case LLM_ARCH_RWKV6QWEN2: return true;
13623-
case LLM_ARCH_RWKV7: return true;
13624-
case LLM_ARCH_ARWKV7: return true;
13625-
default: return false;
13626-
}
13619+
return llm_arch_is_recurrent(model->arch);
13620+
}
13621+
13622+
bool llama_model_is_hybrid(const llama_model * model) {
13623+
return llm_arch_is_hybrid(model->arch);
1362713624
}
1362813625

1362913626
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {

0 commit comments

Comments
 (0)