Skip to content

Commit 467b149

Browse files
qunashggerganov
andauthored
Refactoring convert-pth-to-ggml.py: more concise and readable (#109)
* Refactor get_n_parts function to simplify code and improve readability * Use f-strings instead of concatenation * Refactoring: more concise and readable * modularize --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 70f01cb commit 467b149

File tree

1 file changed

+84
-108
lines changed

1 file changed

+84
-108
lines changed

convert-pth-to-ggml.py

Lines changed: 84 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -16,145 +16,99 @@
1616
# At the start of the ggml file we write the model parameters
1717
# and vocabulary.
1818
#
19-
import os
19+
import argparse
2020
import sys
2121
import json
2222
import struct
2323
import numpy as np
2424
import torch
2525
from sentencepiece import SentencePieceProcessor
2626

27-
if len(sys.argv) < 3:
28-
print("Usage: convert-ckpt-to-ggml.py dir-model ftype\n")
29-
print(" ftype == 0 -> float32")
30-
print(" ftype == 1 -> float16")
31-
sys.exit(1)
27+
def parse_args():
3228

33-
# output in the same directory as the model
34-
dir_model = sys.argv[1]
35-
36-
fname_hparams = sys.argv[1] + "/params.json"
37-
fname_tokenizer = sys.argv[1] + "/../tokenizer.model"
29+
parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file')
30+
parser.add_argument('dir_model', help='directory containing the model checkpoint')
31+
parser.add_argument('ftype', type=int, choices=[0, 1], default=1, help='file type (0: float32, 1: float16)')
32+
return parser.parse_args()
3833

3934
def get_n_parts(dim):
40-
if dim == 4096:
41-
return 1
42-
elif dim == 5120:
43-
return 2
44-
elif dim == 6656:
45-
return 4
46-
elif dim == 8192:
47-
return 8
48-
else:
49-
print("Invalid dim: " + str(dim))
35+
36+
mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8}
37+
n_parts = mappings.get(dim)
38+
if n_parts is None:
39+
print(f"Invalid dim: {dim}")
5040
sys.exit(1)
5141

52-
# possible data types
53-
# ftype == 0 -> float32
54-
# ftype == 1 -> float16
55-
#
56-
# map from ftype to string
57-
ftype_str = ["f32", "f16"]
58-
59-
ftype = 1
60-
if len(sys.argv) > 2:
61-
ftype = int(sys.argv[2])
62-
if ftype < 0 or ftype > 1:
63-
print("Invalid ftype: " + str(ftype))
64-
sys.exit(1)
65-
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
66-
67-
if os.path.exists(fname_out):
68-
print(f"Skip conversion, it already exists: {fname_out}")
69-
sys.exit(0)
70-
71-
with open(fname_hparams, "r") as f:
72-
hparams = json.load(f)
42+
print(f"n_parts = {n_parts}\n")
43+
return n_parts
7344

74-
tokenizer = SentencePieceProcessor(fname_tokenizer)
45+
def load_hparams_and_tokenizer(dir_model):
46+
47+
fname_hparams = f"{dir_model}/params.json"
48+
fname_tokenizer = f"{dir_model}/../tokenizer.model"
7549

76-
hparams.update({"vocab_size": tokenizer.vocab_size()})
50+
with open(fname_hparams, "r") as f:
51+
hparams = json.load(f)
52+
print(hparams)
7753

78-
n_parts = get_n_parts(hparams["dim"])
54+
tokenizer = SentencePieceProcessor(fname_tokenizer)
55+
hparams.update({"vocab_size": tokenizer.vocab_size()})
7956

80-
print(hparams)
81-
print('n_parts = ', n_parts)
57+
return hparams, tokenizer
8258

83-
for p in range(n_parts):
84-
print('Processing part ', p)
59+
def write_header(fout, hparams, ftype):
60+
61+
keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"]
62+
values = [
63+
0x67676d6c, # magic: ggml in hex
64+
*[hparams[key] for key in keys],
65+
hparams["dim"] // hparams["n_heads"], # rot (obsolete)
66+
ftype
67+
]
68+
fout.write(struct.pack("i" * len(values), *values))
8569

86-
#fname_model = sys.argv[1] + "/consolidated.00.pth"
87-
fname_model = sys.argv[1] + "/consolidated.0" + str(p) + ".pth"
88-
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin"
89-
if (p > 0):
90-
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".bin" + "." + str(p)
70+
def write_tokens(fout, tokenizer):
9171

92-
model = torch.load(fname_model, map_location="cpu")
93-
94-
fout = open(fname_out, "wb")
95-
96-
fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
97-
fout.write(struct.pack("i", hparams["vocab_size"]))
98-
fout.write(struct.pack("i", hparams["dim"]))
99-
fout.write(struct.pack("i", hparams["multiple_of"]))
100-
fout.write(struct.pack("i", hparams["n_heads"]))
101-
fout.write(struct.pack("i", hparams["n_layers"]))
102-
fout.write(struct.pack("i", hparams["dim"] // hparams["n_heads"])) # rot (obsolete)
103-
fout.write(struct.pack("i", ftype))
104-
105-
# Is this correct??
10672
for i in range(tokenizer.vocab_size()):
10773
if tokenizer.is_unknown(i):
108-
# "<unk>" token (translated as ??)
10974
text = " \u2047 ".encode("utf-8")
110-
fout.write(struct.pack("i", len(text)))
111-
fout.write(text)
11275
elif tokenizer.is_control(i):
113-
# "<s>"/"</s>" tokens
114-
fout.write(struct.pack("i", 0))
76+
text = b""
11577
elif tokenizer.is_byte(i):
116-
# "<U+XX>" tokens (which may be invalid UTF-8)
11778
piece = tokenizer.id_to_piece(i)
11879
if len(piece) != 6:
119-
print("Invalid token: " + piece)
80+
print(f"Invalid token: {piece}")
12081
sys.exit(1)
12182
byte_value = int(piece[3:-1], 16)
122-
fout.write(struct.pack("i", 1))
123-
fout.write(struct.pack("B", byte_value))
83+
text = struct.pack("B", byte_value)
12484
else:
125-
# normal token. Uses U+2581 (LOWER ONE EIGHTH BLOCK) to represent spaces.
12685
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
127-
fout.write(struct.pack("i", len(text)))
128-
fout.write(text)
86+
fout.write(struct.pack("i", len(text)))
87+
fout.write(text)
12988

130-
for k, v in model.items():
131-
name = k
132-
shape = v.shape
89+
def process_and_write_variables(fout, model, ftype):
13390

134-
# skip layers.X.attention.inner_attention.rope.freqs
135-
if name[-5:] == "freqs":
91+
for name, data in model.items():
92+
93+
if name.endswith("freqs"):
13694
continue
137-
138-
print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype)
139-
140-
#data = tf.train.load_variable(dir_model, name).squeeze()
141-
data = v.numpy().squeeze()
142-
n_dims = len(data.shape);
95+
96+
shape = data.shape
97+
98+
print(f"Processing variable: {name} with shape: {shape} and type: {data.dtype}\n")
99+
100+
data = np.squeeze(data)
101+
n_dims = len(shape)
143102

144103
# for efficiency - transpose some matrices
145104
# "model/h.*/attn/c_attn/w"
146105
# "model/h.*/attn/c_proj/w"
147106
# "model/h.*/mlp/c_fc/w"
148107
# "model/h.*/mlp/c_proj/w"
149-
#if name[-14:] == "/attn/c_attn/w" or \
150-
# name[-14:] == "/attn/c_proj/w" or \
151-
# name[-11:] == "/mlp/c_fc/w" or \
152-
# name[-13:] == "/mlp/c_proj/w":
153-
# print(" Transposing")
108+
#if name.endswith(("/attn/c_attn/w", "/attn/c_proj/w", "/mlp/c_fc/w", "/mlp/c_proj/w")):
109+
# print("Transposing")
154110
# data = data.transpose()
155111

156-
dshape = data.shape
157-
158112
# default type is fp16
159113
ftype_cur = 1
160114
if ftype == 0 or n_dims == 1:
@@ -164,18 +118,40 @@ def get_n_parts(dim):
164118

165119
# header
166120
sname = name.encode('utf-8')
167-
fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur))
168-
for i in range(n_dims):
169-
fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
170-
fout.write(sname);
171-
121+
fout.write(struct.pack("iii", len(data.shape), len(sname), ftype_cur))
122+
for dim in reversed(data.shape):
123+
fout.write(struct.pack("i", dim))
124+
fout.write(sname)
125+
172126
# data
173127
data.tofile(fout)
174128

175-
# I hope this deallocates the memory ..
176-
model = None
129+
def main():
130+
131+
args = parse_args()
132+
dir_model = args.dir_model
133+
ftype = args.ftype
134+
ftype_str = ["f32", "f16"]
135+
136+
hparams, tokenizer = load_hparams_and_tokenizer(dir_model)
137+
n_parts = get_n_parts(hparams["dim"])
138+
139+
for p in range(n_parts):
140+
141+
print(f"Processing part {p}\n")
142+
143+
fname_model = f"{dir_model}/consolidated.0{p}.pth"
144+
fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin{'' if p == 0 else '.' + str(p)}"
145+
146+
model = torch.load(fname_model, map_location="cpu")
147+
148+
with open(fname_out, "wb") as fout:
149+
write_header(fout, hparams, ftype)
150+
write_tokens(fout, tokenizer)
151+
process_and_write_variables(fout, model, ftype)
177152

178-
fout.close()
153+
del model
154+
print(f"Done. Output file: {fname_out}, (part {p})\n")
179155

180-
print("Done. Output file: " + fname_out + ", (part ", p, ")")
181-
print("")
156+
if __name__ == "__main__":
157+
main()

0 commit comments

Comments
 (0)