Skip to content

Commit 8e2566a

Browse files
manyosocebtenzzre
andcommitted
Nomic Embed Text V2 with Mixture-of-Experts (MoE) architecture
- Adds MoE-based embedding model supporting multilingual embeddings. - Selects architecture variant based on hyperparameter detection (MoE layers). - Removes unnecessary subclass initialization checks for clarity. https://www.nomic.ai/blog/posts/nomic-embed-text-v2 Co-authored-by: Jared Van Bortel <[email protected]>
1 parent ecda2ec commit 8e2566a

File tree

9 files changed

+150
-37
lines changed

9 files changed

+150
-37
lines changed

convert_hf_to_gguf.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class ModelBase:
7878
# subclasses should define this!
7979
model_arch: gguf.MODEL_ARCH
8080

81-
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
81+
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
8282
use_temp_file: bool = False, eager: bool = False,
8383
metadata_override: Path | None = None, model_name: str | None = None,
8484
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
@@ -454,13 +454,6 @@ def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type
454454

455455

456456
class TextModel(ModelBase):
457-
@classmethod
458-
def __init_subclass__(cls):
459-
# can't use an abstract property, because overriding it without type errors
460-
# would require using decorated functions instead of simply defining the property
461-
if "model_arch" not in cls.__dict__:
462-
raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}")
463-
464457
def set_vocab(self):
465458
self._set_vocab_gpt2()
466459

@@ -3420,32 +3413,58 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
34203413

34213414
@ModelBase.register("NomicBertModel")
34223415
class NomicBertModel(BertModel):
3423-
model_arch = gguf.MODEL_ARCH.NOMIC_BERT
3416+
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
3417+
hparams = kwargs.pop("hparams", None)
3418+
if hparams is None:
3419+
hparams = ModelBase.load_hparams(dir_model)
34243420

3425-
def __init__(self, *args, **kwargs):
3426-
super().__init__(*args, **kwargs)
3421+
self.is_moe = bool(hparams.get("moe_every_n_layers"))
3422+
self.model_arch = gguf.MODEL_ARCH.NOMIC_BERT_MOE if self.is_moe else gguf.MODEL_ARCH.NOMIC_BERT
3423+
3424+
super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs)
34273425

34283426
# the HF config claims n_ctx=8192, but it uses RoPE scaling
34293427
self.hparams["n_ctx"] = 2048
34303428

3431-
# SwigLU activation
3432-
assert self.hparams["activation_function"] == "swiglu"
3429+
assert self.hparams["activation_function"] == "gelu" if self.is_moe else "swiglu"
3430+
34333431
# this doesn't do anything in the HF version
34343432
assert self.hparams["causal"] is False
3435-
# no bias tensors
3436-
assert self.hparams["qkv_proj_bias"] is False
3437-
assert self.hparams["mlp_fc1_bias"] is False
3438-
assert self.hparams["mlp_fc2_bias"] is False
3433+
# no bias tensors unless MoE
3434+
assert self.hparams["qkv_proj_bias"] == self.is_moe
3435+
assert self.hparams["mlp_fc1_bias"] == self.is_moe
3436+
assert self.hparams["mlp_fc2_bias"] == self.is_moe
3437+
34393438
# norm at end of layer
34403439
assert self.hparams["prenorm"] is False
34413440
# standard RoPE
34423441
assert self.hparams["rotary_emb_fraction"] == 1.0
34433442
assert self.hparams["rotary_emb_interleaved"] is False
34443443
assert self.hparams["rotary_emb_scale_base"] is None
34453444

3445+
def modify_tensors(self, data_torch: torch.Tensor, name: str, bid: int | None) -> Iterable[tuple[str, torch.Tensor]]:
3446+
# If the tensor is an experts bias tensor, skip it by returning an empty list.
3447+
if "mlp.experts.bias" in name:
3448+
return [] # Explicitly return an empty list.
3449+
3450+
if "mlp.experts.mlp.w1" in name:
3451+
data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"])
3452+
name += ".weight"
3453+
3454+
if "mlp.experts.mlp.w2" in name:
3455+
data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"])
3456+
data_torch = data_torch.transpose(1, 2)
3457+
name += ".weight"
3458+
3459+
return [(self.map_tensor_name(name), data_torch)]
3460+
34463461
def set_gguf_parameters(self):
34473462
super().set_gguf_parameters()
34483463
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
3464+
if self.is_moe:
3465+
self.gguf_writer.add_moe_every_n_layers(self.hparams["moe_every_n_layers"])
3466+
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
3467+
self.gguf_writer.add_expert_used_count(self.hparams["moe_top_k"])
34493468

34503469

34513470
@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")

gguf-py/gguf/constants.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class LLM:
104104
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
105105
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
106106
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
107+
MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers"
107108
POOLING_TYPE = "{arch}.pooling_type"
108109
LOGIT_SCALE = "{arch}.logit_scale"
109110
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
@@ -267,6 +268,7 @@ class MODEL_ARCH(IntEnum):
267268
REFACT = auto()
268269
BERT = auto()
269270
NOMIC_BERT = auto()
271+
NOMIC_BERT_MOE = auto()
270272
JINA_BERT_V2 = auto()
271273
BLOOM = auto()
272274
STABLELM = auto()
@@ -521,6 +523,7 @@ class MODEL_TENSOR(IntEnum):
521523
MODEL_ARCH.REFACT: "refact",
522524
MODEL_ARCH.BERT: "bert",
523525
MODEL_ARCH.NOMIC_BERT: "nomic-bert",
526+
MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe",
524527
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
525528
MODEL_ARCH.BLOOM: "bloom",
526529
MODEL_ARCH.STABLELM: "stablelm",
@@ -960,6 +963,22 @@ class MODEL_TENSOR(IntEnum):
960963
MODEL_TENSOR.FFN_UP,
961964
MODEL_TENSOR.LAYER_OUT_NORM,
962965
],
966+
MODEL_ARCH.NOMIC_BERT_MOE: [
967+
MODEL_TENSOR.TOKEN_EMBD,
968+
MODEL_TENSOR.TOKEN_EMBD_NORM,
969+
MODEL_TENSOR.TOKEN_TYPES,
970+
MODEL_TENSOR.POS_EMBD,
971+
MODEL_TENSOR.OUTPUT_NORM,
972+
MODEL_TENSOR.ATTN_OUT_NORM,
973+
MODEL_TENSOR.ATTN_QKV,
974+
MODEL_TENSOR.ATTN_OUT,
975+
MODEL_TENSOR.FFN_DOWN,
976+
MODEL_TENSOR.FFN_UP,
977+
MODEL_TENSOR.FFN_GATE_INP,
978+
MODEL_TENSOR.FFN_DOWN_EXP,
979+
MODEL_TENSOR.FFN_UP_EXP,
980+
MODEL_TENSOR.LAYER_OUT_NORM,
981+
],
963982
MODEL_ARCH.JINA_BERT_V2: [
964983
MODEL_TENSOR.TOKEN_EMBD,
965984
MODEL_TENSOR.TOKEN_EMBD_NORM,

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,9 @@ def add_expert_weights_norm(self, value: bool) -> None:
728728
def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
729729
self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)
730730

731+
def add_moe_every_n_layers(self, value: int) -> None:
732+
self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)
733+
731734
def add_swin_norm(self, value: bool) -> None:
732735
self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
733736

gguf-py/gguf/tensor_mapping.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ class TensorNameMap:
290290
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
291291
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
292292
"language_model.model.layers.{bid}.feed_forward.router", # llama4
293+
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
293294
),
294295

295296
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@@ -322,6 +323,7 @@ class TensorNameMap:
322323
"model.layers.layers.{bid}.mlp.up_proj", # plamo
323324
"model.layers.{bid}.feed_forward.w3", # internlm2
324325
"encoder.layers.{bid}.mlp.fc11", # nomic-bert
326+
"encoder.layers.{bid}.mlp.fc1", # nomic-bert-moe
325327
"model.layers.{bid}.mlp.c_fc", # starcoder2
326328
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
327329
"model.layers.{bid}.residual_mlp.w3", # arctic
@@ -337,6 +339,7 @@ class TensorNameMap:
337339
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
338340
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
339341
"language_model.model.layers.{bid}.feed_forward.experts.up_proj", # llama4
342+
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
340343
),
341344

342345
MODEL_TENSOR.FFN_UP_SHEXP: (
@@ -418,6 +421,7 @@ class TensorNameMap:
418421
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
419422
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
420423
"language_model.model.layers.{bid}.feed_forward.experts.down_proj", # llama4
424+
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
421425
),
422426

423427
MODEL_TENSOR.FFN_DOWN_SHEXP: (

src/llama-arch.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
1919
{ LLM_ARCH_REFACT, "refact" },
2020
{ LLM_ARCH_BERT, "bert" },
2121
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
22+
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
2223
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
2324
{ LLM_ARCH_BLOOM, "bloom" },
2425
{ LLM_ARCH_STABLELM, "stablelm" },
@@ -106,6 +107,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
106107
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
107108
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
108109
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
110+
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
109111
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
110112
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
111113
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
@@ -472,6 +474,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
472474
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
473475
},
474476
},
477+
{
478+
LLM_ARCH_NOMIC_BERT_MOE,
479+
{
480+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
481+
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
482+
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
483+
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
484+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
485+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
486+
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
487+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
488+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
489+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
490+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
491+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
492+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
493+
},
494+
},
475495
{
476496
LLM_ARCH_JINA_BERT_V2,
477497
{

src/llama-arch.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ enum llm_arch {
2323
LLM_ARCH_REFACT,
2424
LLM_ARCH_BERT,
2525
LLM_ARCH_NOMIC_BERT,
26+
LLM_ARCH_NOMIC_BERT_MOE,
2627
LLM_ARCH_JINA_BERT_V2,
2728
LLM_ARCH_BLOOM,
2829
LLM_ARCH_STABLELM,
@@ -110,6 +111,7 @@ enum llm_kv {
110111
LLM_KV_EXPERT_WEIGHTS_SCALE,
111112
LLM_KV_EXPERT_WEIGHTS_NORM,
112113
LLM_KV_EXPERT_GATING_FUNC,
114+
LLM_KV_MOE_EVERY_N_LAYERS,
113115
LLM_KV_POOLING_TYPE,
114116
LLM_KV_LOGIT_SCALE,
115117
LLM_KV_DECODER_START_TOKEN_ID,

src/llama-graph.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -907,31 +907,38 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
907907
cb(cur, "ffn_moe_weighted", il);
908908
}
909909

910-
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
911-
cb(up, "ffn_moe_up", il);
910+
ggml_tensor * tmp = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
911+
cb(tmp, "ffn_moe_up", il);
912912

913-
ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
914-
cb(gate, "ffn_moe_gate", il);
913+
ggml_tensor * experts = nullptr;
914+
if (gate_exps) {
915+
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
916+
cb(cur, "ffn_moe_gate", il);
917+
} else {
918+
cur = tmp;
919+
}
915920

916921
switch (type_op) {
917922
case LLM_FFN_SILU:
918923
{
919-
gate = ggml_silu(ctx0, gate);
920-
cb(gate, "ffn_moe_silu", il);
924+
cur = ggml_silu(ctx0, cur);
925+
cb(cur, "ffn_moe_silu", il);
921926
} break;
922927
case LLM_FFN_GELU:
923928
{
924-
gate = ggml_gelu(ctx0, gate);
925-
cb(gate, "ffn_moe_gelu", il);
929+
cur = ggml_gelu(ctx0, cur);
930+
cb(cur, "ffn_moe_gelu", il);
926931
} break;
927932
default:
928933
GGML_ABORT("fatal error");
929934
}
930935

931-
ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
932-
cb(par, "ffn_moe_gate_par", il);
936+
if (gate_exps) {
937+
cur = ggml_mul(ctx0, cur, tmp); // [n_ff, n_expert_used, n_tokens]
938+
cb(cur, "ffn_moe_gate_par", il);
939+
}
933940

934-
ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
941+
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
935942
cb(experts, "ffn_moe_down", il);
936943

937944
if (!weight_before_ffn) {

src/llama-hparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ struct llama_hparams {
6666
float expert_weights_scale = 0.0;
6767
bool expert_weights_norm = false;
6868
uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
69+
uint32_t moe_every_n_layers = 0;
6970

7071
float f_norm_eps;
7172
float f_norm_rms_eps;

0 commit comments

Comments
 (0)