Skip to content

Commit 3219f58

Browse files
committed
feat(granitemoe): Implement granitemoe
GraniteMoE follows the mixtral architecture (once the input_linear layers are split into gate_exps/up_exps). The main delta is the addition of the same four multipliers used in Granite. Branch: GraniteMoE Signed-off-by: Gabe Goodhart <[email protected]>
1 parent edd2885 commit 3219f58

File tree

1 file changed

+27
-4
lines changed

1 file changed

+27
-4
lines changed

src/llama.cpp

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ enum llm_arch {
213213
LLM_ARCH_EXAONE,
214214
LLM_ARCH_RWKV6,
215215
LLM_ARCH_GRANITE,
216+
LLM_ARCH_GRANITE_MOE,
216217
LLM_ARCH_UNKNOWN,
217218
};
218219

@@ -262,6 +263,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
262263
{ LLM_ARCH_EXAONE, "exaone" },
263264
{ LLM_ARCH_RWKV6, "rwkv6" },
264265
{ LLM_ARCH_GRANITE, "granite" },
266+
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
265267
{ LLM_ARCH_UNKNOWN, "(unknown)" },
266268
};
267269

@@ -1431,6 +1433,23 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
14311433
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
14321434
},
14331435
},
1436+
{
1437+
LLM_ARCH_GRANITE_MOE,
1438+
{
1439+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1440+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1441+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1442+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1443+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1444+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1445+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1446+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1447+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1448+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1449+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1450+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1451+
},
1452+
},
14341453
{
14351454
LLM_ARCH_UNKNOWN,
14361455
{
@@ -2344,7 +2363,7 @@ struct llama_hparams {
23442363
float f_max_alibi_bias = 0.0f;
23452364
float f_logit_scale = 0.0f;
23462365

2347-
// For Granite architecture
2366+
// For Granite architectures
23482367
float f_residual_multiplier = 0.0f;
23492368
float f_embedding_multiplier = 0.0f;
23502369
float f_attention_multiplier = 0.0f;
@@ -5385,6 +5404,7 @@ static void llm_load_hparams(
53855404
switch (model.arch) {
53865405
case LLM_ARCH_LLAMA:
53875406
case LLM_ARCH_GRANITE:
5407+
case LLM_ARCH_GRANITE_MOE:
53885408
{
53895409
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
53905410

@@ -5408,8 +5428,8 @@ static void llm_load_hparams(
54085428
default: model.type = e_model::MODEL_UNKNOWN;
54095429
}
54105430
}
5411-
// Extra multipliers for Granite architecture
5412-
if (model.arch == LLM_ARCH_GRANITE) {
5431+
// Extra multipliers for Granite architectures
5432+
if (model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) {
54135433
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
54145434
ml.get_key(LLM_KV_RESIDUAL_MULTIPLIER, hparams.f_residual_multiplier);
54155435
ml.get_key(LLM_KV_EMBEDDING_MULTIPLIER, hparams.f_embedding_multiplier);
@@ -6685,7 +6705,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
66856705
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
66866706
}
66876707

6688-
if (model.arch == LLM_ARCH_GRANITE) {
6708+
if (model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) {
66896709
LLAMA_LOG_INFO("%s: f_embedding_multiplier = %f\n", __func__, hparams.f_embedding_multiplier);
66906710
LLAMA_LOG_INFO("%s: f_residual_multiplier = %f\n", __func__, hparams.f_residual_multiplier);
66916711
LLAMA_LOG_INFO("%s: f_attention_multiplier = %f\n", __func__, hparams.f_attention_multiplier);
@@ -6861,6 +6881,7 @@ static bool llm_load_tensors(
68616881
case LLM_ARCH_REFACT:
68626882
case LLM_ARCH_MINICPM:
68636883
case LLM_ARCH_GRANITE:
6884+
case LLM_ARCH_GRANITE_MOE:
68646885
{
68656886
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
68666887

@@ -15362,6 +15383,7 @@ static struct ggml_cgraph * llama_build_graph(
1536215383
switch (model.arch) {
1536315384
case LLM_ARCH_LLAMA:
1536415385
case LLM_ARCH_GRANITE:
15386+
case LLM_ARCH_GRANITE_MOE:
1536515387
{
1536615388
result = llm.build_llama();
1536715389
} break;
@@ -18649,6 +18671,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
1864918671
case LLM_ARCH_DEEPSEEK2:
1865018672
case LLM_ARCH_CHATGLM:
1865118673
case LLM_ARCH_GRANITE:
18674+
case LLM_ARCH_GRANITE_MOE:
1865218675
return LLAMA_ROPE_TYPE_NORM;
1865318676

1865418677
// the pairs of head values are offset by n_rot/2

0 commit comments

Comments
 (0)