@@ -52,13 +52,14 @@ class Model:
52
52
is_big_endian : bool
53
53
endianess : gguf .GGUFEndian
54
54
use_temp_file : bool
55
+ lazy : bool
55
56
part_names : list [str ]
56
57
is_safetensors : bool
57
58
hparams : dict [str , Any ]
58
59
gguf_writer : gguf .GGUFWriter
59
60
block_count : int
60
61
tensor_map : gguf .TensorNameMap
61
- tensors : dict [str , Tensor ]
62
+ tensor_names : set [str ] | None
62
63
63
64
# subclasses should define this!
64
65
model_arch : gguf .MODEL_ARCH
@@ -72,6 +73,7 @@ def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian:
72
73
self .is_big_endian = is_big_endian
73
74
self .endianess = gguf .GGUFEndian .BIG if is_big_endian else gguf .GGUFEndian .LITTLE
74
75
self .use_temp_file = use_temp_file
76
+ self .lazy = not eager
75
77
self .part_names = Model .get_model_part_names (self .dir_model , ".safetensors" )
76
78
self .is_safetensors = len (self .part_names ) > 0
77
79
if not self .is_safetensors :
@@ -80,10 +82,7 @@ def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian:
80
82
self .gguf_writer = gguf .GGUFWriter (fname_out , gguf .MODEL_ARCH_NAMES [self .model_arch ], endianess = self .endianess , use_temp_file = self .use_temp_file )
81
83
self .block_count = self .find_hparam (["n_layers" , "num_hidden_layers" , "n_layer" ])
82
84
self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
83
- self .tensors = dict (self .get_tensors ())
84
- if not eager :
85
- for k , v in self .tensors .items ():
86
- self .tensors [k ] = LazyTorchTensor .from_eager (v )
85
+ self .tensor_names = None
87
86
88
87
@classmethod
89
88
def __init_subclass__ (cls ):
@@ -104,6 +103,22 @@ def set_vocab(self):
104
103
self ._set_vocab_gpt2 ()
105
104
106
105
def get_tensors (self ) -> Iterator [tuple [str , Tensor ]]:
106
+ tensor_names_from_parts : set [str ] = set ()
107
+
108
+ if len (self .part_names ) > 1 :
109
+ self .tensor_names = set ()
110
+ index_name = "model.safetensors" if self .is_safetensors else "pytorch_model.bin"
111
+ index_name += ".index.json"
112
+ logger .info (f"gguf: loading model weight map from '{ index_name } '" )
113
+ with open (self .dir_model / index_name , "r" , encoding = "utf-8" ) as f :
114
+ index : dict [str , Any ] = json .load (f )
115
+ weight_map = index .get ("weight_map" )
116
+ if weight_map is None or not isinstance (weight_map , dict ):
117
+ raise ValueError (f"Can't load 'weight_map' from { index_name !r} " )
118
+ self .tensor_names .update (weight_map .keys ())
119
+ else :
120
+ self .tensor_names = tensor_names_from_parts
121
+
107
122
for part_name in self .part_names :
108
123
logger .info (f"gguf: loading model part '{ part_name } '" )
109
124
ctx : ContextManager [Any ]
@@ -114,10 +129,18 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
114
129
ctx = contextlib .nullcontext (torch .load (str (self .dir_model / part_name ), map_location = "cpu" , mmap = True , weights_only = True ))
115
130
116
131
with ctx as model_part :
132
+ tensor_names_from_parts .update (model_part .keys ())
133
+
117
134
for name in model_part .keys ():
118
135
data = model_part .get_tensor (name ) if self .is_safetensors else model_part [name ]
136
+ if self .lazy :
137
+ data = LazyTorchTensor .from_eager (data )
119
138
yield name , data
120
139
140
+ # only verify tensor name presence; it doesn't matter if they are not in the right files
141
+ if len (sym_diff := tensor_names_from_parts .symmetric_difference (self .tensor_names )) > 0 :
142
+ raise ValueError (f"Mismatch between weight map and model parts for tensor names: { sym_diff } " )
143
+
121
144
def format_tensor_name (self , key : gguf .MODEL_TENSOR , bid : int | None = None , suffix : str = ".weight" ) -> str :
122
145
name : str = gguf .TENSOR_NAMES [key ]
123
146
if key not in gguf .MODEL_TENSORS [self .model_arch ]:
@@ -194,7 +217,7 @@ def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: i
194
217
def write_tensors (self ):
195
218
max_name_len = max (len (s ) for _ , s in self .tensor_map .mapping .values ()) + len (".weight," )
196
219
197
- for name , data_torch in self .tensors . items ():
220
+ for name , data_torch in self .get_tensors ():
198
221
# we don't need these
199
222
if name .endswith ((".attention.masked_bias" , ".attention.bias" , ".rotary_emb.inv_freq" )):
200
223
continue
@@ -248,8 +271,6 @@ def write_tensors(self):
248
271
249
272
def write (self ):
250
273
self .write_tensors ()
251
- self .tensors .clear () # save memory by not keeping references to the tensors
252
-
253
274
self .gguf_writer .write_header_to_file ()
254
275
self .gguf_writer .write_kv_data_to_file ()
255
276
self .gguf_writer .write_tensors_to_file (progress = True )
@@ -621,8 +642,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
621
642
tensors .append ((self .map_tensor_name (name ), data_torch ))
622
643
623
644
if name == "word_embeddings.weight" :
645
+ assert self .tensor_names is not None
646
+
624
647
# TODO: tie them at runtime, don't duplicate in the model file
625
- if "lm_head.weight" not in self .tensors and "output .weight" not in self . tensors :
648
+ if all ( s not in self .tensor_names for s in ( "lm_head .weight", "output.weight" )) :
626
649
tensors .append ((self .format_tensor_name (gguf .MODEL_TENSOR .OUTPUT ), data_torch ))
627
650
628
651
return tensors
@@ -1759,7 +1782,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
1759
1782
tensors : list [tuple [str , Tensor ]] = [(new_name , data_torch )]
1760
1783
1761
1784
if new_name == self .format_tensor_name (gguf .MODEL_TENSOR .TOKEN_EMBD ):
1762
- if "lm_head.weight" not in self .tensors and "output.weight" not in self .tensors :
1785
+ assert self .tensor_names is not None
1786
+
1787
+ if all (s not in self .tensor_names for s in ("lm_head.weight" , "output.weight" )):
1788
+ # copy tok_embd.weight to output.weight
1763
1789
tensors .append ((self .format_tensor_name (gguf .MODEL_TENSOR .OUTPUT ), data_torch ))
1764
1790
1765
1791
return tensors
0 commit comments