@@ -4660,6 +4660,9 @@ def __init__(self, dir_model: Path, *args, **kwargs):
4660
4660
with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
4661
4661
hparams = json .load (f )
4662
4662
super ().__init__ (dir_model , * args , hparams = hparams , ** kwargs )
4663
+ self .d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
4664
+ self .d_inner = self .find_hparam (["intermediate_size" , "d_inner" ], optional = True ) or 2 * self .d_model
4665
+ self .n_group = self .hparams .get ("n_groups" , 1 )
4663
4666
4664
4667
def set_vocab (self ):
4665
4668
vocab_size = self .hparams ["vocab_size" ]
@@ -4730,10 +4733,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
4730
4733
# (D is also unsqueezed, but for more straightforward broadcast internally)
4731
4734
data_torch = data_torch .reshape ((* data_torch .shape , 1 ))
4732
4735
elif self .match_model_tensor_name (new_name , gguf .MODEL_TENSOR .SSM_NORM , bid ):
4733
- d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
4734
- d_inner = self .find_hparam (["intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
4735
- n_group = self .hparams .get ("n_groups" , 1 )
4736
- data_torch = data_torch .reshape ((n_group , d_inner // n_group ))
4736
+ data_torch = data_torch .reshape ((self .n_group , self .d_inner // self .n_group ))
4737
4737
4738
4738
if name .endswith (".A_log" ):
4739
4739
logger .debug ("A_log --> A ==> " + new_name )
@@ -4742,6 +4742,107 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
4742
4742
yield (new_name , data_torch )
4743
4743
4744
4744
4745
+ @ModelBase .register ("BambaForCausalLM" )
4746
+ class BambaModel (Mamba2Model ):
4747
+ """Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers"""
4748
+ model_arch = gguf .MODEL_ARCH .BAMBA
4749
+ undo_permute = True
4750
+
4751
+ def __init__ (self , * args , ** kwargs ):
4752
+
4753
+ # Hybrid mamba models use a prefix for the mamba-specific params.
4754
+ # TODO: Extend this if the prefix(es) need to be configurable
4755
+ self .hparam_prefixes = ["mamba" ]
4756
+
4757
+ super ().__init__ (* args , ** kwargs )
4758
+
4759
+ # Use Llama conversion for attention
4760
+ self ._transformer_model_class : type [TextModel ] = LlamaModel
4761
+
4762
+ # Lists of which layers use ssm vs attention
4763
+ self ._attn_layers = self .hparams .get ("attn_layer_indices" , [])
4764
+ if not self ._attn_layers :
4765
+ attn_period = self .hparams .get ("attn_layer_period" )
4766
+ assert attn_period , "Didn't find attn_layer_indices or attn_layer_period"
4767
+ attn_offset = self .hparams .get ("attn_layer_offset" )
4768
+ assert attn_offset is not None , "No attention layer offset set with attn_layer_period"
4769
+ self ._attn_layers = [
4770
+ i for i in range (self .block_count )
4771
+ if i % attn_period == attn_offset
4772
+ ]
4773
+ self ._ssm_layers = [
4774
+ i for i in range (self .block_count )
4775
+ if i not in self ._attn_layers
4776
+ ]
4777
+
4778
+ # n_group and d_inner are used during reshape_tensors for mamaba2
4779
+ self .d_model = self .find_hparam (["hidden_size" , "d_model" ])
4780
+ self .n_group = self .find_hparam (["n_groups" ])
4781
+ self .d_inner = self .find_hparam (["expand" ]) * self .d_model
4782
+
4783
+ def find_hparam (self , keys : Iterable [str ], * args , ** kwargs ) -> Any :
4784
+ prefixed = []
4785
+ for pfx in self .hparam_prefixes :
4786
+ prefixed .extend (
4787
+ "_" .join ([pfx , k ])
4788
+ for k in keys
4789
+ )
4790
+ keys = list (keys ) + prefixed
4791
+ return super ().find_hparam (keys , * args , ** kwargs )
4792
+
4793
+ def set_gguf_parameters (self ):
4794
+
4795
+ ## General Params ##
4796
+ self .gguf_writer .add_embedding_length (self .d_model )
4797
+ self .gguf_writer .add_block_count (self .block_count )
4798
+ self .gguf_writer .add_context_length (self .hparams .get ("max_position_embeddings" , 0 ))
4799
+ self .gguf_writer .add_vocab_size (self .hparams ["vocab_size" ])
4800
+ self .gguf_writer .add_feed_forward_length (self .hparams ["intermediate_size" ])
4801
+
4802
+ ## Mamba mixer params ##
4803
+ self .gguf_writer .add_ssm_conv_kernel (self .find_hparam (["conv_kernel" , "d_conv" ]))
4804
+ self .gguf_writer .add_ssm_state_size (self .find_hparam (["state_size" , "d_state" ]))
4805
+ self .gguf_writer .add_ssm_group_count (self .n_group )
4806
+ self .gguf_writer .add_ssm_inner_size (self .d_inner )
4807
+ # NOTE: The mamba_dt_rank is _not_ the right field for how this is used
4808
+ # in llama.cpp
4809
+ self .gguf_writer .add_ssm_time_step_rank (self .find_hparam (["n_heads" ]))
4810
+
4811
+ ## Attention params ##
4812
+ self .gguf_writer .add_attn_layer_indices (self ._attn_layers )
4813
+ self .gguf_writer .add_rope_dimension_count (self .hparams ["attn_rotary_emb" ])
4814
+ self .gguf_writer .add_head_count (self .hparams ["num_attention_heads" ])
4815
+ self .gguf_writer .add_head_count_kv (self .find_hparam (["num_key_value_heads" , "n_head_kv" ]))
4816
+
4817
+ ## Feed Forward Params ##
4818
+ self .gguf_writer .add_layer_norm_rms_eps (
4819
+ self .find_hparam (["layer_norm_epsilon" , "rms_norm_eps" ], optional = True ) or 1e-5
4820
+ )
4821
+
4822
+ ## Validation ##
4823
+ d_head = self .find_hparam (["d_head" ], optional = True ) or 64
4824
+ assert self .hparams .get ("hidden_act" ) in [None , "silu" ], "Only SILU activation supported"
4825
+ assert self .d_inner % d_head == 0 , f"SSM inner size { self .d_inner } not a multiple of head dim { d_head } "
4826
+
4827
+ def modify_tensors (
4828
+ self , data_torch : Tensor , name : str , bid : int | None
4829
+ ) -> Iterable [tuple [str , Tensor ]]:
4830
+
4831
+ # Determine whether this is a mamaba layer or an attention layer
4832
+ if bid in self ._ssm_layers :
4833
+ for mamba_new_name , data_torch in super ().modify_tensors (
4834
+ data_torch , name , bid
4835
+ ):
4836
+ yield mamba_new_name , data_torch
4837
+ elif bid in self ._attn_layers :
4838
+ for llama_new_name , data_torch in self ._transformer_model_class .modify_tensors (
4839
+ self , data_torch , name , bid
4840
+ ):
4841
+ yield llama_new_name , data_torch
4842
+ else :
4843
+ yield self .map_tensor_name (name ), data_torch
4844
+
4845
+
4745
4846
@ModelBase .register ("CohereForCausalLM" )
4746
4847
class CommandR2Model (TextModel ):
4747
4848
model_arch = gguf .MODEL_ARCH .COMMAND_R
0 commit comments