Skip to content

Commit b97704c

Browse files
committed
refactor: better refactor
1 parent bfa0286 commit b97704c

File tree

1 file changed

+10
-39
lines changed

1 file changed

+10
-39
lines changed

convert_hf_to_gguf.py

Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2711,7 +2711,7 @@ class StarCoder2Model(Model):
27112711
model_arch = gguf.MODEL_ARCH.STARCODER2
27122712

27132713

2714-
@Model.register("MambaForCausalLM", "MambaLMHeadModel")
2714+
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
27152715
class MambaModel(Model):
27162716
model_arch = gguf.MODEL_ARCH.MAMBA
27172717

@@ -2731,7 +2731,7 @@ def set_vocab(self):
27312731
else:
27322732
# Use the GPT-NeoX tokenizer when no tokenizer files are present
27332733
self._set_vocab_builtin("gpt-neox", vocab_size)
2734-
2734+
27352735
def set_gguf_parameters(self):
27362736
d_model = self.find_hparam(["hidden_size", "d_model"])
27372737
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
@@ -2742,21 +2742,25 @@ def set_gguf_parameters(self):
27422742
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
27432743
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
27442744
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
2745-
2745+
num_hidden_layers = self.find_hparam(["n_layer", "num_hidden_layers"])
2746+
use_b_dt_norm = False
2747+
# For falconmamba we do apply RMS norm on B / DT and C layers
2748+
if self.find_hparam(["model_type"]) in ["falcon_mamba"]:
2749+
use_b_dt_norm = True
27462750
# Fail early for models which don't have a block expansion factor of 2
27472751
assert d_inner == 2 * d_model
27482752

27492753
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
27502754
self.gguf_writer.add_embedding_length(d_model)
27512755
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
27522756
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
2753-
self.gguf_writer.add_block_count(self.hparams["n_layer"])
2757+
self.gguf_writer.add_block_count(num_hidden_layers)
27542758
self.gguf_writer.add_ssm_conv_kernel(d_conv)
27552759
self.gguf_writer.add_ssm_inner_size(d_inner)
27562760
self.gguf_writer.add_ssm_state_size(d_state)
27572761
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
27582762
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
2759-
self.gguf_writer.add_mamba_b_dt_rms(False) # For classic Mamba we don't apply rms norm on B / DT layers
2763+
self.gguf_writer.add_mamba_b_dt_rms(use_b_dt_norm) # For classic Mamba we don't apply rms norm on B / DT layers
27602764
self.gguf_writer.add_file_type(self.ftype)
27612765

27622766
_tok_embd = None
@@ -3855,43 +3859,10 @@ def prepare_tensors(self):
38553859
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
38563860

38573861
super().prepare_tensors()
3858-
3859-
3860-
@Model.register("FalconMambaForCausalLM")
3861-
class FalconMambaModel(MambaModel):
3862-
model_arch = gguf.MODEL_ARCH.MAMBA
3863-
3864-
def set_gguf_parameters(self):
3865-
d_model = self.find_hparam(["hidden_size", "d_model"])
3866-
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
3867-
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
3868-
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16
3869-
# ceiling division
3870-
# ref: https://stackoverflow.com/a/17511341/22827863
3871-
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
3872-
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
3873-
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
3874-
3875-
# Fail early for models which don't have a block expansion factor of 2
3876-
assert d_inner == 2 * d_model
3877-
3878-
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
3879-
self.gguf_writer.add_embedding_length(d_model)
3880-
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
3881-
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
3882-
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
3883-
self.gguf_writer.add_ssm_conv_kernel(d_conv)
3884-
self.gguf_writer.add_mamba_b_dt_rms(True) # For FalconMamba we do apply rms norm on B / DT layers
3885-
self.gguf_writer.add_ssm_inner_size(d_inner)
3886-
self.gguf_writer.add_ssm_state_size(d_state)
3887-
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
3888-
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
3889-
self.gguf_writer.add_file_type(self.ftype)
3890-
3862+
38913863

38923864
###### CONVERSION LOGIC ######
38933865

3894-
38953866
# tree of lazy tensors
38963867
class LazyTorchTensor(gguf.LazyBase):
38973868
_tensor_type = torch.Tensor

0 commit comments

Comments
 (0)