13
13
from pathlib import Path
14
14
from hashlib import sha256
15
15
from typing import TYPE_CHECKING , Any , Callable , ContextManager , Iterable , Iterator , Sequence , TypeVar , cast
16
+ from dataclasses import dataclass
16
17
17
18
import math
18
19
import numpy as np
30
31
logger = logging .getLogger ("hf-to-gguf" )
31
32
32
33
34
+ @dataclass
35
+ class Metadata :
36
+ name : Optional [str ] = None
37
+ author : Optional [str ] = None
38
+ version : Optional [str ] = None
39
+ url : Optional [str ] = None
40
+ description : Optional [str ] = None
41
+ licence : Optional [str ] = None
42
+ source_url : Optional [str ] = None
43
+ source_hf_repo : Optional [str ] = None
44
+
45
+ @staticmethod
46
+ def load (metadata_path : Path ) -> Metadata :
47
+ if metadata_path is None or not metadata_path .exists ():
48
+ return Metadata ()
49
+
50
+ with open (metadata_path , 'r' ) as file :
51
+ data = json .load (file )
52
+
53
+ # Create a new Metadata instance
54
+ metadata = Metadata ()
55
+
56
+ # Assigning values to Metadata attributes if they exist in the JSON file
57
+ # This is based on LLM_KV_NAMES mapping in llama.cpp
58
+ metadata .name = data .get ("general.name" )
59
+ metadata .author = data .get ("general.author" )
60
+ metadata .version = data .get ("general.version" )
61
+ metadata .url = data .get ("general.url" )
62
+ metadata .description = data .get ("general.description" )
63
+ metadata .license = data .get ("general.license" )
64
+ metadata .source_url = data .get ("general.source.url" )
65
+ metadata .source_hf_repo = data .get ("general.source.huggingface.repository" )
66
+
67
+ return metadata
68
+
69
+
33
70
###### MODEL DEFINITIONS ######
34
71
35
72
class SentencePieceTokenTypes (IntEnum ):
@@ -62,11 +99,12 @@ class Model:
62
99
fname_out : Path
63
100
fname_default : Path
64
101
gguf_writer : gguf .GGUFWriter
102
+ metadata : Metadata
65
103
66
104
# subclasses should define this!
67
105
model_arch : gguf .MODEL_ARCH
68
106
69
- def __init__ (self , dir_model : Path , ftype : gguf .LlamaFileType , fname_out : Path , is_big_endian : bool , use_temp_file : bool , eager : bool ):
107
+ def __init__ (self , dir_model : Path , ftype : gguf .LlamaFileType , fname_out : Path , is_big_endian : bool , use_temp_file : bool , eager : bool , metadata : Metadata ):
70
108
if type (self ) is Model :
71
109
raise TypeError (f"{ type (self ).__name__ !r} should not be directly instantiated" )
72
110
self .dir_model = dir_model
@@ -83,6 +121,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
83
121
self .block_count = self .find_hparam (["n_layers" , "num_hidden_layers" , "n_layer" ])
84
122
self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
85
123
self .tensor_names = None
124
+ self .metadata = metadata
86
125
if self .ftype == gguf .LlamaFileType .GUESSED :
87
126
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
88
127
_ , first_tensor = next (self .get_tensors ())
@@ -200,8 +239,34 @@ def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", "
200
239
raise ValueError (f"Can not map tensor { name !r} " )
201
240
return new_name
202
241
242
+ def set_gguf_meta_model (self ):
243
+
244
+ # Metadata About The Model And Its Provenence
245
+ name = "LLaMA"
246
+ if self .metadata is not None and self .metadata .name is not None :
247
+ name = metadata .name
248
+ elif self .dir_model is not None :
249
+ name = self .dir_model .name
250
+
251
+ self .gguf_writer .add_name (name )
252
+
253
+ if self .metadata is not None :
254
+ if self .metadata .author is not None :
255
+ self .gguf_writer .add_author (self .metadata .author )
256
+ if self .metadata .version is not None :
257
+ self .gguf_writer .add_version (self .metadata .version )
258
+ if self .metadata .url is not None :
259
+ self .gguf_writer .add_url (self .metadata .url )
260
+ if self .metadata .description is not None :
261
+ self .gguf_writer .add_description (self .metadata .description )
262
+ if self .metadata .licence is not None :
263
+ self .gguf_writer .add_licence (self .metadata .licence )
264
+ if self .metadata .source_url is not None :
265
+ self .gguf_writer .add_source_url (self .metadata .source_url )
266
+ if self .metadata .source_hf_repo is not None :
267
+ self .gguf_writer .add_source_hf_repo (self .metadata .source_hf_repo )
268
+
203
269
def set_gguf_parameters (self ):
204
- self .gguf_writer .add_name (self .dir_model .name )
205
270
self .gguf_writer .add_block_count (self .block_count )
206
271
207
272
if (n_ctx := self .find_hparam (["max_position_embeddings" , "n_ctx" ], optional = True )) is not None :
@@ -2588,6 +2653,10 @@ def parse_args() -> argparse.Namespace:
2588
2653
"--verbose" , action = "store_true" ,
2589
2654
help = "increase output verbosity" ,
2590
2655
)
2656
+ parser .add_argument (
2657
+ "--metadata" , type = Path ,
2658
+ help = "Specify the path for a metadata file"
2659
+ )
2591
2660
parser .add_argument (
2592
2661
"--get-outfile" , action = "store_true" ,
2593
2662
help = "get calculated default outfile name"
@@ -2607,6 +2676,7 @@ def main() -> None:
2607
2676
else :
2608
2677
logging .basicConfig (level = logging .INFO )
2609
2678
2679
+ metadata = Metadata .load (args .metadata )
2610
2680
dir_model = args .model
2611
2681
2612
2682
if args .awq_path :
@@ -2642,12 +2712,15 @@ def main() -> None:
2642
2712
encodingScheme = ftype_map [args .outtype ]
2643
2713
model_architecture = hparams ["architectures" ][0 ]
2644
2714
model_class = Model .from_model_architecture (model_architecture )
2645
- model_instance = model_class (dir_model , encodingScheme , args .outfile , args .bigendian , args .use_temp_file , args .no_lazy )
2715
+ model_instance = model_class (dir_model , encodingScheme , args .outfile , args .bigendian , args .use_temp_file , args .no_lazy , metadata )
2646
2716
2647
2717
if args .get_outfile :
2648
2718
print (f"{ model_instance .fname_default } " ) # noqa: NP100
2649
2719
return
2650
2720
2721
+ logger .info ("Set meta model" )
2722
+ model_instance .set_gguf_meta_model ()
2723
+
2651
2724
logger .info ("Set model parameters" )
2652
2725
model_instance .set_gguf_parameters ()
2653
2726
0 commit comments