Skip to content

Commit 0ad970b

Browse files
2015arorasarthw
authored andcommitted
llama : support OLMoE (ggml-org#9462)
1 parent 4ef14f0 commit 0ad970b

File tree

5 files changed

+298
-15
lines changed

5 files changed

+298
-15
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ Typically finetunes of the base models below are supported as well.
111111
- [x] [SEA-LION](https://huggingface.co/models?search=sea-lion)
112112
- [x] [GritLM-7B](https://huggingface.co/GritLM/GritLM-7B) + [GritLM-8x7B](https://huggingface.co/GritLM/GritLM-8x7B)
113113
- [x] [OLMo](https://allenai.org/olmo)
114+
- [x] [OLMoE](https://huggingface.co/allenai/OLMoE-1B-7B-0924)
114115
- [x] [Granite models](https://huggingface.co/collections/ibm-granite/granite-code-models-6624c5cec322e4c148c8b330)
115116
- [x] [GPT-NeoX](https://github.com/EleutherAI/gpt-neox) + [Pythia](https://github.com/EleutherAI/pythia)
116117
- [x] [Snowflake-Arctic MoE](https://huggingface.co/collections/Snowflake/arctic-66290090abe542894a5ac520)

convert_hf_to_gguf.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2998,6 +2998,66 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
29982998
return [(self.map_tensor_name(name), data_torch)]
29992999

30003000

3001+
@Model.register("OlmoeForCausalLM")
3002+
class OlmoeModel(Model):
3003+
model_arch = gguf.MODEL_ARCH.OLMOE
3004+
3005+
def set_gguf_parameters(self):
3006+
super().set_gguf_parameters()
3007+
self.gguf_writer.add_layer_norm_rms_eps(1e-5)
3008+
if (n_experts := self.hparams.get("num_experts")) is not None:
3009+
self.gguf_writer.add_expert_count(n_experts)
3010+
3011+
_experts: list[dict[str, Tensor]] | None = None
3012+
3013+
# Copied from: Qwen2MoeModel
3014+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3015+
# process the experts separately
3016+
if name.find("experts") != -1:
3017+
n_experts = self.hparams["num_experts"]
3018+
assert bid is not None
3019+
3020+
if self._experts is None:
3021+
self._experts = [{} for _ in range(self.block_count)]
3022+
3023+
self._experts[bid][name] = data_torch
3024+
3025+
if len(self._experts[bid]) >= n_experts * 3:
3026+
tensors: list[tuple[str, Tensor]] = []
3027+
3028+
# merge the experts into a single 3d tensor
3029+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
3030+
datas: list[Tensor] = []
3031+
3032+
for xid in range(n_experts):
3033+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
3034+
datas.append(self._experts[bid][ename])
3035+
del self._experts[bid][ename]
3036+
3037+
data_torch = torch.stack(datas, dim=0)
3038+
3039+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
3040+
3041+
new_name = self.map_tensor_name(merged_name)
3042+
3043+
tensors.append((new_name, data_torch))
3044+
return tensors
3045+
else:
3046+
return []
3047+
3048+
return [(self.map_tensor_name(name), data_torch)]
3049+
3050+
# Copied from: Qwen2MoeModel
3051+
def prepare_tensors(self):
3052+
super().prepare_tensors()
3053+
3054+
if self._experts is not None:
3055+
# flatten `list[dict[str, Tensor]]` into `list[str]`
3056+
experts = [k for d in self._experts for k in d.keys()]
3057+
if len(experts) > 0:
3058+
raise ValueError(f"Unprocessed experts: {experts}")
3059+
3060+
30013061
@Model.register("JinaBertModel", "JinaBertForMaskedLM")
30023062
class JinaBertV2Model(BertModel):
30033063
model_arch = gguf.MODEL_ARCH.JINA_BERT_V2

gguf-py/gguf/constants.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ class MODEL_ARCH(IntEnum):
220220
COMMAND_R = auto()
221221
DBRX = auto()
222222
OLMO = auto()
223+
OLMOE = auto()
223224
OPENELM = auto()
224225
ARCTIC = auto()
225226
DEEPSEEK2 = auto()
@@ -375,6 +376,7 @@ class MODEL_TENSOR(IntEnum):
375376
MODEL_ARCH.COMMAND_R: "command-r",
376377
MODEL_ARCH.DBRX: "dbrx",
377378
MODEL_ARCH.OLMO: "olmo",
379+
MODEL_ARCH.OLMOE: "olmoe",
378380
MODEL_ARCH.OPENELM: "openelm",
379381
MODEL_ARCH.ARCTIC: "arctic",
380382
MODEL_ARCH.DEEPSEEK2: "deepseek2",
@@ -1027,6 +1029,23 @@ class MODEL_TENSOR(IntEnum):
10271029
MODEL_TENSOR.FFN_DOWN,
10281030
MODEL_TENSOR.FFN_UP,
10291031
],
1032+
MODEL_ARCH.OLMOE: [
1033+
MODEL_TENSOR.TOKEN_EMBD,
1034+
MODEL_TENSOR.OUTPUT_NORM,
1035+
MODEL_TENSOR.OUTPUT,
1036+
MODEL_TENSOR.ATTN_OUT,
1037+
MODEL_TENSOR.ATTN_Q,
1038+
MODEL_TENSOR.ATTN_K,
1039+
MODEL_TENSOR.ATTN_V,
1040+
MODEL_TENSOR.ATTN_NORM,
1041+
MODEL_TENSOR.ATTN_Q_NORM,
1042+
MODEL_TENSOR.ATTN_K_NORM,
1043+
MODEL_TENSOR.FFN_NORM,
1044+
MODEL_TENSOR.FFN_GATE_INP,
1045+
MODEL_TENSOR.FFN_GATE_EXP,
1046+
MODEL_TENSOR.FFN_UP_EXP,
1047+
MODEL_TENSOR.FFN_DOWN_EXP,
1048+
],
10301049
MODEL_ARCH.OPENELM: [
10311050
MODEL_TENSOR.TOKEN_EMBD,
10321051
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class TensorNameMap:
1313
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
1414
"transformer.word_embeddings", # falcon
1515
"word_embeddings", # bloom
16-
"model.embed_tokens", # llama-hf nemotron
16+
"model.embed_tokens", # llama-hf nemotron olmoe
1717
"tok_embeddings", # llama-pth
1818
"embeddings.word_embeddings", # bert nomic-bert
1919
"language_model.embedding.word_embeddings", # persimmon
@@ -54,7 +54,7 @@ class TensorNameMap:
5454
# Output
5555
MODEL_TENSOR.OUTPUT: (
5656
"embed_out", # gptneox
57-
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone
57+
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe
5858
"output", # llama-pth bloom internlm2
5959
"word_embeddings_for_head", # persimmon
6060
"lm_head.linear", # phi2
@@ -66,7 +66,7 @@ class TensorNameMap:
6666
MODEL_TENSOR.OUTPUT_NORM: (
6767
"gpt_neox.final_layer_norm", # gptneox
6868
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
69-
"model.norm", # llama-hf baichuan internlm2
69+
"model.norm", # llama-hf baichuan internlm2 olmoe
7070
"norm", # llama-pth
7171
"transformer.norm_f", # mpt dbrx
7272
"ln_f", # refact bloom qwen gpt2
@@ -98,7 +98,7 @@ class TensorNameMap:
9898
"transformer.h.{bid}.input_layernorm", # falcon7b
9999
"h.{bid}.input_layernorm", # bloom
100100
"transformer.h.{bid}.ln_mlp", # falcon40b
101-
"model.layers.{bid}.input_layernorm", # llama-hf nemotron
101+
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe
102102
"layers.{bid}.attention_norm", # llama-pth
103103
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
104104
"model.layers.{bid}.ln1", # yi
@@ -142,7 +142,7 @@ class TensorNameMap:
142142

143143
# Attention query
144144
MODEL_TENSOR.ATTN_Q: (
145-
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron
145+
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe
146146
"layers.{bid}.attention.wq", # llama-pth
147147
"encoder.layer.{bid}.attention.self.query", # bert
148148
"transformer.h.{bid}.attn.q_proj", # gpt-j
@@ -154,7 +154,7 @@ class TensorNameMap:
154154

155155
# Attention key
156156
MODEL_TENSOR.ATTN_K: (
157-
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron
157+
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe
158158
"layers.{bid}.attention.wk", # llama-pth
159159
"encoder.layer.{bid}.attention.self.key", # bert
160160
"transformer.h.{bid}.attn.k_proj", # gpt-j
@@ -167,7 +167,7 @@ class TensorNameMap:
167167

168168
# Attention value
169169
MODEL_TENSOR.ATTN_V: (
170-
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron
170+
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe
171171
"layers.{bid}.attention.wv", # llama-pth
172172
"encoder.layer.{bid}.attention.self.value", # bert
173173
"transformer.h.{bid}.attn.v_proj", # gpt-j
@@ -185,7 +185,7 @@ class TensorNameMap:
185185
"transformer.blocks.{bid}.attn.out_proj", # mpt
186186
"transformer.h.{bid}.self_attention.dense", # falcon
187187
"h.{bid}.self_attention.dense", # bloom
188-
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron
188+
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe
189189
"layers.{bid}.attention.wo", # llama-pth
190190
"encoder.layer.{bid}.attention.output.dense", # bert
191191
"transformer.h.{bid}.attn.out_proj", # gpt-j
@@ -229,7 +229,7 @@ class TensorNameMap:
229229
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
230230
"h.{bid}.post_attention_layernorm", # bloom
231231
"transformer.blocks.{bid}.norm_2", # mpt
232-
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron
232+
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe
233233
"layers.{bid}.ffn_norm", # llama-pth
234234
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
235235
"model.layers.{bid}.ln2", # yi
@@ -253,7 +253,7 @@ class TensorNameMap:
253253
MODEL_TENSOR.FFN_GATE_INP: (
254254
"layers.{bid}.feed_forward.gate", # mixtral
255255
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
256-
"model.layers.{bid}.mlp.gate", # qwen2moe
256+
"model.layers.{bid}.mlp.gate", # qwen2moe olmoe
257257
"transformer.decoder_layer.{bid}.router", # Grok
258258
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
259259
),
@@ -295,7 +295,7 @@ class TensorNameMap:
295295
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
296296
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
297297
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
298-
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe (merged)
298+
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
299299
),
300300

301301
MODEL_TENSOR.FFN_UP_SHEXP: (
@@ -327,7 +327,7 @@ class TensorNameMap:
327327
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
328328
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
329329
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
330-
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe (merged)
330+
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
331331
),
332332

333333
MODEL_TENSOR.FFN_GATE_SHEXP: (
@@ -367,7 +367,7 @@ class TensorNameMap:
367367
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
368368
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
369369
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
370-
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe (merged)
370+
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
371371
),
372372

373373
MODEL_TENSOR.FFN_DOWN_SHEXP: (
@@ -378,7 +378,7 @@ class TensorNameMap:
378378
MODEL_TENSOR.ATTN_Q_NORM: (
379379
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
380380
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
381-
"model.layers.{bid}.self_attn.q_norm", # cohere
381+
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe
382382
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
383383
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
384384
"transformer.layers.{bid}.attn.q_norm", # openelm
@@ -387,7 +387,7 @@ class TensorNameMap:
387387
MODEL_TENSOR.ATTN_K_NORM: (
388388
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
389389
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
390-
"model.layers.{bid}.self_attn.k_norm", # cohere
390+
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe
391391
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
392392
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
393393
"transformer.layers.{bid}.attn.k_norm", # openelm

0 commit comments

Comments
 (0)