Skip to content

Commit ce067af

Browse files
committed
convert-hf : use an ABC for Model again
It seems Protocol can't be used as a statically type-checked ABC, because its subclasses also can't be instantiated. (why did it seem to work?) At least there's still a way to throw an error when forgetting to define the `model_arch` property of any registered Model subclasses.
1 parent 644c269 commit ce067af

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

convert-hf-to-gguf.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
import os
99
import re
1010
import sys
11+
from abc import ABC
1112
from enum import IntEnum
1213
from pathlib import Path
1314
from hashlib import sha256
14-
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Protocol, Sequence, TypeVar, cast
15+
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
1516

1617
import numpy as np
1718
import torch
@@ -40,7 +41,7 @@ class SentencePieceTokenTypes(IntEnum):
4041
AnyModel = TypeVar("AnyModel", bound="type[Model]")
4142

4243

43-
class Model(Protocol):
44+
class Model(ABC):
4445
_model_classes: dict[str, type[Model]] = {}
4546

4647
dir_model: Path
@@ -57,6 +58,7 @@ class Model(Protocol):
5758
tensor_map: gguf.TensorNameMap
5859
tensors: dict[str, Tensor]
5960

61+
# subclasses should define this!
6062
model_arch: gguf.MODEL_ARCH
6163

6264
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool):
@@ -163,12 +165,18 @@ def set_gguf_parameters(self):
163165
print(f"gguf: file type = {self.ftype}")
164166

165167
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
168+
del bid # unused
169+
166170
return [(self.map_tensor_name(name), data_torch)]
167171

168172
def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
173+
del name, new_name, bid, n_dims # unused
174+
169175
return False
170176

171177
def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
178+
del name, new_name, bid, n_dims # unused
179+
172180
return False
173181

174182
def write_tensors(self):
@@ -248,7 +256,7 @@ def get_model_part_names(dir_model: Path, suffix: str) -> list[str]:
248256
return part_names
249257

250258
@staticmethod
251-
def load_hparams(dir_model):
259+
def load_hparams(dir_model: Path):
252260
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
253261
return json.load(f)
254262

@@ -258,12 +266,14 @@ def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
258266

259267
def func(modelcls: AnyModel) -> AnyModel:
260268
for name in names:
269+
if "model_arch" not in modelcls.__dict__:
270+
raise TypeError(f"Missing property 'model_arch' for {modelcls.__name__!r}")
261271
cls._model_classes[name] = modelcls
262272
return modelcls
263273
return func
264274

265275
@classmethod
266-
def from_model_architecture(cls, arch):
276+
def from_model_architecture(cls, arch: str) -> type[Model]:
267277
try:
268278
return cls._model_classes[arch]
269279
except KeyError:

0 commit comments

Comments
 (0)