Skip to content

Commit ee06e9b

Browse files
committed
weight_before_ffn
1 parent 2a9b29a commit ee06e9b

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

src/llama-graph.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
814814
int il) const {
815815
const int64_t n_embd = cur->ne[0];
816816
const int64_t n_tokens = cur->ne[1];
817-
const bool scale_before_ffn = arch == LLM_ARCH_LLAMA4;
817+
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
818818

819819
ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
820820
cb(logits, "ffn_moe_logits", il);
@@ -875,13 +875,16 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
875875

876876
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
877877

878+
if (weight_before_ffn) {
879+
ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
880+
repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
881+
cur = ggml_mul(ctx0, repeated, weights);
882+
cb(cur, "ffn_moe_weighted", il);
883+
}
884+
878885
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
879886
cb(up, "ffn_moe_up", il);
880887

881-
if (scale_before_ffn) {
882-
up = ggml_mul(ctx0, up, weights);
883-
}
884-
885888
ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
886889
cb(gate, "ffn_moe_gate", il);
887890

@@ -906,7 +909,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
906909
ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
907910
cb(experts, "ffn_moe_down", il);
908911

909-
if (!scale_before_ffn) {
912+
if (!weight_before_ffn) {
910913
experts = ggml_mul(ctx0, experts, weights);
911914
}
912915

0 commit comments

Comments
 (0)