@@ -814,7 +814,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
814
814
int il) const {
815
815
const int64_t n_embd = cur->ne [0 ];
816
816
const int64_t n_tokens = cur->ne [1 ];
817
- const bool scale_before_ffn = arch == LLM_ARCH_LLAMA4;
817
+ const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
818
818
819
819
ggml_tensor * logits = build_lora_mm (gate_inp, cur); // [n_expert, n_tokens]
820
820
cb (logits, " ffn_moe_logits" , il);
@@ -875,13 +875,16 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
875
875
876
876
cur = ggml_reshape_3d (ctx0, cur, n_embd, 1 , n_tokens);
877
877
878
+ if (weight_before_ffn) {
879
+ ggml_tensor * repeated = ggml_new_tensor_3d (ctx0, cur->type , n_embd, n_expert_used, n_tokens);
880
+ repeated = ggml_repeat (ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
881
+ cur = ggml_mul (ctx0, repeated, weights);
882
+ cb (cur, " ffn_moe_weighted" , il);
883
+ }
884
+
878
885
ggml_tensor * up = build_lora_mm_id (up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
879
886
cb (up, " ffn_moe_up" , il);
880
887
881
- if (scale_before_ffn) {
882
- up = ggml_mul (ctx0, up, weights);
883
- }
884
-
885
888
ggml_tensor * gate = build_lora_mm_id (gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
886
889
cb (gate, " ffn_moe_gate" , il);
887
890
@@ -906,7 +909,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
906
909
ggml_tensor * experts = build_lora_mm_id (down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
907
910
cb (experts, " ffn_moe_down" , il);
908
911
909
- if (!scale_before_ffn ) {
912
+ if (!weight_before_ffn ) {
910
913
experts = ggml_mul (ctx0, experts, weights);
911
914
}
912
915
0 commit comments