Skip to content

Commit 2b46899

Browse files
committed
feat: Add conversion for Bamba models
This is borrowed and adapted from the original implementation ggml-org#10810 Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 291ee10 commit 2b46899

File tree

4 files changed

+155
-10
lines changed

4 files changed

+155
-10
lines changed

convert_hf_to_gguf.py

Lines changed: 105 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4685,6 +4685,9 @@ def __init__(self, dir_model: Path, *args, **kwargs):
46854685
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
46864686
hparams = json.load(f)
46874687
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
4688+
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4689+
self.d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * self.d_model
4690+
self.n_group = self.hparams.get("n_groups", 1)
46884691

46894692
def set_vocab(self):
46904693
vocab_size = self.hparams["vocab_size"]
@@ -4755,10 +4758,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
47554758
# (D is also unsqueezed, but for more straightforward broadcast internally)
47564759
data_torch = data_torch.reshape((*data_torch.shape, 1))
47574760
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
4758-
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4759-
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
4760-
n_group = self.hparams.get("n_groups", 1)
4761-
data_torch = data_torch.reshape((n_group, d_inner // n_group))
4761+
data_torch = data_torch.reshape((self.n_group, self.d_inner // self.n_group))
47624762

47634763
if name.endswith(".A_log"):
47644764
logger.debug("A_log --> A ==> " + new_name)
@@ -4767,6 +4767,107 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
47674767
yield (new_name, data_torch)
47684768

47694769

4770+
@ModelBase.register("BambaForCausalLM")
4771+
class BambaModel(Mamba2Model):
4772+
"""Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers"""
4773+
model_arch = gguf.MODEL_ARCH.BAMBA
4774+
undo_permute = True
4775+
4776+
def __init__(self, *args, **kwargs):
4777+
4778+
# Hybrid mamba models use a prefix for the mamba-specific params.
4779+
# TODO: Extend this if the prefix(es) need to be configurable
4780+
self.hparam_prefixes = ["mamba"]
4781+
4782+
super().__init__(*args, **kwargs)
4783+
4784+
# Use Llama conversion for attention
4785+
self._transformer_model_class: type[TextModel] = LlamaModel
4786+
4787+
# Lists of which layers use ssm vs attention
4788+
self._attn_layers = self.hparams.get("attn_layer_indices", [])
4789+
if not self._attn_layers:
4790+
attn_period = self.hparams.get("attn_layer_period")
4791+
assert attn_period, "Didn't find attn_layer_indices or attn_layer_period"
4792+
attn_offset = self.hparams.get("attn_layer_offset")
4793+
assert attn_offset is not None, "No attention layer offset set with attn_layer_period"
4794+
self._attn_layers = [
4795+
i for i in range(self.block_count)
4796+
if i % attn_period == attn_offset
4797+
]
4798+
self._ssm_layers = [
4799+
i for i in range(self.block_count)
4800+
if i not in self._attn_layers
4801+
]
4802+
4803+
# n_group and d_inner are used during reshape_tensors for mamaba2
4804+
self.d_model = self.find_hparam(["hidden_size", "d_model"])
4805+
self.n_group = self.find_hparam(["n_groups"])
4806+
self.d_inner = self.find_hparam(["expand"]) * self.d_model
4807+
4808+
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
4809+
prefixed = []
4810+
for pfx in self.hparam_prefixes:
4811+
prefixed.extend(
4812+
"_".join([pfx, k])
4813+
for k in keys
4814+
)
4815+
keys = list(keys) + prefixed
4816+
return super().find_hparam(keys, *args, **kwargs)
4817+
4818+
def set_gguf_parameters(self):
4819+
4820+
## General Params ##
4821+
self.gguf_writer.add_embedding_length(self.d_model)
4822+
self.gguf_writer.add_block_count(self.block_count)
4823+
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
4824+
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
4825+
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
4826+
4827+
## Mamba mixer params ##
4828+
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
4829+
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
4830+
self.gguf_writer.add_ssm_group_count(self.n_group)
4831+
self.gguf_writer.add_ssm_inner_size(self.d_inner)
4832+
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
4833+
# in llama.cpp
4834+
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
4835+
4836+
## Attention params ##
4837+
self.gguf_writer.add_attn_layer_indices(self._attn_layers)
4838+
self.gguf_writer.add_rope_dimension_count(self.hparams["attn_rotary_emb"])
4839+
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
4840+
self.gguf_writer.add_head_count_kv(self.find_hparam(["num_key_value_heads", "n_head_kv"]))
4841+
4842+
## Feed Forward Params ##
4843+
self.gguf_writer.add_layer_norm_rms_eps(
4844+
self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
4845+
)
4846+
4847+
## Validation ##
4848+
d_head = self.find_hparam(["d_head"], optional=True) or 64
4849+
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
4850+
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"
4851+
4852+
def modify_tensors(
4853+
self, data_torch: Tensor, name: str, bid: int | None
4854+
) -> Iterable[tuple[str, Tensor]]:
4855+
4856+
# Determine whether this is a mamaba layer or an attention layer
4857+
if bid in self._ssm_layers:
4858+
for mamba_new_name, data_torch in super().modify_tensors(
4859+
data_torch, name, bid
4860+
):
4861+
yield mamba_new_name, data_torch
4862+
elif bid in self._attn_layers:
4863+
for llama_new_name, data_torch in self._transformer_model_class.modify_tensors(
4864+
self, data_torch, name, bid
4865+
):
4866+
yield llama_new_name, data_torch
4867+
else:
4868+
yield self.map_tensor_name(name), data_torch
4869+
4870+
47704871
@ModelBase.register("CohereForCausalLM")
47714872
class CommandR2Model(TextModel):
47724873
model_arch = gguf.MODEL_ARCH.COMMAND_R

gguf-py/gguf/constants.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ class SSM:
167167
GROUP_COUNT = "{arch}.ssm.group_count"
168168
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
169169

170+
class HybridAttention:
171+
ATTN_LAYER_INDICES = "{arch}.attention.layer_indices"
172+
170173
class WKV:
171174
HEAD_SIZE = "{arch}.wkv.head_size"
172175

@@ -320,6 +323,7 @@ class MODEL_ARCH(IntEnum):
320323
ARWKV7 = auto()
321324
MAMBA = auto()
322325
MAMBA2 = auto()
326+
BAMBA = auto()
323327
XVERSE = auto()
324328
COMMAND_R = auto()
325329
COHERE2 = auto()
@@ -602,6 +606,7 @@ class MODEL_TENSOR(IntEnum):
602606
MODEL_ARCH.ARWKV7: "arwkv7",
603607
MODEL_ARCH.MAMBA: "mamba",
604608
MODEL_ARCH.MAMBA2: "mamba2",
609+
MODEL_ARCH.BAMBA: "bamba",
605610
MODEL_ARCH.XVERSE: "xverse",
606611
MODEL_ARCH.COMMAND_R: "command-r",
607612
MODEL_ARCH.COHERE2: "cohere2",
@@ -1636,6 +1641,31 @@ class MODEL_TENSOR(IntEnum):
16361641
MODEL_TENSOR.SSM_NORM,
16371642
MODEL_TENSOR.SSM_OUT,
16381643
],
1644+
MODEL_ARCH.BAMBA: [
1645+
MODEL_TENSOR.TOKEN_EMBD,
1646+
MODEL_TENSOR.OUTPUT_NORM,
1647+
MODEL_TENSOR.OUTPUT,
1648+
MODEL_TENSOR.ATTN_NORM,
1649+
MODEL_TENSOR.SSM_IN,
1650+
MODEL_TENSOR.SSM_CONV1D,
1651+
MODEL_TENSOR.SSM_DT,
1652+
MODEL_TENSOR.SSM_A,
1653+
MODEL_TENSOR.SSM_D,
1654+
MODEL_TENSOR.SSM_NORM,
1655+
MODEL_TENSOR.SSM_OUT,
1656+
MODEL_TENSOR.ATTN_Q,
1657+
MODEL_TENSOR.ATTN_K,
1658+
MODEL_TENSOR.ATTN_V,
1659+
MODEL_TENSOR.ATTN_OUT,
1660+
MODEL_TENSOR.FFN_NORM,
1661+
MODEL_TENSOR.FFN_GATE,
1662+
MODEL_TENSOR.FFN_DOWN,
1663+
MODEL_TENSOR.FFN_UP,
1664+
MODEL_TENSOR.FFN_GATE_INP,
1665+
MODEL_TENSOR.FFN_GATE_EXP,
1666+
MODEL_TENSOR.FFN_DOWN_EXP,
1667+
MODEL_TENSOR.FFN_UP_EXP,
1668+
],
16391669
MODEL_ARCH.XVERSE: [
16401670
MODEL_TENSOR.TOKEN_EMBD,
16411671
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,9 @@ def add_ssm_group_count(self, value: int) -> None:
849849
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
850850
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
851851

852+
def add_attn_layer_indices(self, values: list[int]) -> None:
853+
self.add_array(Keys.HybridAttention.ATTN_LAYER_INDICES.format(arch=self.arch), values)
854+
852855
def add_tokenizer_model(self, model: str) -> None:
853856
self.add_string(Keys.Tokenizer.MODEL, model)
854857

gguf-py/gguf/tensor_mapping.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class TensorNameMap:
1313
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
1414
"transformer.word_embeddings", # falcon
1515
"word_embeddings", # bloom
16-
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414
16+
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 bamba
1717
"tok_embeddings", # llama-pth
1818
"embeddings.word_embeddings", # bert nomic-bert
1919
"language_model.embedding.word_embeddings", # persimmon
@@ -117,7 +117,7 @@ class TensorNameMap:
117117
"transformer.h.{bid}.input_layernorm", # falcon7b
118118
"h.{bid}.input_layernorm", # bloom
119119
"transformer.h.{bid}.ln_mlp", # falcon40b
120-
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe
120+
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe bamba
121121
"layers.{bid}.attention_norm", # llama-pth
122122
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
123123
"model.layers.{bid}.ln1", # yi
@@ -275,7 +275,8 @@ class TensorNameMap:
275275
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
276276
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
277277
"transformer.layers.{bid}.ffn_norm", # openelm
278-
"model.layers.{bid}.post_attention_layernorm", # llama4
278+
"language_model.model.layers.{bid}.post_attention_layernorm", # llama4
279+
"model.layers.{bid}.pre_ff_layernorm", # bamba
279280
),
280281

281282
# Post feed-forward norm
@@ -337,7 +338,8 @@ class TensorNameMap:
337338
"model.layers.{bid}.residual_mlp.w3", # arctic
338339
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
339340
"transformer.h.{bid}.mlp.c_fc_1", # exaone
340-
"model.layers.{bid}.feed_forward.up_proj", # llama4
341+
"language_model.model.layers.{bid}.feed_forward.up_proj", # llama4
342+
"model.layers.{bid}.feed_forward.up_proj", # bamba
341343
),
342344

343345
MODEL_TENSOR.FFN_UP_EXP: (
@@ -374,7 +376,8 @@ class TensorNameMap:
374376
"transformer.h.{bid}.mlp.linear_1", # refact
375377
"model.layers.{bid}.residual_mlp.w1", # arctic
376378
"transformer.h.{bid}.mlp.c_fc_0", # exaone
377-
"model.layers.{bid}.feed_forward.gate_proj", # llama4
379+
"language_model.model.layers.{bid}.feed_forward.gate_proj", # llama4
380+
"model.layers.{bid}.feed_forward.gate_proj", # bamba
378381
),
379382

380383
MODEL_TENSOR.FFN_GATE_EXP: (
@@ -419,7 +422,8 @@ class TensorNameMap:
419422
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
420423
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
421424
"model.layers.h.{bid}.mlp.c_proj", # exaone
422-
"model.layers.{bid}.feed_forward.down_proj", # llama4
425+
"language_model.model.layers.{bid}.feed_forward.down_proj", # llama4
426+
"model.layers.{bid}.feed_forward.down_proj", # bamba
423427
),
424428

425429
MODEL_TENSOR.FFN_DOWN_EXP: (
@@ -474,11 +478,13 @@ class TensorNameMap:
474478
MODEL_TENSOR.SSM_IN: (
475479
"model.layers.{bid}.in_proj",
476480
"backbone.layers.{bid}.mixer.in_proj",
481+
"model.layers.{bid}.mamba.in_proj", # bamba
477482
),
478483

479484
MODEL_TENSOR.SSM_CONV1D: (
480485
"model.layers.{bid}.conv1d",
481486
"backbone.layers.{bid}.mixer.conv1d",
487+
"model.layers.{bid}.mamba.conv1d", # bamba
482488
),
483489

484490
MODEL_TENSOR.SSM_X: (
@@ -489,25 +495,30 @@ class TensorNameMap:
489495
MODEL_TENSOR.SSM_DT: (
490496
"model.layers.{bid}.dt_proj",
491497
"backbone.layers.{bid}.mixer.dt_proj",
498+
"model.layers.{bid}.mamba.dt_proj", # bamba
492499
),
493500

494501
MODEL_TENSOR.SSM_A: (
495502
"model.layers.{bid}.A_log",
496503
"backbone.layers.{bid}.mixer.A_log",
504+
"model.layers.{bid}.mamba.A_log", # bamba
497505
),
498506

499507
MODEL_TENSOR.SSM_D: (
500508
"model.layers.{bid}.D",
501509
"backbone.layers.{bid}.mixer.D",
510+
"model.layers.{bid}.mamba.D", # bamba
502511
),
503512

504513
MODEL_TENSOR.SSM_NORM: (
505514
"backbone.layers.{bid}.mixer.norm", # mamba2
515+
"model.layers.{bid}.mamba.norm", # bamba
506516
),
507517

508518
MODEL_TENSOR.SSM_OUT: (
509519
"model.layers.{bid}.out_proj",
510520
"backbone.layers.{bid}.mixer.out_proj",
521+
"model.layers.{bid}.mamba.out_proj", # bamba
511522
),
512523

513524
MODEL_TENSOR.TIME_MIX_W0: (

0 commit comments

Comments
 (0)