Skip to content

Commit c9c8952

Browse files
committed
Use only one list of weight names, with values from the gguf module.
This saves weights in the order in which they are in the Grok-1 files. Since we operate weight-by-weight now, we no longer need caches and name2key translations. Per reviewer request, I also moved to using keys in gguf.TENSOR_NAMES.
1 parent 72920b1 commit c9c8952

File tree

1 file changed

+63
-88
lines changed

1 file changed

+63
-88
lines changed

convert_grok.py

Lines changed: 63 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -196,55 +196,47 @@ def get_dtype_and_ggml_type(tensor, ggml_type):
196196
return np.float32, gguf.GGMLQuantizationType.F32
197197

198198

199-
def dump_state_dict(f, weight_names, model_files, ggml_type, config):
200-
keys2names = {}
201-
meta_tensors = {}
202-
weight_scales = {}
199+
def dump_state_dict(f, ggml_type, input_dir, config):
200+
weight_names = get_weight_names(config.num_hidden_layers)
201+
weights = {}
203202

204203
# First operate on meta tensors to find shapes and dtypes for GGUF header.
205-
for name, fn in model_files:
206-
weight, scales = get_weights(fn)
207-
meta_state_dict = convert_weight(name, weight, scales, config.experts, device="meta")
208-
weight_scales[name] = (weight, scales)
209-
for key in meta_state_dict.keys():
210-
keys2names[key] = name
211-
212-
meta_tensors.update(meta_state_dict)
213-
214-
for key in weight_names:
215-
meta_tensor = meta_tensors[key]
204+
for idx, name in enumerate(weight_names):
205+
weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000")
206+
meta_tensor = convert_weight(name, weight, scales, config.experts, device="meta")
216207
dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type)
217208
quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type)
218209
f.add_tensor_info(
219-
key, list(meta_tensor.shape), dtype, quantized_meta_tensor.nbytes, tensor_ggml_type
210+
f"{name}.weight",
211+
list(meta_tensor.shape),
212+
dtype,
213+
quantized_meta_tensor.nbytes,
214+
tensor_ggml_type,
220215
)
221-
print("Loaded", len(meta_tensors), "files")
216+
weights[name] = weight, scales
217+
print("Loaded", len(weight_names), "files")
222218

223219
f.write_header_to_file()
224220
f.write_kv_data_to_file()
225221
f.write_ti_data_to_file()
226222

227-
cache = {}
223+
# Now write actual tensor data.
228224
tensor_info = []
229225

230-
for key in weight_names:
231-
if key not in cache:
232-
name = keys2names[key]
233-
weight, scales = weight_scales.pop(name)
234-
state_dict = convert_weight(name, weight, scales, config.experts)
235-
permute_tensors(state_dict, config)
236-
cache.update(state_dict)
237-
tensor = cache.pop(key)
226+
for name in weight_names:
227+
weight, scales = weights.pop(name)
228+
tensor = convert_weight(name, weight, scales, config.experts)
229+
tensor = maybe_permute_tensor(name, tensor, config)
238230
_, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type)
239231
array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy()
240232

241233
print(
242-
f"dumping {key}:",
234+
f"dumping {name}:",
243235
f"{tensor_ggml_type.name}/{array.dtype}, {list(tensor.shape)}, {array.nbytes} bytes",
244236
)
245237
f.write_tensor_data(array)
246238

247-
tensor_info.append((key, list(tensor.shape), tensor_ggml_type.name))
239+
tensor_info.append((name, list(tensor.shape), tensor_ggml_type.name))
248240

249241
try:
250242
print(tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql"))
@@ -263,10 +255,8 @@ def from_numpy(array):
263255
return torch.from_numpy(array)
264256

265257

266-
def convert_weight(tensor_name, weight, scales, experts, dtype=torch.float32, device=None):
258+
def convert_weight(name, weight, scales, experts, dtype=torch.float32, device=None):
267259
# copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1
268-
result = {}
269-
270260
weight = from_numpy(weight).to(device=device, dtype=dtype)
271261
if scales is not None:
272262
scale = from_numpy(scales).to(device=device, dtype=dtype)
@@ -279,30 +269,29 @@ def convert_weight(tensor_name, weight, scales, experts, dtype=torch.float32, de
279269
weight = weight * scale
280270

281271
# Transpose linear matrix
282-
if len(weight.shape) >= 2 and "token_embd" not in tensor_name:
272+
if len(weight.shape) >= 2 and "token_embd" not in name:
283273
weight = weight.transpose(-1, -2)
284274

285-
if tensor_name.endswith("ffn_gate_inp.weight") or tensor_name.endswith("_exps.weight"):
286-
result[tensor_name] = weight[experts] # gather.
287-
elif "experts" not in tensor_name:
288-
result[tensor_name] = weight
275+
if name.endswith("ffn_gate_inp") or name.endswith("_exps"):
276+
weight = weight[experts] # gather.
289277

290-
return result
278+
return weight
291279

292280

293-
def permute_tensors(state_dict, config):
281+
def maybe_permute_tensor(name, tensor, config):
294282
def permute(weights, n_head):
295283
return (
296284
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
297285
.swapaxes(1, 2)
298286
.reshape(weights.shape)
299287
)
300288

301-
for name, tensor in state_dict.items():
302-
if name.endswith("attn_k.weight"):
303-
state_dict[name] = permute(tensor, config.num_key_value_heads)
304-
elif name.endswith("attn_q.weight"):
305-
state_dict[name] = permute(tensor, config.num_attention_heads)
289+
if name.endswith("attn_k"):
290+
return permute(tensor, config.num_key_value_heads)
291+
elif name.endswith("attn_q"):
292+
return permute(tensor, config.num_attention_heads)
293+
294+
return tensor
306295

307296

308297
def extract_vocabulary_from_model(vocab):
@@ -320,25 +309,32 @@ def extract_vocabulary_from_model(vocab):
320309
return tokens, scores, toktypes
321310

322311

323-
def get_weight_names(config):
324-
weight_names = ["token_embd.weight"]
325-
for i in range(config.num_hidden_layers):
326-
weight_names += [
327-
f"blk.{i}.ffn_gate_exps.weight",
328-
f"blk.{i}.ffn_down_exps.weight",
329-
f"blk.{i}.ffn_up_exps.weight",
330-
f"blk.{i}.attn_k.weight",
331-
f"blk.{i}.attn_output.weight",
332-
f"blk.{i}.attn_q.weight",
333-
f"blk.{i}.attn_v.weight",
334-
f"blk.{i}.attn_norm.weight",
335-
f"blk.{i}.attn_output_norm.weight",
336-
f"blk.{i}.ffn_norm.weight",
337-
f"blk.{i}.layer_output_norm.weight",
338-
f"blk.{i}.ffn_gate_inp.weight",
339-
]
340-
341-
weight_names += ["output_norm.weight"]
312+
def get_weight_names(num_hidden_layers=64):
313+
"""Return Grok-1 weight names, in the order in which they are in the tensor#####_000 files."""
314+
315+
weight_names = [
316+
gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD],
317+
gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT_NORM],
318+
]
319+
320+
layer = (
321+
gguf.MODEL_TENSOR.FFN_GATE_EXP,
322+
gguf.MODEL_TENSOR.FFN_DOWN_EXP,
323+
gguf.MODEL_TENSOR.FFN_UP_EXP,
324+
gguf.MODEL_TENSOR.ATTN_K,
325+
gguf.MODEL_TENSOR.ATTN_OUT,
326+
gguf.MODEL_TENSOR.ATTN_Q,
327+
gguf.MODEL_TENSOR.ATTN_V,
328+
gguf.MODEL_TENSOR.ATTN_NORM,
329+
gguf.MODEL_TENSOR.ATTN_OUT_NORM,
330+
gguf.MODEL_TENSOR.FFN_NORM,
331+
gguf.MODEL_TENSOR.LAYER_OUT_NORM,
332+
gguf.MODEL_TENSOR.FFN_GATE_INP,
333+
)
334+
335+
for bid in range(num_hidden_layers):
336+
for key in layer:
337+
weight_names.append(gguf.TENSOR_NAMES[key].format(bid=bid))
342338

343339
return weight_names
344340

@@ -383,28 +379,6 @@ def ffn_size(emb_size, widening_factor):
383379
assert config.num_experts >= 2, "need at least 2 experts"
384380
print("experts to export:", config.experts)
385381

386-
# Contents of in Grok-1 pickle files, in order. Weights with "experts" will be split later.
387-
tensor_names = [
388-
"token_embd.weight",
389-
"output_norm.weight",
390-
]
391-
for i in range(config.num_hidden_layers):
392-
tensor_names += [
393-
f"blk.{i}.ffn_gate_exps.weight",
394-
f"blk.{i}.ffn_down_exps.weight",
395-
f"blk.{i}.ffn_up_exps.weight",
396-
f"blk.{i}.attn_k.weight",
397-
f"blk.{i}.attn_output.weight",
398-
f"blk.{i}.attn_q.weight",
399-
f"blk.{i}.attn_v.weight",
400-
f"blk.{i}.attn_norm.weight",
401-
f"blk.{i}.attn_output_norm.weight",
402-
f"blk.{i}.ffn_norm.weight",
403-
f"blk.{i}.layer_output_norm.weight",
404-
f"blk.{i}.ffn_gate_inp.weight",
405-
]
406-
407-
tensor_map = [(name, f"{args.input}/tensor{i:05}_000") for i, name in enumerate(tensor_names)]
408382
f = gguf.GGUFWriter(args.save_path, "grok", endianess=gguf.GGUFEndian.LITTLE)
409383

410384
f.add_name("grok")
@@ -430,8 +404,7 @@ def ffn_size(emb_size, widening_factor):
430404
f.add_token_scores(scores)
431405
f.add_token_types(toktypes)
432406

433-
weight_names = get_weight_names(config)
434-
dump_state_dict(f, weight_names, tensor_map, ggml_type, config)
407+
dump_state_dict(f, ggml_type, args.input_dir, config)
435408
f.close()
436409

437410
delta = time.time() - start
@@ -465,7 +438,7 @@ def load_spm(p):
465438

466439
def main():
467440
parser = argparse.ArgumentParser("convert_grok")
468-
parser.add_argument("-i", "--input", type=str)
441+
parser.add_argument("-i", "--input_dir", type=str)
469442
parser.add_argument("-o", "--save_path", type=pathlib.Path)
470443
parser.add_argument(
471444
"-t", "--type", type=str, default="q8_0", choices=["f32", "f16", "q8_0", "q4_0", "q4_1"]
@@ -474,7 +447,9 @@ def main():
474447
parser.add_argument("--experts", type=str, default="")
475448
args = parser.parse_args()
476449

477-
vocab = load_vocab(pathlib.Path(args.vocab_dir) if args.vocab_dir else pathlib.Path(args.input))
450+
vocab = load_vocab(
451+
pathlib.Path(args.vocab_dir) if args.vocab_dir else pathlib.Path(args.input_dir)
452+
)
478453
ggml_type = gguf.GGMLQuantizationType[args.type.upper()]
479454
convert_grok(args, vocab, ggml_type)
480455

0 commit comments

Comments
 (0)