Skip to content

Commit edd2885

Browse files
committed
fix(granitemoe convert): Split the double-sized input layer into gate and up
After a lot of staring and squinting, it's clear that the standard mixtral expert implementation is equivalent to the vectorized parallel experts in granite. The difference is that in granite, the w1 and w3 are concatenated into a single tensor "input_linear." Rather than reimplementing all of the math on the llama.cpp side, the much simpler route is to just split this tensor during conversion and follow the standard mixtral route. Branch: GraniteMoE Signed-off-by: Gabe Goodhart <[email protected]>
1 parent cdabf89 commit edd2885

File tree

3 files changed

+29
-9
lines changed

3 files changed

+29
-9
lines changed

convert_hf_to_gguf.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3987,8 +3987,26 @@ class GraniteMoeModel(GraniteModel):
39873987
"""Conversion for IBM's GraniteMoeForCausalLM"""
39883988
model_arch = gguf.MODEL_ARCH.GRANITE_MOE
39893989

3990+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3991+
"""In modeling_granitemoe, the JetMoe implementation of parallel experts
3992+
is used. This essentially merges w1 and w3 into a single tensor with 2x
3993+
the hidden size that is then split during forward. To keep compativility
3994+
with existing mixtral support, we pull them apart here.
3995+
"""
3996+
3997+
if name.endswith("block_sparse_moe.input_linear.weight"):
3998+
gate, up = data_torch.chunk(2, dim=-2)
3999+
return [
4000+
(self.map_tensor_name(f"model.layers.{bid}.block_sparse_moe.input_linear.gate.weight"), gate),
4001+
(self.map_tensor_name(f"model.layers.{bid}.block_sparse_moe.input_linear.up.weight"), up),
4002+
]
4003+
4004+
return super().modify_tensors(data_torch, name, bid)
4005+
4006+
39904007
###### CONVERSION LOGIC ######
39914008

4009+
39924010
# tree of lazy tensors
39934011
class LazyTorchTensor(gguf.LazyBase):
39944012
_tensor_type = torch.Tensor

gguf-py/gguf/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,6 +1216,7 @@ class MODEL_TENSOR(IntEnum):
12161216
MODEL_TENSOR.ATTN_OUT,
12171217
MODEL_TENSOR.FFN_NORM,
12181218
MODEL_TENSOR.FFN_GATE_INP,
1219+
MODEL_TENSOR.FFN_GATE_EXP,
12191220
MODEL_TENSOR.FFN_DOWN_EXP,
12201221
MODEL_TENSOR.FFN_UP_EXP,
12211222
],

gguf-py/gguf/tensor_mapping.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -293,11 +293,11 @@ class TensorNameMap:
293293
),
294294

295295
MODEL_TENSOR.FFN_UP_EXP: (
296-
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
297-
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
298-
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
299-
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe (merged)
300-
"model.layers.{bid}.block_sparse_moe.input_linear", # granitemoe
296+
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
297+
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
298+
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
299+
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe (merged)
300+
"model.layers.{bid}.block_sparse_moe.input_linear.up", # granitemoe
301301
),
302302

303303
MODEL_TENSOR.FFN_UP_SHEXP: (
@@ -326,10 +326,11 @@ class TensorNameMap:
326326
),
327327

328328
MODEL_TENSOR.FFN_GATE_EXP: (
329-
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
330-
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
331-
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
332-
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe (merged)
329+
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
330+
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
331+
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
332+
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe (merged)
333+
"model.layers.{bid}.block_sparse_moe.input_linear.gate", # granitemoe
333334
),
334335

335336
MODEL_TENSOR.FFN_GATE_SHEXP: (

0 commit comments

Comments
 (0)