Skip to content

Commit 1d295cf

Browse files
committed
py: let users add full base model and dataset to model_card
if they choose to. Also added more tests for this feature.
1 parent 0d7ce70 commit 1d295cf

File tree

2 files changed

+69
-24
lines changed

2 files changed

+69
-24
lines changed

gguf-py/gguf/metadata.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -367,19 +367,24 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
367367
for model_id in metadata_base_models:
368368
# NOTE: model size of base model is assumed to be similar to the size of the current model
369369
base_model = {}
370-
if isinstance(model_id, str) and (model_id.startswith("http://") or model_id.startswith("https://")):
371-
base_model["repo_url"] = model_id
370+
if isinstance(model_id, str):
371+
if model_id.startswith("http://") or model_id.startswith("https://") or model_id.startswith("ssh://"):
372+
base_model["repo_url"] = model_id
373+
else:
374+
# Likely a Hugging Face ID
375+
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
376+
if model_full_name_component is not None:
377+
base_model["name"] = Metadata.id_to_title(model_full_name_component)
378+
if org_component is not None:
379+
base_model["organization"] = Metadata.id_to_title(org_component)
380+
if version is not None:
381+
base_model["version"] = version
382+
if org_component is not None and model_full_name_component is not None:
383+
base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
384+
elif isinstance(model_id, dict):
385+
base_model = model_id
372386
else:
373-
# Likely a Hugging Face ID
374-
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
375-
if model_full_name_component is not None:
376-
base_model["name"] = Metadata.id_to_title(model_full_name_component)
377-
if org_component is not None:
378-
base_model["organization"] = Metadata.id_to_title(org_component)
379-
if version is not None:
380-
base_model["version"] = version
381-
if org_component is not None and model_full_name_component is not None:
382-
base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
387+
logger.error(f"base model entry '{str(model_id)}' not in a known format")
383388
metadata.base_models.append(base_model)
384389

385390
if "datasets" in model_card or "dataset" in model_card:
@@ -399,19 +404,24 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
399404
for dataset_id in metadata_datasets:
400405
# NOTE: model size of base model is assumed to be similar to the size of the current model
401406
dataset = {}
402-
if isinstance(dataset_id, str) and (dataset_id.startswith("http://") or dataset_id.startswith("https://")):
403-
dataset["repo_url"] = dataset_id
407+
if isinstance(dataset_id, str):
408+
if dataset_id.startswith("http://") or dataset_id.startswith("https://") or dataset_id.startswith("ssh://"):
409+
dataset["repo_url"] = dataset_id
410+
else:
411+
# Likely a Hugging Face ID
412+
dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id, total_params)
413+
if dataset_name_component is not None:
414+
dataset["name"] = Metadata.id_to_title(dataset_name_component)
415+
if org_component is not None:
416+
dataset["organization"] = Metadata.id_to_title(org_component)
417+
if version is not None:
418+
dataset["version"] = version
419+
if org_component is not None and dataset_name_component is not None:
420+
dataset["repo_url"] = f"https://huggingface.co/{org_component}/{dataset_name_component}"
421+
elif isinstance(dataset_id, dict):
422+
dataset = dataset_id
404423
else:
405-
# Likely a Hugging Face ID
406-
dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id, total_params)
407-
if dataset_name_component is not None:
408-
dataset["name"] = Metadata.id_to_title(dataset_name_component)
409-
if org_component is not None:
410-
dataset["organization"] = Metadata.id_to_title(org_component)
411-
if version is not None:
412-
dataset["version"] = version
413-
if org_component is not None and dataset_name_component is not None:
414-
dataset["repo_url"] = f"https://huggingface.co/{org_component}/{dataset_name_component}"
424+
logger.error(f"dataset entry '{str(dataset_id)}' not in a known format")
415425
metadata.datasets.append(dataset)
416426

417427
use_model_card_metadata("license", "license")

gguf-py/tests/test_metadata.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,42 @@ def test_apply_metadata_heuristic_from_model_card(self):
183183
expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl']
184184
expect.languages=['en']
185185
expect.datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]
186+
self.assertEqual(got, expect)
186187

188+
# Base Model spec is inferred from model id
189+
model_card = {'base_models': ['teknium/OpenHermes-2.5']}
190+
expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
191+
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
192+
self.assertEqual(got, expect)
193+
194+
# Base Model spec is only url
195+
model_card = {'base_models': ['https://huggingface.co/teknium/OpenHermes-2.5']}
196+
expect = gguf.Metadata(base_models=[{'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
197+
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
198+
self.assertEqual(got, expect)
199+
200+
# Base Model spec is given directly
201+
model_card = {'base_models': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]}
202+
expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
203+
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
204+
self.assertEqual(got, expect)
205+
206+
# Dataset spec is inferred from model id
207+
model_card = {'datasets': ['teknium/OpenHermes-2.5']}
208+
expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
209+
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
210+
self.assertEqual(got, expect)
211+
212+
# Dataset spec is only url
213+
model_card = {'datasets': ['https://huggingface.co/teknium/OpenHermes-2.5']}
214+
expect = gguf.Metadata(datasets=[{'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
215+
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
216+
self.assertEqual(got, expect)
217+
218+
# Dataset spec is given directly
219+
model_card = {'datasets': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]}
220+
expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
221+
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
187222
self.assertEqual(got, expect)
188223

189224
def test_apply_metadata_heuristic_from_hf_parameters(self):

0 commit comments

Comments
 (0)