Skip to content

Commit 13f4cf7

Browse files
committed
convert-hf : use a plain class for Model, and forbid direct instantiation
There are no abstract methods used anyway, so using ABC isn't really necessary.
1 parent ce067af commit 13f4cf7

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

convert-hf-to-gguf.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import os
99
import re
1010
import sys
11-
from abc import ABC
1211
from enum import IntEnum
1312
from pathlib import Path
1413
from hashlib import sha256
@@ -41,7 +40,7 @@ class SentencePieceTokenTypes(IntEnum):
4140
AnyModel = TypeVar("AnyModel", bound="type[Model]")
4241

4342

44-
class Model(ABC):
43+
class Model:
4544
_model_classes: dict[str, type[Model]] = {}
4645

4746
dir_model: Path
@@ -62,6 +61,8 @@ class Model(ABC):
6261
model_arch: gguf.MODEL_ARCH
6362

6463
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool):
64+
if self.__class__ == Model:
65+
raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated")
6566
self.dir_model = dir_model
6667
self.ftype = ftype
6768
self.fname_out = fname_out
@@ -78,6 +79,13 @@ def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian:
7879
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
7980
self.tensors = dict(self.get_tensors())
8081

82+
@classmethod
83+
def __init_subclass__(cls):
84+
# can't use an abstract property, because overriding it without type errors
85+
# would require using decorated functions instead of simply defining the property
86+
if "model_arch" not in cls.__dict__:
87+
raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}")
88+
8189
def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
8290
key = next((k for k in keys if k in self.hparams), None)
8391
if key is not None:
@@ -266,8 +274,6 @@ def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
266274

267275
def func(modelcls: AnyModel) -> AnyModel:
268276
for name in names:
269-
if "model_arch" not in modelcls.__dict__:
270-
raise TypeError(f"Missing property 'model_arch' for {modelcls.__name__!r}")
271277
cls._model_classes[name] = modelcls
272278
return modelcls
273279
return func

0 commit comments

Comments
 (0)