Skip to content

Commit 1b235d0

Browse files
committed
feat(granitemoe): Implement granitemoe
GraniteMoE follows the mixtral architecture (once the input_linear layers are split into gate_exps/up_exps). The main delta is the addition of the same four multipliers used in Granite. Branch: GraniteMoE Signed-off-by: Gabe Goodhart <[email protected]>
1 parent df08e22 commit 1b235d0

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

src/llama.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ enum llm_arch {
215215
LLM_ARCH_EXAONE,
216216
LLM_ARCH_RWKV6,
217217
LLM_ARCH_GRANITE,
218+
LLM_ARCH_GRANITE_MOE,
218219
LLM_ARCH_UNKNOWN,
219220
};
220221

@@ -266,6 +267,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
266267
{ LLM_ARCH_EXAONE, "exaone" },
267268
{ LLM_ARCH_RWKV6, "rwkv6" },
268269
{ LLM_ARCH_GRANITE, "granite" },
270+
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
269271
{ LLM_ARCH_UNKNOWN, "(unknown)" },
270272
};
271273

@@ -1478,6 +1480,23 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
14781480
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
14791481
},
14801482
},
1483+
{
1484+
LLM_ARCH_GRANITE_MOE,
1485+
{
1486+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1487+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1488+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1489+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1490+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1491+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1492+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1493+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1494+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1495+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1496+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1497+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1498+
},
1499+
},
14811500
{
14821501
LLM_ARCH_UNKNOWN,
14831502
{
@@ -2396,7 +2415,7 @@ struct llama_hparams {
23962415
float f_max_alibi_bias = 0.0f;
23972416
float f_logit_scale = 0.0f;
23982417

2399-
// Additional scale factors (Granite)
2418+
// Additional scale factors (Granite/Granite MoE)
24002419
float f_residual_scale = 0.0f;
24012420
float f_embedding_scale = 0.0f;
24022421
float f_attention_scale = 0.0f;
@@ -6052,6 +6071,7 @@ static void llm_load_hparams(
60526071
}
60536072
} break;
60546073
case LLM_ARCH_GRANITE:
6074+
case LLM_ARCH_GRANITE_MOE:
60556075
{
60566076
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
60576077
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
@@ -6060,6 +6080,7 @@ static void llm_load_hparams(
60606080
ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);
60616081

60626082
switch (hparams.n_layer) {
6083+
case 32: model.type = e_model::MODEL_3B; break;
60636084
case 40: model.type = e_model::MODEL_3B; break;
60646085
// Add additional layer/vocab/etc checks here for other model sizes
60656086
default: model.type = e_model::MODEL_UNKNOWN;
@@ -6764,7 +6785,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
67646785
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
67656786
}
67666787

6767-
if (model.arch == LLM_ARCH_GRANITE) {
6788+
if (model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) {
67686789
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
67696790
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
67706791
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
@@ -6938,6 +6959,7 @@ static bool llm_load_tensors(
69386959
case LLM_ARCH_REFACT:
69396960
case LLM_ARCH_MINICPM:
69406961
case LLM_ARCH_GRANITE:
6962+
case LLM_ARCH_GRANITE_MOE:
69416963
{
69426964
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
69436965

@@ -15865,6 +15887,7 @@ static struct ggml_cgraph * llama_build_graph(
1586515887
switch (model.arch) {
1586615888
case LLM_ARCH_LLAMA:
1586715889
case LLM_ARCH_GRANITE:
15890+
case LLM_ARCH_GRANITE_MOE:
1586815891
{
1586915892
result = llm.build_llama();
1587015893
} break;
@@ -19162,6 +19185,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
1916219185
case LLM_ARCH_DEEPSEEK2:
1916319186
case LLM_ARCH_CHATGLM:
1916419187
case LLM_ARCH_GRANITE:
19188+
case LLM_ARCH_GRANITE_MOE:
1916519189
return LLAMA_ROPE_TYPE_NORM;
1916619190

1916719191
// the pairs of head values are offset by n_rot/2

0 commit comments

Comments
 (0)