@@ -196,55 +196,47 @@ def get_dtype_and_ggml_type(tensor, ggml_type):
196
196
return np .float32 , gguf .GGMLQuantizationType .F32
197
197
198
198
199
- def dump_state_dict (f , weight_names , model_files , ggml_type , config ):
200
- keys2names = {}
201
- meta_tensors = {}
202
- weight_scales = {}
199
+ def dump_state_dict (f , ggml_type , input_dir , config ):
200
+ weight_names = get_weight_names (config .num_hidden_layers )
201
+ weights = {}
203
202
204
203
# First operate on meta tensors to find shapes and dtypes for GGUF header.
205
- for name , fn in model_files :
206
- weight , scales = get_weights (fn )
207
- meta_state_dict = convert_weight (name , weight , scales , config .experts , device = "meta" )
208
- weight_scales [name ] = (weight , scales )
209
- for key in meta_state_dict .keys ():
210
- keys2names [key ] = name
211
-
212
- meta_tensors .update (meta_state_dict )
213
-
214
- for key in weight_names :
215
- meta_tensor = meta_tensors [key ]
204
+ for idx , name in enumerate (weight_names ):
205
+ weight , scales = get_weights (f"{ input_dir } /tensor{ idx :05} _000" )
206
+ meta_tensor = convert_weight (name , weight , scales , config .experts , device = "meta" )
216
207
dtype , tensor_ggml_type = get_dtype_and_ggml_type (meta_tensor , ggml_type )
217
208
quantized_meta_tensor = maybe_quantize_tensor (meta_tensor , tensor_ggml_type )
218
209
f .add_tensor_info (
219
- key , list (meta_tensor .shape ), dtype , quantized_meta_tensor .nbytes , tensor_ggml_type
210
+ f"{ name } .weight" ,
211
+ list (meta_tensor .shape ),
212
+ dtype ,
213
+ quantized_meta_tensor .nbytes ,
214
+ tensor_ggml_type ,
220
215
)
221
- print ("Loaded" , len (meta_tensors ), "files" )
216
+ weights [name ] = weight , scales
217
+ print ("Loaded" , len (weight_names ), "files" )
222
218
223
219
f .write_header_to_file ()
224
220
f .write_kv_data_to_file ()
225
221
f .write_ti_data_to_file ()
226
222
227
- cache = {}
223
+ # Now write actual tensor data.
228
224
tensor_info = []
229
225
230
- for key in weight_names :
231
- if key not in cache :
232
- name = keys2names [key ]
233
- weight , scales = weight_scales .pop (name )
234
- state_dict = convert_weight (name , weight , scales , config .experts )
235
- permute_tensors (state_dict , config )
236
- cache .update (state_dict )
237
- tensor = cache .pop (key )
226
+ for name in weight_names :
227
+ weight , scales = weights .pop (name )
228
+ tensor = convert_weight (name , weight , scales , config .experts )
229
+ tensor = maybe_permute_tensor (name , tensor , config )
238
230
_ , tensor_ggml_type = get_dtype_and_ggml_type (tensor , ggml_type )
239
231
array = maybe_quantize_tensor (tensor , tensor_ggml_type ).numpy ()
240
232
241
233
print (
242
- f"dumping { key } :" ,
234
+ f"dumping { name } :" ,
243
235
f"{ tensor_ggml_type .name } /{ array .dtype } , { list (tensor .shape )} , { array .nbytes } bytes" ,
244
236
)
245
237
f .write_tensor_data (array )
246
238
247
- tensor_info .append ((key , list (tensor .shape ), tensor_ggml_type .name ))
239
+ tensor_info .append ((name , list (tensor .shape ), tensor_ggml_type .name ))
248
240
249
241
try :
250
242
print (tabulate (tensor_info , headers = ["name" , "shape" , "dtype" ], tablefmt = "psql" ))
@@ -263,10 +255,8 @@ def from_numpy(array):
263
255
return torch .from_numpy (array )
264
256
265
257
266
- def convert_weight (tensor_name , weight , scales , experts , dtype = torch .float32 , device = None ):
258
+ def convert_weight (name , weight , scales , experts , dtype = torch .float32 , device = None ):
267
259
# copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1
268
- result = {}
269
-
270
260
weight = from_numpy (weight ).to (device = device , dtype = dtype )
271
261
if scales is not None :
272
262
scale = from_numpy (scales ).to (device = device , dtype = dtype )
@@ -279,30 +269,29 @@ def convert_weight(tensor_name, weight, scales, experts, dtype=torch.float32, de
279
269
weight = weight * scale
280
270
281
271
# Transpose linear matrix
282
- if len (weight .shape ) >= 2 and "token_embd" not in tensor_name :
272
+ if len (weight .shape ) >= 2 and "token_embd" not in name :
283
273
weight = weight .transpose (- 1 , - 2 )
284
274
285
- if tensor_name .endswith ("ffn_gate_inp.weight" ) or tensor_name .endswith ("_exps.weight" ):
286
- result [tensor_name ] = weight [experts ] # gather.
287
- elif "experts" not in tensor_name :
288
- result [tensor_name ] = weight
275
+ if name .endswith ("ffn_gate_inp" ) or name .endswith ("_exps" ):
276
+ weight = weight [experts ] # gather.
289
277
290
- return result
278
+ return weight
291
279
292
280
293
- def permute_tensors ( state_dict , config ):
281
+ def maybe_permute_tensor ( name , tensor , config ):
294
282
def permute (weights , n_head ):
295
283
return (
296
284
weights .reshape (n_head , 2 , weights .shape [0 ] // n_head // 2 , * weights .shape [1 :])
297
285
.swapaxes (1 , 2 )
298
286
.reshape (weights .shape )
299
287
)
300
288
301
- for name , tensor in state_dict .items ():
302
- if name .endswith ("attn_k.weight" ):
303
- state_dict [name ] = permute (tensor , config .num_key_value_heads )
304
- elif name .endswith ("attn_q.weight" ):
305
- state_dict [name ] = permute (tensor , config .num_attention_heads )
289
+ if name .endswith ("attn_k" ):
290
+ return permute (tensor , config .num_key_value_heads )
291
+ elif name .endswith ("attn_q" ):
292
+ return permute (tensor , config .num_attention_heads )
293
+
294
+ return tensor
306
295
307
296
308
297
def extract_vocabulary_from_model (vocab ):
@@ -320,25 +309,32 @@ def extract_vocabulary_from_model(vocab):
320
309
return tokens , scores , toktypes
321
310
322
311
323
- def get_weight_names (config ):
324
- weight_names = ["token_embd.weight" ]
325
- for i in range (config .num_hidden_layers ):
326
- weight_names += [
327
- f"blk.{ i } .ffn_gate_exps.weight" ,
328
- f"blk.{ i } .ffn_down_exps.weight" ,
329
- f"blk.{ i } .ffn_up_exps.weight" ,
330
- f"blk.{ i } .attn_k.weight" ,
331
- f"blk.{ i } .attn_output.weight" ,
332
- f"blk.{ i } .attn_q.weight" ,
333
- f"blk.{ i } .attn_v.weight" ,
334
- f"blk.{ i } .attn_norm.weight" ,
335
- f"blk.{ i } .attn_output_norm.weight" ,
336
- f"blk.{ i } .ffn_norm.weight" ,
337
- f"blk.{ i } .layer_output_norm.weight" ,
338
- f"blk.{ i } .ffn_gate_inp.weight" ,
339
- ]
340
-
341
- weight_names += ["output_norm.weight" ]
312
+ def get_weight_names (num_hidden_layers = 64 ):
313
+ """Return Grok-1 weight names, in the order in which they are in the tensor#####_000 files."""
314
+
315
+ weight_names = [
316
+ gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .TOKEN_EMBD ],
317
+ gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .OUTPUT_NORM ],
318
+ ]
319
+
320
+ layer = (
321
+ gguf .MODEL_TENSOR .FFN_GATE_EXP ,
322
+ gguf .MODEL_TENSOR .FFN_DOWN_EXP ,
323
+ gguf .MODEL_TENSOR .FFN_UP_EXP ,
324
+ gguf .MODEL_TENSOR .ATTN_K ,
325
+ gguf .MODEL_TENSOR .ATTN_OUT ,
326
+ gguf .MODEL_TENSOR .ATTN_Q ,
327
+ gguf .MODEL_TENSOR .ATTN_V ,
328
+ gguf .MODEL_TENSOR .ATTN_NORM ,
329
+ gguf .MODEL_TENSOR .ATTN_OUT_NORM ,
330
+ gguf .MODEL_TENSOR .FFN_NORM ,
331
+ gguf .MODEL_TENSOR .LAYER_OUT_NORM ,
332
+ gguf .MODEL_TENSOR .FFN_GATE_INP ,
333
+ )
334
+
335
+ for bid in range (num_hidden_layers ):
336
+ for key in layer :
337
+ weight_names .append (gguf .TENSOR_NAMES [key ].format (bid = bid ))
342
338
343
339
return weight_names
344
340
@@ -383,28 +379,6 @@ def ffn_size(emb_size, widening_factor):
383
379
assert config .num_experts >= 2 , "need at least 2 experts"
384
380
print ("experts to export:" , config .experts )
385
381
386
- # Contents of in Grok-1 pickle files, in order. Weights with "experts" will be split later.
387
- tensor_names = [
388
- "token_embd.weight" ,
389
- "output_norm.weight" ,
390
- ]
391
- for i in range (config .num_hidden_layers ):
392
- tensor_names += [
393
- f"blk.{ i } .ffn_gate_exps.weight" ,
394
- f"blk.{ i } .ffn_down_exps.weight" ,
395
- f"blk.{ i } .ffn_up_exps.weight" ,
396
- f"blk.{ i } .attn_k.weight" ,
397
- f"blk.{ i } .attn_output.weight" ,
398
- f"blk.{ i } .attn_q.weight" ,
399
- f"blk.{ i } .attn_v.weight" ,
400
- f"blk.{ i } .attn_norm.weight" ,
401
- f"blk.{ i } .attn_output_norm.weight" ,
402
- f"blk.{ i } .ffn_norm.weight" ,
403
- f"blk.{ i } .layer_output_norm.weight" ,
404
- f"blk.{ i } .ffn_gate_inp.weight" ,
405
- ]
406
-
407
- tensor_map = [(name , f"{ args .input } /tensor{ i :05} _000" ) for i , name in enumerate (tensor_names )]
408
382
f = gguf .GGUFWriter (args .save_path , "grok" , endianess = gguf .GGUFEndian .LITTLE )
409
383
410
384
f .add_name ("grok" )
@@ -430,8 +404,7 @@ def ffn_size(emb_size, widening_factor):
430
404
f .add_token_scores (scores )
431
405
f .add_token_types (toktypes )
432
406
433
- weight_names = get_weight_names (config )
434
- dump_state_dict (f , weight_names , tensor_map , ggml_type , config )
407
+ dump_state_dict (f , ggml_type , args .input_dir , config )
435
408
f .close ()
436
409
437
410
delta = time .time () - start
@@ -465,7 +438,7 @@ def load_spm(p):
465
438
466
439
def main ():
467
440
parser = argparse .ArgumentParser ("convert_grok" )
468
- parser .add_argument ("-i" , "--input " , type = str )
441
+ parser .add_argument ("-i" , "--input_dir " , type = str )
469
442
parser .add_argument ("-o" , "--save_path" , type = pathlib .Path )
470
443
parser .add_argument (
471
444
"-t" , "--type" , type = str , default = "q8_0" , choices = ["f32" , "f16" , "q8_0" , "q4_0" , "q4_1" ]
@@ -474,7 +447,9 @@ def main():
474
447
parser .add_argument ("--experts" , type = str , default = "" )
475
448
args = parser .parse_args ()
476
449
477
- vocab = load_vocab (pathlib .Path (args .vocab_dir ) if args .vocab_dir else pathlib .Path (args .input ))
450
+ vocab = load_vocab (
451
+ pathlib .Path (args .vocab_dir ) if args .vocab_dir else pathlib .Path (args .input_dir )
452
+ )
478
453
ggml_type = gguf .GGMLQuantizationType [args .type .upper ()]
479
454
convert_grok (args , vocab , ggml_type )
480
455
0 commit comments