Skip to content

Commit 15dc4a9

Browse files
committed
py: more detailed dataset authorship support
1 parent df885aa commit 15dc4a9

File tree

4 files changed

+129
-28
lines changed

4 files changed

+129
-28
lines changed

examples/convert_legacy_llama.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -851,12 +851,32 @@ def add_meta_model(self, params: Params, metadata: gguf.Metadata | None) -> None
851851
if "repo_url" in base_model_entry:
852852
self.gguf.add_base_model_repo_url(key, base_model_entry["repo_url"])
853853

854+
if metadata.datasets is not None:
855+
self.gguf.add_dataset_count(len(metadata.datasets))
856+
for key, dataset_entry in enumerate(metadata.datasets):
857+
if "name" in dataset_entry:
858+
self.gguf.add_dataset_name(key, dataset_entry["name"])
859+
if "author" in dataset_entry:
860+
self.gguf.add_dataset_author(key, dataset_entry["author"])
861+
if "version" in dataset_entry:
862+
self.gguf.add_dataset_version(key, dataset_entry["version"])
863+
if "organization" in dataset_entry:
864+
self.gguf.add_dataset_organization(key, dataset_entry["organization"])
865+
if "description" in dataset_entry:
866+
self.gguf.add_dataset_description(key, dataset_entry["description"])
867+
if "url" in dataset_entry:
868+
self.gguf.add_dataset_url(key, dataset_entry["url"])
869+
if "doi" in dataset_entry:
870+
self.gguf.add_dataset_doi(key, dataset_entry["doi"])
871+
if "uuid" in dataset_entry:
872+
self.gguf.add_dataset_uuid(key, dataset_entry["uuid"])
873+
if "repo_url" in dataset_entry:
874+
self.gguf.add_dataset_repo_url(key, dataset_entry["repo_url"])
875+
854876
if metadata.tags is not None:
855877
self.gguf.add_tags(metadata.tags)
856878
if metadata.languages is not None:
857879
self.gguf.add_languages(metadata.languages)
858-
if metadata.datasets is not None:
859-
self.gguf.add_datasets(metadata.datasets)
860880

861881
def add_meta_arch(self, params: Params) -> None:
862882
# Metadata About The Neural Architecture Itself

gguf-py/gguf/constants.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,6 @@ class General:
7070
BASE_MODEL_UUID = "general.base_model.{id}.uuid"
7171
BASE_MODEL_REPO_URL = "general.base_model.{id}.repo_url" # Model Source Repository (git/svn/etc...)
7272

73-
# Array based KV stores
74-
TAGS = "general.tags"
75-
LANGUAGES = "general.languages"
76-
DATASETS = "general.datasets"
77-
7873
# Dataset Source
7974
DATASET_COUNT = "general.dataset.count"
8075
DATASET_NAME = "general.dataset.{id}.name"
@@ -87,6 +82,10 @@ class General:
8782
DATASET_UUID = "general.dataset.{id}.uuid"
8883
DATASET_REPO_URL = "general.dataset.{id}.repo_url" # Model Source Repository (git/svn/etc...)
8984

85+
# Array based KV stores
86+
TAGS = "general.tags"
87+
LANGUAGES = "general.languages"
88+
9089
class LLM:
9190
VOCAB_SIZE = "{arch}.vocab_size"
9291
CONTEXT_LENGTH = "{arch}.context_length"

gguf-py/gguf/gguf_writer.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,15 +583,42 @@ def add_base_model_uuid(self, source_id: int, uuid: str) -> None:
583583
def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None:
584584
self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url)
585585

586+
def add_dataset_count(self, source_count: int) -> None:
587+
self.add_uint32(Keys.General.DATASET_COUNT, source_count)
588+
589+
def add_dataset_name(self, source_id: int, name: str) -> None:
590+
self.add_string(Keys.General.DATASET_NAME.format(id=source_id), name)
591+
592+
def add_dataset_author(self, source_id: int, author: str) -> None:
593+
self.add_string(Keys.General.DATASET_AUTHOR.format(id=source_id), author)
594+
595+
def add_dataset_version(self, source_id: int, version: str) -> None:
596+
self.add_string(Keys.General.DATASET_VERSION.format(id=source_id), version)
597+
598+
def add_dataset_organization(self, source_id: int, organization: str) -> None:
599+
self.add_string(Keys.General.DATASET_ORGANIZATION.format(id=source_id), organization)
600+
601+
def add_dataset_description(self, source_id: int, description: str) -> None:
602+
self.add_string(Keys.General.DATASET_DESCRIPTION.format(id=source_id), description)
603+
604+
def add_dataset_url(self, source_id: int, url: str) -> None:
605+
self.add_string(Keys.General.DATASET_URL.format(id=source_id), url)
606+
607+
def add_dataset_doi(self, source_id: int, doi: str) -> None:
608+
self.add_string(Keys.General.DATASET_DOI.format(id=source_id), doi)
609+
610+
def add_dataset_uuid(self, source_id: int, uuid: str) -> None:
611+
self.add_string(Keys.General.DATASET_UUID.format(id=source_id), uuid)
612+
613+
def add_dataset_repo_url(self, source_id: int, repo_url: str) -> None:
614+
self.add_string(Keys.General.DATASET_REPO_URL.format(id=source_id), repo_url)
615+
586616
def add_tags(self, tags: Sequence[str]) -> None:
587617
self.add_array(Keys.General.TAGS, tags)
588618

589619
def add_languages(self, languages: Sequence[str]) -> None:
590620
self.add_array(Keys.General.LANGUAGES, languages)
591621

592-
def add_datasets(self, datasets: Sequence[str]) -> None:
593-
self.add_array(Keys.General.DATASETS, datasets)
594-
595622
def add_tensor_data_layout(self, layout: str) -> None:
596623
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
597624

gguf-py/gguf/metadata.py

Lines changed: 73 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ class Metadata:
3939
license_name: Optional[str] = None
4040
license_link: Optional[str] = None
4141
base_models: Optional[list[dict]] = None
42+
datasets: Optional[list[dict]] = None
4243
tags: Optional[list[str]] = None
4344
languages: Optional[list[str]] = None
44-
datasets: Optional[list[str]] = None
4545

4646
@staticmethod
4747
def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata:
@@ -91,9 +91,11 @@ def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Pat
9191
# Base Models is received here as an array of models
9292
metadata.base_models = metadata_override.get("general.base_models", metadata.base_models)
9393

94+
# Datasets is received here as an array of datasets
95+
metadata.datasets = metadata_override.get("general.datasets", metadata.datasets)
96+
9497
metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags)
9598
metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages)
96-
metadata.datasets = metadata_override.get(Keys.General.DATASETS, metadata.datasets)
9799

98100
# Direct Metadata Override (via direct cli argument)
99101
if model_name is not None:
@@ -346,12 +348,12 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
346348
use_model_card_metadata("author", "model_creator")
347349
use_model_card_metadata("basename", "model_type")
348350

349-
if "base_model" in model_card:
351+
if "base_model" in model_card or "base_models" in model_card:
350352
# This represents the parent models that this is based on
351353
# Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
352354
# Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
353355
metadata_base_models = []
354-
base_model_value = model_card.get("base_model", None)
356+
base_model_value = model_card.get("base_model", model_card.get("base_models", None))
355357

356358
if base_model_value is not None:
357359
if isinstance(base_model_value, str):
@@ -364,18 +366,54 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
364366

365367
for model_id in metadata_base_models:
366368
# NOTE: model size of base model is assumed to be similar to the size of the current model
367-
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
368369
base_model = {}
369-
if model_full_name_component is not None:
370-
base_model["name"] = Metadata.id_to_title(model_full_name_component)
371-
if org_component is not None:
372-
base_model["organization"] = Metadata.id_to_title(org_component)
373-
if version is not None:
374-
base_model["version"] = version
375-
if org_component is not None and model_full_name_component is not None:
376-
base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
370+
if isinstance(model_id, str) and (model_id.startswith("http://") or model_id.startswith("https://")):
371+
base_model["repo_url"] = model_id
372+
else:
373+
# Likely a Hugging Face ID
374+
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
375+
if model_full_name_component is not None:
376+
base_model["name"] = Metadata.id_to_title(model_full_name_component)
377+
if org_component is not None:
378+
base_model["organization"] = Metadata.id_to_title(org_component)
379+
if version is not None:
380+
base_model["version"] = version
381+
if org_component is not None and model_full_name_component is not None:
382+
base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
377383
metadata.base_models.append(base_model)
378384

385+
if "datasets" in model_card or "dataset" in model_card:
386+
# This represents the datasets that this was trained from
387+
metadata_datasets = []
388+
dataset_value = model_card.get("datasets", model_card.get("dataset", None))
389+
390+
if dataset_value is not None:
391+
if isinstance(dataset_value, str):
392+
metadata_datasets.append(dataset_value)
393+
elif isinstance(dataset_value, list):
394+
metadata_datasets.extend(dataset_value)
395+
396+
if metadata.datasets is None:
397+
metadata.datasets = []
398+
399+
for dataset_id in metadata_datasets:
400+
# NOTE: model size of base model is assumed to be similar to the size of the current model
401+
dataset = {}
402+
if isinstance(dataset_id, str) and (dataset_id.startswith("http://") or dataset_id.startswith("https://")):
403+
dataset["repo_url"] = dataset_id
404+
else:
405+
# Likely a Hugging Face ID
406+
dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id, total_params)
407+
if dataset_name_component is not None:
408+
dataset["name"] = Metadata.id_to_title(dataset_name_component)
409+
if org_component is not None:
410+
dataset["organization"] = Metadata.id_to_title(org_component)
411+
if version is not None:
412+
dataset["version"] = version
413+
if org_component is not None and dataset_name_component is not None:
414+
dataset["repo_url"] = f"https://huggingface.co/{org_component}/{dataset_name_component}"
415+
metadata.datasets.append(dataset)
416+
379417
use_model_card_metadata("license", "license")
380418
use_model_card_metadata("license_name", "license_name")
381419
use_model_card_metadata("license_link", "license_link")
@@ -386,9 +424,6 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
386424
use_array_model_card_metadata("languages", "languages")
387425
use_array_model_card_metadata("languages", "language")
388426

389-
use_array_model_card_metadata("datasets", "datasets")
390-
use_array_model_card_metadata("datasets", "dataset")
391-
392427
# Hugging Face Parameter Heuristics
393428
####################################
394429

@@ -504,9 +539,29 @@ def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
504539
if "repo_url" in base_model_entry:
505540
gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"])
506541

542+
if self.datasets is not None:
543+
gguf_writer.add_dataset_count(len(self.datasets))
544+
for key, dataset_entry in enumerate(self.datasets):
545+
if "name" in dataset_entry:
546+
gguf_writer.add_dataset_name(key, dataset_entry["name"])
547+
if "author" in dataset_entry:
548+
gguf_writer.add_dataset_author(key, dataset_entry["author"])
549+
if "version" in dataset_entry:
550+
gguf_writer.add_dataset_version(key, dataset_entry["version"])
551+
if "organization" in dataset_entry:
552+
gguf_writer.add_dataset_organization(key, dataset_entry["organization"])
553+
if "description" in dataset_entry:
554+
gguf_writer.add_dataset_description(key, dataset_entry["description"])
555+
if "url" in dataset_entry:
556+
gguf_writer.add_dataset_url(key, dataset_entry["url"])
557+
if "doi" in dataset_entry:
558+
gguf_writer.add_dataset_doi(key, dataset_entry["doi"])
559+
if "uuid" in dataset_entry:
560+
gguf_writer.add_dataset_uuid(key, dataset_entry["uuid"])
561+
if "repo_url" in dataset_entry:
562+
gguf_writer.add_dataset_repo_url(key, dataset_entry["repo_url"])
563+
507564
if self.tags is not None:
508565
gguf_writer.add_tags(self.tags)
509566
if self.languages is not None:
510567
gguf_writer.add_languages(self.languages)
511-
if self.datasets is not None:
512-
gguf_writer.add_datasets(self.datasets)

0 commit comments

Comments
 (0)