Skip to content

Commit 7fbb7a1

Browse files
committed
fix: QWEN2MOE support for expert_feed_forward_length
previously, expert ff was taken from n_ff (intermediate size) but it is now properly taken from LLM_KV_EXPERT_FEED_FORWARD_LENGTH n_ff_exp and n_ff_shexp are now properly calculated
1 parent 06531cb commit 7fbb7a1

File tree

4 files changed

+47
-34
lines changed

4 files changed

+47
-34
lines changed

convert-hf-to-gguf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,6 +1630,9 @@ def set_gguf_parameters(self):
16301630
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
16311631
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
16321632
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
1633+
if (shared_expert_intermediate_size := self.hparams.get('shared_expert_intermediate_size')) is not None:
1634+
self.gguf_writer.add_shared_expert_feed_forward_length(shared_expert_intermediate_size)
1635+
logger.info(f"gguf: shared expert feed forward length = {shared_expert_intermediate_size}")
16331636

16341637
_experts: list[dict[str, Tensor]] | None = None
16351638

gguf-py/gguf/constants.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,22 @@ class General:
3333
FILE_TYPE = "general.file_type"
3434

3535
class LLM:
36-
VOCAB_SIZE = "{arch}.vocab_size"
37-
CONTEXT_LENGTH = "{arch}.context_length"
38-
EMBEDDING_LENGTH = "{arch}.embedding_length"
39-
BLOCK_COUNT = "{arch}.block_count"
40-
LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
41-
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
42-
EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
43-
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
44-
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
45-
EXPERT_COUNT = "{arch}.expert_count"
46-
EXPERT_USED_COUNT = "{arch}.expert_used_count"
47-
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
48-
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
49-
POOLING_TYPE = "{arch}.pooling_type"
50-
LOGIT_SCALE = "{arch}.logit_scale"
36+
VOCAB_SIZE = "{arch}.vocab_size"
37+
CONTEXT_LENGTH = "{arch}.context_length"
38+
EMBEDDING_LENGTH = "{arch}.embedding_length"
39+
BLOCK_COUNT = "{arch}.block_count"
40+
LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
41+
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
42+
EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
43+
SHARED_EXPERT_FEED_FORWARD_LENGTH = "{arch}.shared_expert_feed_forward_length"
44+
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
45+
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
46+
EXPERT_COUNT = "{arch}.expert_count"
47+
EXPERT_USED_COUNT = "{arch}.expert_used_count"
48+
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
49+
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
50+
POOLING_TYPE = "{arch}.pooling_type"
51+
LOGIT_SCALE = "{arch}.logit_scale"
5152

5253
class Attention:
5354
HEAD_COUNT = "{arch}.attention.head_count"

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,9 @@ def add_feed_forward_length(self, length: int) -> None:
383383
def add_expert_feed_forward_length(self, length: int) -> None:
384384
self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
385385

386+
def add_shared_expert_feed_forward_length(self, length: int) -> None:
387+
self.add_uint32(Keys.LLM.SHARED_EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
388+
386389
def add_parallel_residual(self, use: bool) -> None:
387390
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
388391

llama.cpp

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ enum llm_kv {
282282
LLM_KV_LEADING_DENSE_BLOCK_COUNT,
283283
LLM_KV_FEED_FORWARD_LENGTH,
284284
LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
285+
LLM_KV_SHARED_EXPERT_FEED_FORWARD_LENGTH,
285286
LLM_KV_USE_PARALLEL_RESIDUAL,
286287
LLM_KV_TENSOR_DATA_LAYOUT,
287288
LLM_KV_EXPERT_COUNT,
@@ -360,21 +361,22 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
360361
{ LLM_KV_GENERAL_SOURCE_URL, "general.source.url" },
361362
{ LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" },
362363

363-
{ LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
364-
{ LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
365-
{ LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
366-
{ LLM_KV_BLOCK_COUNT, "%s.block_count" },
367-
{ LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" },
368-
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
369-
{ LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" },
370-
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
371-
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
372-
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
373-
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
374-
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
375-
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
376-
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
377-
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
364+
{ LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
365+
{ LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
366+
{ LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
367+
{ LLM_KV_BLOCK_COUNT, "%s.block_count" },
368+
{ LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" },
369+
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
370+
{ LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" },
371+
{ LLM_KV_SHARED_EXPERT_FEED_FORWARD_LENGTH, "%s.shared_expert_feed_forward_length" },
372+
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
373+
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
374+
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
375+
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
376+
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
377+
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
378+
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
379+
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
378380

379381
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
380382
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -1840,6 +1842,7 @@ struct llama_hparams {
18401842
uint32_t n_lora_q = 0;
18411843
uint32_t n_lora_kv = 0;
18421844
uint32_t n_ff_exp = 0;
1845+
uint32_t n_ff_shexp = 0;
18431846
uint32_t n_expert_shared = 0;
18441847
float expert_weights_scale = 0.0;
18451848

@@ -1888,6 +1891,7 @@ struct llama_hparams {
18881891
if (this->n_lora_q != other.n_lora_q) return true;
18891892
if (this->n_lora_kv != other.n_lora_kv) return true;
18901893
if (this->n_ff_exp != other.n_ff_exp) return true;
1894+
if (this->n_ff_shexp != other.n_ff_shexp) return true;
18911895
if (this->n_expert_shared != other.n_expert_shared) return true;
18921896

18931897
if (this->rope_finetuned != other.rope_finetuned) return true;
@@ -4248,6 +4252,7 @@ static void llm_load_hparams(
42484252
case LLM_ARCH_QWEN2MOE:
42494253
{
42504254
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
4255+
ml.get_key(LLM_KV_SHARED_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
42514256

42524257
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
42534258
switch (hparams.n_layer) {
@@ -5024,6 +5029,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
50245029

50255030
if (model.arch == LLM_ARCH_QWEN2MOE) {
50265031
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
5032+
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
50275033
}
50285034
}
50295035

@@ -5817,11 +5823,11 @@ static bool llm_load_tensors(
58175823
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
58185824

58195825
// Shared expert branch
5820-
auto n_ff_shared_exp = hparams.n_ff_exp && hparams.n_expert_used ? hparams.n_ff_exp * hparams.n_expert_used : n_ff;
5826+
auto n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
58215827
layer.ffn_gate_inp_shexp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd});
5822-
layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shared_exp});
5823-
layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shared_exp, n_embd});
5824-
layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shared_exp});
5828+
layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp});
5829+
layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd});
5830+
layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp});
58255831
}
58265832
} break;
58275833
case LLM_ARCH_PHI2:

0 commit comments

Comments
 (0)