Skip to content

Commit f483313

Browse files
committed
feat: Add GRANITE_MOE_HYBRID through llama-arch
Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
1 parent d6ebc23 commit f483313

File tree

2 files changed

+108
-73
lines changed

2 files changed

+108
-73
lines changed

src/llama-arch.cpp

Lines changed: 107 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -5,79 +5,80 @@
55
#include <map>
66

77
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
8-
{ LLM_ARCH_LLAMA, "llama" },
9-
{ LLM_ARCH_LLAMA4, "llama4" },
10-
{ LLM_ARCH_DECI, "deci" },
11-
{ LLM_ARCH_FALCON, "falcon" },
12-
{ LLM_ARCH_GROK, "grok" },
13-
{ LLM_ARCH_GPT2, "gpt2" },
14-
{ LLM_ARCH_GPTJ, "gptj" },
15-
{ LLM_ARCH_GPTNEOX, "gptneox" },
16-
{ LLM_ARCH_MPT, "mpt" },
17-
{ LLM_ARCH_BAICHUAN, "baichuan" },
18-
{ LLM_ARCH_STARCODER, "starcoder" },
19-
{ LLM_ARCH_REFACT, "refact" },
20-
{ LLM_ARCH_BERT, "bert" },
21-
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
22-
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
23-
{ LLM_ARCH_NEO_BERT, "neo-bert" },
24-
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
25-
{ LLM_ARCH_BLOOM, "bloom" },
26-
{ LLM_ARCH_STABLELM, "stablelm" },
27-
{ LLM_ARCH_QWEN, "qwen" },
28-
{ LLM_ARCH_QWEN2, "qwen2" },
29-
{ LLM_ARCH_QWEN2MOE, "qwen2moe" },
30-
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
31-
{ LLM_ARCH_QWEN3, "qwen3" },
32-
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
33-
{ LLM_ARCH_PHI2, "phi2" },
34-
{ LLM_ARCH_PHI3, "phi3" },
35-
{ LLM_ARCH_PHIMOE, "phimoe" },
36-
{ LLM_ARCH_PLAMO, "plamo" },
37-
{ LLM_ARCH_CODESHELL, "codeshell" },
38-
{ LLM_ARCH_ORION, "orion" },
39-
{ LLM_ARCH_INTERNLM2, "internlm2" },
40-
{ LLM_ARCH_MINICPM, "minicpm" },
41-
{ LLM_ARCH_MINICPM3, "minicpm3" },
42-
{ LLM_ARCH_GEMMA, "gemma" },
43-
{ LLM_ARCH_GEMMA2, "gemma2" },
44-
{ LLM_ARCH_GEMMA3, "gemma3" },
45-
{ LLM_ARCH_STARCODER2, "starcoder2" },
46-
{ LLM_ARCH_MAMBA, "mamba" },
47-
{ LLM_ARCH_MAMBA2, "mamba2" },
48-
{ LLM_ARCH_BAMBA, "bamba" },
49-
{ LLM_ARCH_XVERSE, "xverse" },
50-
{ LLM_ARCH_COMMAND_R, "command-r" },
51-
{ LLM_ARCH_COHERE2, "cohere2" },
52-
{ LLM_ARCH_DBRX, "dbrx" },
53-
{ LLM_ARCH_OLMO, "olmo" },
54-
{ LLM_ARCH_OLMO2, "olmo2" },
55-
{ LLM_ARCH_OLMOE, "olmoe" },
56-
{ LLM_ARCH_OPENELM, "openelm" },
57-
{ LLM_ARCH_ARCTIC, "arctic" },
58-
{ LLM_ARCH_DEEPSEEK, "deepseek" },
59-
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
60-
{ LLM_ARCH_CHATGLM, "chatglm" },
61-
{ LLM_ARCH_GLM4, "glm4" },
62-
{ LLM_ARCH_BITNET, "bitnet" },
63-
{ LLM_ARCH_T5, "t5" },
64-
{ LLM_ARCH_T5ENCODER, "t5encoder" },
65-
{ LLM_ARCH_JAIS, "jais" },
66-
{ LLM_ARCH_NEMOTRON, "nemotron" },
67-
{ LLM_ARCH_EXAONE, "exaone" },
68-
{ LLM_ARCH_RWKV6, "rwkv6" },
69-
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
70-
{ LLM_ARCH_RWKV7, "rwkv7" },
71-
{ LLM_ARCH_ARWKV7, "arwkv7" },
72-
{ LLM_ARCH_GRANITE, "granite" },
73-
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
74-
{ LLM_ARCH_CHAMELEON, "chameleon" },
75-
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
76-
{ LLM_ARCH_PLM, "plm" },
77-
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
78-
{ LLM_ARCH_DOTS1, "dots1" },
79-
{ LLM_ARCH_ARCEE, "arcee" },
80-
{ LLM_ARCH_UNKNOWN, "(unknown)" },
8+
{ LLM_ARCH_LLAMA, "llama" },
9+
{ LLM_ARCH_LLAMA4, "llama4" },
10+
{ LLM_ARCH_DECI, "deci" },
11+
{ LLM_ARCH_FALCON, "falcon" },
12+
{ LLM_ARCH_GROK, "grok" },
13+
{ LLM_ARCH_GPT2, "gpt2" },
14+
{ LLM_ARCH_GPTJ, "gptj" },
15+
{ LLM_ARCH_GPTNEOX, "gptneox" },
16+
{ LLM_ARCH_MPT, "mpt" },
17+
{ LLM_ARCH_BAICHUAN, "baichuan" },
18+
{ LLM_ARCH_STARCODER, "starcoder" },
19+
{ LLM_ARCH_REFACT, "refact" },
20+
{ LLM_ARCH_BERT, "bert" },
21+
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
22+
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
23+
{ LLM_ARCH_NEO_BERT, "neo-bert" },
24+
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
25+
{ LLM_ARCH_BLOOM, "bloom" },
26+
{ LLM_ARCH_STABLELM, "stablelm" },
27+
{ LLM_ARCH_QWEN, "qwen" },
28+
{ LLM_ARCH_QWEN2, "qwen2" },
29+
{ LLM_ARCH_QWEN2MOE, "qwen2moe" },
30+
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
31+
{ LLM_ARCH_QWEN3, "qwen3" },
32+
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
33+
{ LLM_ARCH_PHI2, "phi2" },
34+
{ LLM_ARCH_PHI3, "phi3" },
35+
{ LLM_ARCH_PHIMOE, "phimoe" },
36+
{ LLM_ARCH_PLAMO, "plamo" },
37+
{ LLM_ARCH_CODESHELL, "codeshell" },
38+
{ LLM_ARCH_ORION, "orion" },
39+
{ LLM_ARCH_INTERNLM2, "internlm2" },
40+
{ LLM_ARCH_MINICPM, "minicpm" },
41+
{ LLM_ARCH_MINICPM3, "minicpm3" },
42+
{ LLM_ARCH_GEMMA, "gemma" },
43+
{ LLM_ARCH_GEMMA2, "gemma2" },
44+
{ LLM_ARCH_GEMMA3, "gemma3" },
45+
{ LLM_ARCH_STARCODER2, "starcoder2" },
46+
{ LLM_ARCH_MAMBA, "mamba" },
47+
{ LLM_ARCH_MAMBA2, "mamba2" },
48+
{ LLM_ARCH_BAMBA, "bamba" },
49+
{ LLM_ARCH_XVERSE, "xverse" },
50+
{ LLM_ARCH_COMMAND_R, "command-r" },
51+
{ LLM_ARCH_COHERE2, "cohere2" },
52+
{ LLM_ARCH_DBRX, "dbrx" },
53+
{ LLM_ARCH_OLMO, "olmo" },
54+
{ LLM_ARCH_OLMO2, "olmo2" },
55+
{ LLM_ARCH_OLMOE, "olmoe" },
56+
{ LLM_ARCH_OPENELM, "openelm" },
57+
{ LLM_ARCH_ARCTIC, "arctic" },
58+
{ LLM_ARCH_DEEPSEEK, "deepseek" },
59+
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
60+
{ LLM_ARCH_CHATGLM, "chatglm" },
61+
{ LLM_ARCH_GLM4, "glm4" },
62+
{ LLM_ARCH_BITNET, "bitnet" },
63+
{ LLM_ARCH_T5, "t5" },
64+
{ LLM_ARCH_T5ENCODER, "t5encoder" },
65+
{ LLM_ARCH_JAIS, "jais" },
66+
{ LLM_ARCH_NEMOTRON, "nemotron" },
67+
{ LLM_ARCH_EXAONE, "exaone" },
68+
{ LLM_ARCH_RWKV6, "rwkv6" },
69+
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
70+
{ LLM_ARCH_RWKV7, "rwkv7" },
71+
{ LLM_ARCH_ARWKV7, "arwkv7" },
72+
{ LLM_ARCH_GRANITE, "granite" },
73+
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
74+
{ LLM_ARCH_GRANITE_MOE_HYBRID, "granitemoehybrid" },
75+
{ LLM_ARCH_CHAMELEON, "chameleon" },
76+
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
77+
{ LLM_ARCH_PLM, "plm" },
78+
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
79+
{ LLM_ARCH_DOTS1, "dots1" },
80+
{ LLM_ARCH_ARCEE, "arcee" },
81+
{ LLM_ARCH_UNKNOWN, "(unknown)" },
8182
};
8283

8384
static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
@@ -1576,6 +1577,38 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
15761577
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
15771578
},
15781579
},
1580+
{
1581+
LLM_ARCH_GRANITE_MOE_HYBRID,
1582+
{
1583+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1584+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1585+
{ LLM_TENSOR_OUTPUT, "output" },
1586+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1587+
// mamba(2) ssm layers
1588+
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
1589+
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
1590+
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
1591+
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
1592+
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
1593+
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
1594+
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
1595+
// attention layers
1596+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1597+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1598+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1599+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1600+
// moe FFN
1601+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1602+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1603+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1604+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1605+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1606+
// shared expert
1607+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
1608+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
1609+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1610+
},
1611+
},
15791612
{
15801613
LLM_ARCH_CHAMELEON,
15811614
{
@@ -1889,6 +1922,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
18891922
// the place to identify them
18901923
switch (arch) {
18911924
case LLM_ARCH_BAMBA:
1925+
case LLM_ARCH_GRANITE_MOE_HYBRID:
18921926
return true;
18931927
default:
18941928
return false;

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ enum llm_arch {
7575
LLM_ARCH_ARWKV7,
7676
LLM_ARCH_GRANITE,
7777
LLM_ARCH_GRANITE_MOE,
78+
LLM_ARCH_GRANITE_MOE_HYBRID,
7879
LLM_ARCH_CHAMELEON,
7980
LLM_ARCH_WAVTOKENIZER_DEC,
8081
LLM_ARCH_PLM,

0 commit comments

Comments
 (0)