Skip to content

Commit ec087e9

Browse files
committed
feat: Support GRANITE_MOE_HYBRID in llama-model
This re-uses the Bamba code paths heavily and simply adds the missing parts for loading MoE and the shared expert. Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
1 parent f483313 commit ec087e9

File tree

1 file changed

+40
-9
lines changed

1 file changed

+40
-9
lines changed

src/llama-model.cpp

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,6 +1473,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
14731473
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false);
14741474
} break;
14751475
case LLM_ARCH_BAMBA:
1476+
case LLM_ARCH_GRANITE_MOE_HYBRID:
14761477
{
14771478
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
14781479
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /* required */ false);
@@ -1514,6 +1515,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
15141515
// TODO: Add llm type label (not sure this is useful)
15151516
default: type = LLM_TYPE_UNKNOWN;
15161517
}
1518+
1519+
// For Granite MoE Shared
1520+
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false);
15171521
} break;
15181522
case LLM_ARCH_CHAMELEON:
15191523
{
@@ -3167,6 +3171,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
31673171
}
31683172
} break;
31693173
case LLM_ARCH_BAMBA:
3174+
case LLM_ARCH_GRANITE_MOE_HYBRID:
31703175
{
31713176
// mamba2 Mixer SSM params
31723177
// NOTE: int64_t for tensor dimensions
@@ -3233,14 +3238,31 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
32333238
}
32343239

32353240
// feed forward (w/ optional biases)
3236-
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
3237-
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
3238-
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
3239-
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
3240-
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
3241-
layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
3242-
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
3243-
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
3241+
if (n_expert > 0) {
3242+
// MoE FFN
3243+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
3244+
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
3245+
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
3246+
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED);
3247+
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
3248+
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
3249+
3250+
// For Granite MoE Shared
3251+
if (hparams.n_ff_shexp > 0) {
3252+
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
3253+
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
3254+
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0);
3255+
}
3256+
} else {
3257+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
3258+
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
3259+
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
3260+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
3261+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
3262+
layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
3263+
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
3264+
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
3265+
}
32443266
}
32453267
} break;
32463268
case LLM_ARCH_XVERSE:
@@ -4781,7 +4803,9 @@ void llama_model::print_info() const {
47814803

47824804
if (arch == LLM_ARCH_MINICPM ||
47834805
arch == LLM_ARCH_GRANITE ||
4784-
arch == LLM_ARCH_GRANITE_MOE) {
4806+
arch == LLM_ARCH_GRANITE_MOE ||
4807+
arch == LLM_ARCH_GRANITE_MOE_HYBRID ||
4808+
arch == LLM_ARCH_BAMBA) {
47854809
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
47864810
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
47874811
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
@@ -14577,6 +14601,12 @@ llm_graph_result_ptr llama_model::build_graph(
1457714601
{
1457814602
llm = std::make_unique<llm_build_granite>(*this, params, gf);
1457914603
} break;
14604+
case LLM_ARCH_GRANITE_MOE_HYBRID:
14605+
{
14606+
llm = std::make_unique<llm_build_hybrid_mamba>(*this, params, gf,
14607+
/* use_mamba2 */ true,
14608+
/* use_rope */ false);
14609+
} break;
1458014610
case LLM_ARCH_BAMBA:
1458114611
{
1458214612
llm = std::make_unique<llm_build_hybrid_mamba>(
@@ -14756,6 +14786,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1475614786
case LLM_ARCH_GLM4:
1475714787
case LLM_ARCH_GRANITE:
1475814788
case LLM_ARCH_GRANITE_MOE:
14789+
case LLM_ARCH_GRANITE_MOE_HYBRID:
1475914790
case LLM_ARCH_BAMBA:
1476014791
case LLM_ARCH_CHAMELEON:
1476114792
case LLM_ARCH_BAILINGMOE:

0 commit comments

Comments
 (0)