8
8
import os
9
9
import re
10
10
import sys
11
+ from abc import ABC
11
12
from enum import IntEnum
12
13
from pathlib import Path
13
14
from hashlib import sha256
14
- from typing import TYPE_CHECKING , Any , Callable , ContextManager , Iterable , Iterator , Protocol , Sequence , TypeVar , cast
15
+ from typing import TYPE_CHECKING , Any , Callable , ContextManager , Iterable , Iterator , Sequence , TypeVar , cast
15
16
16
17
import numpy as np
17
18
import torch
@@ -40,7 +41,7 @@ class SentencePieceTokenTypes(IntEnum):
40
41
AnyModel = TypeVar ("AnyModel" , bound = "type[Model]" )
41
42
42
43
43
- class Model (Protocol ):
44
+ class Model (ABC ):
44
45
_model_classes : dict [str , type [Model ]] = {}
45
46
46
47
dir_model : Path
@@ -57,6 +58,7 @@ class Model(Protocol):
57
58
tensor_map : gguf .TensorNameMap
58
59
tensors : dict [str , Tensor ]
59
60
61
+ # subclasses should define this!
60
62
model_arch : gguf .MODEL_ARCH
61
63
62
64
def __init__ (self , dir_model : Path , ftype : int , fname_out : Path , is_big_endian : bool , use_temp_file : bool ):
@@ -163,12 +165,18 @@ def set_gguf_parameters(self):
163
165
print (f"gguf: file type = { self .ftype } " )
164
166
165
167
def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
168
+ del bid # unused
169
+
166
170
return [(self .map_tensor_name (name ), data_torch )]
167
171
168
172
def extra_f32_tensors (self , name : str , new_name : str , bid : int | None , n_dims : int ) -> bool :
173
+ del name , new_name , bid , n_dims # unused
174
+
169
175
return False
170
176
171
177
def extra_f16_tensors (self , name : str , new_name : str , bid : int | None , n_dims : int ) -> bool :
178
+ del name , new_name , bid , n_dims # unused
179
+
172
180
return False
173
181
174
182
def write_tensors (self ):
@@ -248,7 +256,7 @@ def get_model_part_names(dir_model: Path, suffix: str) -> list[str]:
248
256
return part_names
249
257
250
258
@staticmethod
251
- def load_hparams (dir_model ):
259
+ def load_hparams (dir_model : Path ):
252
260
with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
253
261
return json .load (f )
254
262
@@ -258,12 +266,14 @@ def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
258
266
259
267
def func (modelcls : AnyModel ) -> AnyModel :
260
268
for name in names :
269
+ if "model_arch" not in modelcls .__dict__ :
270
+ raise TypeError (f"Missing property 'model_arch' for { modelcls .__name__ !r} " )
261
271
cls ._model_classes [name ] = modelcls
262
272
return modelcls
263
273
return func
264
274
265
275
@classmethod
266
- def from_model_architecture (cls , arch ) :
276
+ def from_model_architecture (cls , arch : str ) -> type [ Model ] :
267
277
try :
268
278
return cls ._model_classes [arch ]
269
279
except KeyError :
0 commit comments