10
10
import sys
11
11
from enum import IntEnum
12
12
from pathlib import Path
13
- from typing import TYPE_CHECKING , Any , ContextManager , Iterator , cast
13
+ from typing import TYPE_CHECKING , Any , ContextManager , Iterator , Sequence , cast
14
14
15
15
import numpy as np
16
16
import torch
25
25
from convert import HfVocab
26
26
27
27
28
- # check for any of the given keys in the dictionary and return the value of the first key found
29
- def get_key_opts (d , keys ):
30
- for k in keys :
31
- if k in d :
32
- return d [k ]
33
- print (f"Could not find any of { keys } " )
34
- sys .exit ()
35
-
36
-
37
28
###### MODEL DEFINITIONS ######
38
29
39
30
class SentencePieceTokenTypes (IntEnum ):
@@ -58,6 +49,15 @@ def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian:
58
49
self .hparams = Model .load_hparams (self .dir_model )
59
50
self .model_arch = self ._get_model_architecture ()
60
51
self .gguf_writer = gguf .GGUFWriter (fname_out , gguf .MODEL_ARCH_NAMES [self .model_arch ], endianess = self .endianess , use_temp_file = False )
52
+ self .block_count = self .find_hparam (["n_layers" , "num_hidden_layers" , "n_layer" ])
53
+
54
+ def find_hparam (self , keys : Sequence [str ], optional : bool = False ) -> Any :
55
+ key = next ((k for k in keys if k in self .hparams ), None )
56
+ if key is not None :
57
+ return self .hparams [key ]
58
+ if optional :
59
+ return None
60
+ raise KeyError (f"could not find any of: { keys } " )
61
61
62
62
def set_vocab (self ):
63
63
self ._set_vocab_gpt2 ()
@@ -79,28 +79,33 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
79
79
80
80
def set_gguf_parameters (self ):
81
81
self .gguf_writer .add_name (self .dir_model .name )
82
- self .gguf_writer .add_block_count (self .hparams .get (
83
- "n_layers" , self .hparams .get ("num_hidden_layers" , self .hparams .get ("n_layer" )),
84
- ))
85
- if (n_ctx := self .hparams .get ("max_position_embeddings" )) is not None :
82
+ self .gguf_writer .add_block_count (self .block_count )
83
+
84
+ if (n_ctx := self .find_hparam (["max_position_embeddings" , "n_ctx" ], optional = True )) is not None :
86
85
self .gguf_writer .add_context_length (n_ctx )
87
- if (n_embd := self .hparams .get ("hidden_size" )) is not None :
88
- self .gguf_writer .add_embedding_length (n_embd )
89
- if (n_ff := self .hparams .get ("intermediate_size" )) is not None :
86
+
87
+ n_embd = self .find_hparam (["hidden_size" , "n_embd" ])
88
+ self .gguf_writer .add_embedding_length (n_embd )
89
+
90
+ if (n_ff := self .find_hparam (["intermediate_size" , "n_inner" ], optional = True )) is not None :
90
91
self .gguf_writer .add_feed_forward_length (n_ff )
91
- if (n_head := self .hparams .get ("num_attention_heads" )) is not None :
92
- self .gguf_writer .add_head_count (n_head )
92
+
93
+ n_head = self .find_hparam (["num_attention_heads" , "n_head" ])
94
+ self .gguf_writer .add_head_count (n_head )
95
+
93
96
if (n_head_kv := self .hparams .get ("num_key_value_heads" )) is not None :
94
97
self .gguf_writer .add_head_count_kv (n_head_kv )
95
98
96
- if (n_rms_eps := self .hparams .get ("rms_norm_eps" )) is not None :
97
- self .gguf_writer .add_layer_norm_rms_eps (n_rms_eps )
99
+ if (f_rms_eps := self .hparams .get ("rms_norm_eps" )) is not None :
100
+ self .gguf_writer .add_layer_norm_rms_eps (f_rms_eps )
101
+ if (f_norm_eps := self .find_hparam (["layer_norm_eps" , "layer_norm_epsilon" ], optional = True )) is not None :
102
+ self .gguf_writer .add_layer_norm_eps (f_norm_eps )
98
103
if (n_experts := self .hparams .get ("num_local_experts" )) is not None :
99
104
self .gguf_writer .add_expert_count (n_experts )
100
105
if (n_experts_used := self .hparams .get ("num_experts_per_tok" )) is not None :
101
106
self .gguf_writer .add_expert_used_count (n_experts_used )
102
107
103
- self .gguf_writer .add_parallel_residual (self .hparams . get ( "use_parallel_residual" , True ) )
108
+ self .gguf_writer .add_file_type (self .ftype )
104
109
105
110
def write_tensors (self ):
106
111
block_count = self .hparams .get ("n_layers" , self .hparams .get ("num_hidden_layers" , self .hparams .get ("n_layer" )))
@@ -211,6 +216,8 @@ def from_model_architecture(model_architecture):
211
216
return MiniCPMModel
212
217
if model_architecture == "BertModel" :
213
218
return BertModel
219
+ if model_architecture == "NomicBertModel" :
220
+ return NomicBertModel
214
221
return Model
215
222
216
223
def _is_model_safetensors (self ) -> bool :
@@ -268,6 +275,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
268
275
return gguf .MODEL_ARCH .MINICPM
269
276
if arch == "BertModel" :
270
277
return gguf .MODEL_ARCH .BERT
278
+ if arch == "NomicBertModel" :
279
+ return gguf .MODEL_ARCH .NOMIC_BERT
271
280
272
281
raise NotImplementedError (f'Architecture "{ arch } " not supported!' )
273
282
@@ -1297,21 +1306,21 @@ def write_tensors(self):
1297
1306
1298
1307
class Phi2Model (Model ):
1299
1308
def set_gguf_parameters (self ):
1300
- block_count = get_key_opts ( self .hparams , ["num_hidden_layers" , "n_layer" ])
1309
+ block_count = self .find_hparam ( ["num_hidden_layers" , "n_layer" ])
1301
1310
1302
- rot_pct = get_key_opts ( self .hparams , ["partial_rotary_factor" ])
1303
- n_embd = get_key_opts ( self .hparams , ["hidden_size" , "n_embd" ])
1304
- n_head = get_key_opts ( self .hparams , ["num_attention_heads" , "n_head" ])
1311
+ rot_pct = self .find_hparam ( ["partial_rotary_factor" ])
1312
+ n_embd = self .find_hparam ( ["hidden_size" , "n_embd" ])
1313
+ n_head = self .find_hparam ( ["num_attention_heads" , "n_head" ])
1305
1314
1306
1315
self .gguf_writer .add_name ("Phi2" )
1307
- self .gguf_writer .add_context_length (get_key_opts ( self .hparams , ["n_positions" , "max_position_embeddings" ]))
1316
+ self .gguf_writer .add_context_length (self .find_hparam ( ["n_positions" , "max_position_embeddings" ]))
1308
1317
1309
1318
self .gguf_writer .add_embedding_length (n_embd )
1310
1319
self .gguf_writer .add_feed_forward_length (4 * n_embd )
1311
1320
self .gguf_writer .add_block_count (block_count )
1312
1321
self .gguf_writer .add_head_count (n_head )
1313
1322
self .gguf_writer .add_head_count_kv (n_head )
1314
- self .gguf_writer .add_layer_norm_eps (get_key_opts ( self .hparams , ["layer_norm_epsilon" , "layer_norm_eps" ]))
1323
+ self .gguf_writer .add_layer_norm_eps (self .find_hparam ( ["layer_norm_epsilon" , "layer_norm_eps" ]))
1315
1324
self .gguf_writer .add_rope_dimension_count (int (rot_pct * n_embd ) // n_head )
1316
1325
self .gguf_writer .add_file_type (self .ftype )
1317
1326
self .gguf_writer .add_add_bos_token (False )
@@ -1636,20 +1645,12 @@ def write_tensors(self):
1636
1645
class BertModel (Model ):
1637
1646
def __init__ (self , * args , ** kwargs ):
1638
1647
super ().__init__ (* args , ** kwargs )
1639
- self .block_count = self . hparams [ "num_hidden_layers" ]
1648
+ self .vocab_size = None
1640
1649
1641
1650
def set_gguf_parameters (self ):
1642
- # TODO(cebtenzzre): merge with parent class
1643
- self .gguf_writer .add_name (self .dir_model .name )
1644
- self .gguf_writer .add_context_length (self .hparams ["max_position_embeddings" ])
1645
- self .gguf_writer .add_embedding_length (self .hparams ["hidden_size" ])
1646
- self .gguf_writer .add_feed_forward_length (self .hparams ["intermediate_size" ])
1647
- self .gguf_writer .add_block_count (self .block_count )
1648
- self .gguf_writer .add_head_count (self .hparams ["num_attention_heads" ])
1649
- self .gguf_writer .add_layer_norm_eps (self .hparams ["layer_norm_eps" ])
1651
+ super ().set_gguf_parameters ()
1650
1652
self .gguf_writer .add_causal_attention (False )
1651
1653
self .gguf_writer .add_pooling_layer (True )
1652
- self .gguf_writer .add_file_type (self .ftype )
1653
1654
1654
1655
def set_vocab (self ):
1655
1656
path = self .dir_model
@@ -1659,6 +1660,7 @@ def set_vocab(self):
1659
1660
vocab = HfVocab (path , added_tokens_path )
1660
1661
tokens , scores , toktypes = zip (* vocab .all_tokens ())
1661
1662
assert len (tokens ) == vocab .vocab_size
1663
+ self .vocab_size = vocab .vocab_size
1662
1664
1663
1665
# we need this to validate the size of the token_type embeddings
1664
1666
# though currently we are passing all zeros to the token_type embeddings
@@ -1672,7 +1674,7 @@ def phantom(tok, typ):
1672
1674
if tok .startswith (b"##" ):
1673
1675
return tok [2 :]
1674
1676
return b"\xe2 \x96 \x81 " + tok
1675
- tokens = [ phantom (t , y ) for t , y in zip (tokens , toktypes )]
1677
+ tokens = tuple ( phantom (t , y ) for t , y in zip (tokens , toktypes ))
1676
1678
1677
1679
# set up bos and eos tokens (cls and sep)
1678
1680
self .gguf_writer .add_bos_token_id (vocab .tokenizer .cls_token_id )
@@ -1724,6 +1726,43 @@ def write_tensors(self):
1724
1726
self .gguf_writer .add_tensor (new_name , data )
1725
1727
1726
1728
1729
+ class NomicBertModel (BertModel ):
1730
+ def __init__ (self , * args , ** kwargs ):
1731
+ super ().__init__ (* args , ** kwargs )
1732
+
1733
+ # the HF config claims n_ctx=8192, but it uses RoPE scaling
1734
+ self .hparams ["n_ctx" ] = 2048
1735
+
1736
+ # SwigLU activation
1737
+ assert self .hparams ["activation_function" ] == "swiglu"
1738
+ # this doesn't do anything in the HF version
1739
+ assert self .hparams ["causal" ] is False
1740
+ # no bias tensors
1741
+ assert self .hparams ["qkv_proj_bias" ] is False
1742
+ assert self .hparams ["mlp_fc1_bias" ] is False
1743
+ assert self .hparams ["mlp_fc2_bias" ] is False
1744
+ # norm at end of layer
1745
+ assert self .hparams ["prenorm" ] is False
1746
+ # standard RoPE
1747
+ assert self .hparams ["rotary_emb_fraction" ] == 1.0
1748
+ assert self .hparams ["rotary_emb_interleaved" ] is False
1749
+ assert self .hparams ["rotary_emb_scale_base" ] is None
1750
+
1751
+ def set_gguf_parameters (self ):
1752
+ super ().set_gguf_parameters ()
1753
+ self .gguf_writer .add_rope_freq_base (self .hparams ["rotary_emb_base" ])
1754
+
1755
+ def get_tensors (self ):
1756
+ assert self .vocab_size is not None
1757
+ for name , data in super ().get_tensors ():
1758
+ # Nomic Embed's token embeddings tensor is padded, but llama.cpp wants tensor sizes to match exactly.
1759
+ if name == 'embeddings.word_embeddings.weight' and data .shape [1 ] != self .vocab_size :
1760
+ rounded_vocab_size = (self .vocab_size + 63 ) // 64 * 64
1761
+ assert data .shape == (rounded_vocab_size , self .hparams ["n_embd" ])
1762
+ data = data [:self .vocab_size , :]
1763
+ yield name , data
1764
+
1765
+
1727
1766
###### CONVERSION LOGIC ######
1728
1767
1729
1768
0 commit comments