Skip to content

Commit c704442

Browse files
committed
convert-hf-to-gguf.py: add metadata override
1 parent 53096ef commit c704442

File tree

1 file changed

+76
-3
lines changed

1 file changed

+76
-3
lines changed

convert-hf-to-gguf.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pathlib import Path
1414
from hashlib import sha256
1515
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
16+
from dataclasses import dataclass
1617

1718
import math
1819
import numpy as np
@@ -30,6 +31,42 @@
3031
logger = logging.getLogger("hf-to-gguf")
3132

3233

34+
@dataclass
35+
class Metadata:
36+
name: Optional[str] = None
37+
author: Optional[str] = None
38+
version: Optional[str] = None
39+
url: Optional[str] = None
40+
description: Optional[str] = None
41+
licence: Optional[str] = None
42+
source_url: Optional[str] = None
43+
source_hf_repo: Optional[str] = None
44+
45+
@staticmethod
46+
def load(metadata_path: Path) -> Metadata:
47+
if metadata_path is None or not metadata_path.exists():
48+
return Metadata()
49+
50+
with open(metadata_path, 'r') as file:
51+
data = json.load(file)
52+
53+
# Create a new Metadata instance
54+
metadata = Metadata()
55+
56+
# Assigning values to Metadata attributes if they exist in the JSON file
57+
# This is based on LLM_KV_NAMES mapping in llama.cpp
58+
metadata.name = data.get("general.name")
59+
metadata.author = data.get("general.author")
60+
metadata.version = data.get("general.version")
61+
metadata.url = data.get("general.url")
62+
metadata.description = data.get("general.description")
63+
metadata.license = data.get("general.license")
64+
metadata.source_url = data.get("general.source.url")
65+
metadata.source_hf_repo = data.get("general.source.huggingface.repository")
66+
67+
return metadata
68+
69+
3370
###### MODEL DEFINITIONS ######
3471

3572
class SentencePieceTokenTypes(IntEnum):
@@ -62,11 +99,12 @@ class Model:
6299
fname_out: Path
63100
fname_default: Path
64101
gguf_writer: gguf.GGUFWriter
102+
metadata: Metadata
65103

66104
# subclasses should define this!
67105
model_arch: gguf.MODEL_ARCH
68106

69-
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
107+
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, metadata: Metadata):
70108
if type(self) is Model:
71109
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
72110
self.dir_model = dir_model
@@ -83,6 +121,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
83121
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
84122
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
85123
self.tensor_names = None
124+
self.metadata = metadata
86125
if self.ftype == gguf.LlamaFileType.GUESSED:
87126
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
88127
_, first_tensor = next(self.get_tensors())
@@ -200,8 +239,34 @@ def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", "
200239
raise ValueError(f"Can not map tensor {name!r}")
201240
return new_name
202241

242+
def set_gguf_meta_model(self):
243+
244+
# Metadata About The Model And Its Provenence
245+
name = "LLaMA"
246+
if self.metadata is not None and self.metadata.name is not None:
247+
name = metadata.name
248+
elif self.dir_model is not None:
249+
name = self.dir_model.name
250+
251+
self.gguf_writer.add_name(name)
252+
253+
if self.metadata is not None:
254+
if self.metadata.author is not None:
255+
self.gguf_writer.add_author(self.metadata.author)
256+
if self.metadata.version is not None:
257+
self.gguf_writer.add_version(self.metadata.version)
258+
if self.metadata.url is not None:
259+
self.gguf_writer.add_url(self.metadata.url)
260+
if self.metadata.description is not None:
261+
self.gguf_writer.add_description(self.metadata.description)
262+
if self.metadata.licence is not None:
263+
self.gguf_writer.add_licence(self.metadata.licence)
264+
if self.metadata.source_url is not None:
265+
self.gguf_writer.add_source_url(self.metadata.source_url)
266+
if self.metadata.source_hf_repo is not None:
267+
self.gguf_writer.add_source_hf_repo(self.metadata.source_hf_repo)
268+
203269
def set_gguf_parameters(self):
204-
self.gguf_writer.add_name(self.dir_model.name)
205270
self.gguf_writer.add_block_count(self.block_count)
206271

207272
if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
@@ -2588,6 +2653,10 @@ def parse_args() -> argparse.Namespace:
25882653
"--verbose", action="store_true",
25892654
help="increase output verbosity",
25902655
)
2656+
parser.add_argument(
2657+
"--metadata", type=Path,
2658+
help="Specify the path for a metadata file"
2659+
)
25912660
parser.add_argument(
25922661
"--get-outfile", action="store_true",
25932662
help="get calculated default outfile name"
@@ -2607,6 +2676,7 @@ def main() -> None:
26072676
else:
26082677
logging.basicConfig(level=logging.INFO)
26092678

2679+
metadata = Metadata.load(args.metadata)
26102680
dir_model = args.model
26112681

26122682
if args.awq_path:
@@ -2642,12 +2712,15 @@ def main() -> None:
26422712
encodingScheme = ftype_map[args.outtype]
26432713
model_architecture = hparams["architectures"][0]
26442714
model_class = Model.from_model_architecture(model_architecture)
2645-
model_instance = model_class(dir_model, encodingScheme, args.outfile, args.bigendian, args.use_temp_file, args.no_lazy)
2715+
model_instance = model_class(dir_model, encodingScheme, args.outfile, args.bigendian, args.use_temp_file, args.no_lazy, metadata)
26462716

26472717
if args.get_outfile:
26482718
print(f"{model_instance.fname_default}") # noqa: NP100
26492719
return
26502720

2721+
logger.info("Set meta model")
2722+
model_instance.set_gguf_meta_model()
2723+
26512724
logger.info("Set model parameters")
26522725
model_instance.set_gguf_parameters()
26532726

0 commit comments

Comments
 (0)