@@ -204,7 +204,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):
204
204
# First operate on meta tensors to find shapes and dtypes for GGUF header.
205
205
for idx , name in enumerate (weight_names ):
206
206
weight , scales = get_weights (f"{ input_dir } /tensor{ idx :05} _000" )
207
- meta_tensor = convert_weight (name , weight , scales , config . experts , device = "meta" )
207
+ meta_tensor = convert_weight (name , weight , scales , config , device = "meta" )
208
208
dtype , tensor_ggml_type = get_dtype_and_ggml_type (meta_tensor , ggml_type )
209
209
quantized_meta_tensor = maybe_quantize_tensor (meta_tensor , tensor_ggml_type )
210
210
f .add_tensor_info (
@@ -226,8 +226,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):
226
226
227
227
for name in weight_names :
228
228
weight , scales = weights .pop (name )
229
- tensor = convert_weight (name , weight , scales , config .experts )
230
- tensor = maybe_permute_tensor (name , tensor , config )
229
+ tensor = convert_weight (name , weight , scales , config )
231
230
_ , tensor_ggml_type = get_dtype_and_ggml_type (tensor , ggml_type )
232
231
array = maybe_quantize_tensor (tensor , tensor_ggml_type ).numpy ()
233
232
@@ -258,7 +257,7 @@ def from_numpy(array):
258
257
return torch .from_numpy (array )
259
258
260
259
261
- def convert_weight (name , weight , scales , experts , dtype = torch .float32 , device = None ):
260
+ def convert_weight (name , weight , scales , config , dtype = torch .float32 , device = None ):
262
261
# copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1
263
262
weight = from_numpy (weight ).to (device = device , dtype = dtype )
264
263
if scales is not None :
@@ -271,32 +270,19 @@ def convert_weight(name, weight, scales, experts, dtype=torch.float32, device=No
271
270
else :
272
271
weight = weight * scale
273
272
274
- # Transpose linear matrix
275
- if len (weight .shape ) >= 2 and "token_embd" not in name :
273
+ if name == "token_embd" :
274
+ weight *= config .embedding_multiplier_scale
275
+ elif len (weight .shape ) >= 2 :
276
+ # Transpose linear matrix
276
277
weight = weight .transpose (- 1 , - 2 )
277
278
279
+
278
280
if name .endswith ("ffn_gate_inp" ) or name .endswith ("_exps" ):
279
- weight = weight [experts ] # gather.
281
+ weight = weight [config . experts ] # gather.
280
282
281
283
return weight
282
284
283
285
284
- def maybe_permute_tensor (name , tensor , config ):
285
- def permute (weights , n_head ):
286
- return (
287
- weights .reshape (n_head , 2 , weights .shape [0 ] // n_head // 2 , * weights .shape [1 :])
288
- .swapaxes (1 , 2 )
289
- .reshape (weights .shape )
290
- )
291
-
292
- if name .endswith ("attn_k" ):
293
- return permute (tensor , config .num_key_value_heads )
294
- elif name .endswith ("attn_q" ):
295
- return permute (tensor , config .num_attention_heads )
296
-
297
- return tensor
298
-
299
-
300
286
def extract_vocabulary_from_model (vocab ):
301
287
tokens = []
302
288
scores = []
0 commit comments