Skip to content

Commit 847eedb

Browse files
ggerganovakxcebtenzzre
authored
py : add Gemma conversion from HF models (#5647)
* py : add gemma conversion from HF models * Update convert-hf-to-gguf.py Co-authored-by: Aarni Koskela <[email protected]> * Update convert-hf-to-gguf.py Co-authored-by: Aarni Koskela <[email protected]> * Update convert-hf-to-gguf.py Co-authored-by: Jared Van Bortel <[email protected]> --------- Co-authored-by: Aarni Koskela <[email protected]> Co-authored-by: Jared Van Bortel <[email protected]>
1 parent 7e4f339 commit 847eedb

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

convert-hf-to-gguf.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@ def from_model_architecture(model_architecture):
218218
return BertModel
219219
if model_architecture == "NomicBertModel":
220220
return NomicBertModel
221+
if model_architecture == "GemmaForCausalLM":
222+
return GemmaModel
221223
return Model
222224

223225
def _is_model_safetensors(self) -> bool:
@@ -277,6 +279,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
277279
return gguf.MODEL_ARCH.BERT
278280
if arch == "NomicBertModel":
279281
return gguf.MODEL_ARCH.NOMIC_BERT
282+
if arch == "GemmaForCausalLM":
283+
return gguf.MODEL_ARCH.GEMMA
280284

281285
raise NotImplementedError(f'Architecture "{arch}" not supported!')
282286

@@ -1786,6 +1790,62 @@ def get_tensors(self):
17861790
yield name, data
17871791

17881792

1793+
class GemmaModel(Model):
1794+
def set_vocab(self):
1795+
self._set_vocab_sentencepiece()
1796+
1797+
def set_gguf_parameters(self):
1798+
hparams = self.hparams
1799+
block_count = hparams["num_hidden_layers"]
1800+
1801+
self.gguf_writer.add_name(self.dir_model.name)
1802+
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
1803+
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
1804+
self.gguf_writer.add_block_count(block_count)
1805+
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
1806+
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
1807+
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
1808+
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
1809+
self.gguf_writer.add_key_length(hparams["head_dim"])
1810+
self.gguf_writer.add_value_length(hparams["head_dim"])
1811+
1812+
def write_tensors(self):
1813+
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
1814+
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
1815+
1816+
for name, data_torch in self.get_tensors():
1817+
# ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
1818+
if name.endswith("norm.weight"):
1819+
data_torch = data_torch + 1
1820+
1821+
old_dtype = data_torch.dtype
1822+
1823+
# convert any unsupported data types to float32
1824+
if data_torch.dtype not in (torch.float16, torch.float32):
1825+
data_torch = data_torch.to(torch.float32)
1826+
1827+
data = data_torch.squeeze().numpy()
1828+
1829+
# map tensor names
1830+
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
1831+
if new_name is None:
1832+
print(f"Can not map tensor {name!r}")
1833+
sys.exit()
1834+
1835+
n_dims = len(data.shape)
1836+
data_dtype = data.dtype
1837+
1838+
data = data.astype(np.float32)
1839+
1840+
# if f16 desired, convert any float32 2-dim weight tensors to float16
1841+
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
1842+
data = data.astype(np.float16)
1843+
1844+
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
1845+
1846+
self.gguf_writer.add_tensor(new_name, data)
1847+
1848+
17891849
###### CONVERSION LOGIC ######
17901850

17911851

llama.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7450,6 +7450,7 @@ struct llm_build_context {
74507450

74517451
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
74527452
cb(inpL, "inp_embd", -1);
7453+
74537454
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
74547455
cb(inpL, "inp_scaled", -1);
74557456

@@ -7491,6 +7492,7 @@ struct llm_build_context {
74917492
n_embd_head_k, 2, 0, n_orig_ctx, freq_base, freq_scale,
74927493
ext_factor, attn_factor, beta_fast, beta_slow);
74937494
cb(Qcur, "Qcur", il);
7495+
74947496
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
74957497
cb(Qcur, "Qcur_scaled", il);
74967498

@@ -7505,6 +7507,7 @@ struct llm_build_context {
75057507
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
75067508
cb(cur, "kqv_out", il);
75077509
}
7510+
75087511
struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
75097512
cb(sa_out, "sa_out", il);
75107513

0 commit comments

Comments
 (0)