@@ -218,6 +218,8 @@ def from_model_architecture(model_architecture):
218
218
return BertModel
219
219
if model_architecture == "NomicBertModel" :
220
220
return NomicBertModel
221
+ if model_architecture == "GemmaForCausalLM" :
222
+ return GemmaModel
221
223
return Model
222
224
223
225
def _is_model_safetensors (self ) -> bool :
@@ -277,6 +279,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
277
279
return gguf .MODEL_ARCH .BERT
278
280
if arch == "NomicBertModel" :
279
281
return gguf .MODEL_ARCH .NOMIC_BERT
282
+ if arch == "GemmaForCausalLM" :
283
+ return gguf .MODEL_ARCH .GEMMA
280
284
281
285
raise NotImplementedError (f'Architecture "{ arch } " not supported!' )
282
286
@@ -1786,6 +1790,62 @@ def get_tensors(self):
1786
1790
yield name , data
1787
1791
1788
1792
1793
+ class GemmaModel (Model ):
1794
+ def set_vocab (self ):
1795
+ self ._set_vocab_sentencepiece ()
1796
+
1797
+ def set_gguf_parameters (self ):
1798
+ hparams = self .hparams
1799
+ block_count = hparams ["num_hidden_layers" ]
1800
+
1801
+ self .gguf_writer .add_name (self .dir_model .name )
1802
+ self .gguf_writer .add_context_length (hparams ["max_position_embeddings" ])
1803
+ self .gguf_writer .add_embedding_length (hparams ["hidden_size" ])
1804
+ self .gguf_writer .add_block_count (block_count )
1805
+ self .gguf_writer .add_feed_forward_length (hparams ["intermediate_size" ])
1806
+ self .gguf_writer .add_head_count (hparams ["num_attention_heads" ])
1807
+ self .gguf_writer .add_head_count_kv (self .hparams ["num_key_value_heads" ] if "num_key_value_heads" in hparams else hparams ["num_attention_heads" ])
1808
+ self .gguf_writer .add_layer_norm_rms_eps (self .hparams ["rms_norm_eps" ])
1809
+ self .gguf_writer .add_key_length (hparams ["head_dim" ])
1810
+ self .gguf_writer .add_value_length (hparams ["head_dim" ])
1811
+
1812
+ def write_tensors (self ):
1813
+ block_count = self .hparams .get ("n_layers" , self .hparams .get ("num_hidden_layers" , self .hparams .get ("n_layer" )))
1814
+ tensor_map = gguf .get_tensor_name_map (self .model_arch , block_count )
1815
+
1816
+ for name , data_torch in self .get_tensors ():
1817
+ # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
1818
+ if name .endswith ("norm.weight" ):
1819
+ data_torch = data_torch + 1
1820
+
1821
+ old_dtype = data_torch .dtype
1822
+
1823
+ # convert any unsupported data types to float32
1824
+ if data_torch .dtype not in (torch .float16 , torch .float32 ):
1825
+ data_torch = data_torch .to (torch .float32 )
1826
+
1827
+ data = data_torch .squeeze ().numpy ()
1828
+
1829
+ # map tensor names
1830
+ new_name = tensor_map .get_name (name , try_suffixes = (".weight" , ".bias" ))
1831
+ if new_name is None :
1832
+ print (f"Can not map tensor { name !r} " )
1833
+ sys .exit ()
1834
+
1835
+ n_dims = len (data .shape )
1836
+ data_dtype = data .dtype
1837
+
1838
+ data = data .astype (np .float32 )
1839
+
1840
+ # if f16 desired, convert any float32 2-dim weight tensors to float16
1841
+ if self .ftype == 1 and data_dtype == np .float32 and name .endswith (".weight" ) and n_dims == 2 :
1842
+ data = data .astype (np .float16 )
1843
+
1844
+ print (f"{ new_name } , n_dims = { n_dims } , { old_dtype } --> { data .dtype } " )
1845
+
1846
+ self .gguf_writer .add_tensor (new_name , data )
1847
+
1848
+
1789
1849
###### CONVERSION LOGIC ######
1790
1850
1791
1851
0 commit comments