Skip to content

Commit c1c7026

Browse files
committed
Fix python stuff (#109)
1 parent 467b149 commit c1c7026

File tree

1 file changed

+14
-23
lines changed

1 file changed

+14
-23
lines changed

convert-pth-to-ggml.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def parse_args():
3232
return parser.parse_args()
3333

3434
def get_n_parts(dim):
35-
35+
3636
mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8}
3737
n_parts = mappings.get(dim)
3838
if n_parts is None:
@@ -43,7 +43,7 @@ def get_n_parts(dim):
4343
return n_parts
4444

4545
def load_hparams_and_tokenizer(dir_model):
46-
46+
4747
fname_hparams = f"{dir_model}/params.json"
4848
fname_tokenizer = f"{dir_model}/../tokenizer.model"
4949

@@ -57,7 +57,7 @@ def load_hparams_and_tokenizer(dir_model):
5757
return hparams, tokenizer
5858

5959
def write_header(fout, hparams, ftype):
60-
60+
6161
keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"]
6262
values = [
6363
0x67676d6c, # magic: ggml in hex
@@ -88,26 +88,17 @@ def write_tokens(fout, tokenizer):
8888

8989
def process_and_write_variables(fout, model, ftype):
9090

91-
for name, data in model.items():
92-
91+
for name, datao in model.items():
92+
9393
if name.endswith("freqs"):
9494
continue
95-
96-
shape = data.shape
97-
98-
print(f"Processing variable: {name} with shape: {shape} and type: {data.dtype}\n")
99-
100-
data = np.squeeze(data)
101-
n_dims = len(shape)
10295

103-
# for efficiency - transpose some matrices
104-
# "model/h.*/attn/c_attn/w"
105-
# "model/h.*/attn/c_proj/w"
106-
# "model/h.*/mlp/c_fc/w"
107-
# "model/h.*/mlp/c_proj/w"
108-
#if name.endswith(("/attn/c_attn/w", "/attn/c_proj/w", "/mlp/c_fc/w", "/mlp/c_proj/w")):
109-
# print("Transposing")
110-
# data = data.transpose()
96+
shape = datao.shape
97+
98+
print(f"Processing variable: {name} with shape: {shape} and type: {datao.dtype}")
99+
100+
data = datao.numpy().squeeze()
101+
n_dims = len(shape)
111102

112103
# default type is fp16
113104
ftype_cur = 1
@@ -122,8 +113,8 @@ def process_and_write_variables(fout, model, ftype):
122113
for dim in reversed(data.shape):
123114
fout.write(struct.pack("i", dim))
124115
fout.write(sname)
125-
126-
# data
116+
117+
# data output to file
127118
data.tofile(fout)
128119

129120
def main():
@@ -139,7 +130,7 @@ def main():
139130
for p in range(n_parts):
140131

141132
print(f"Processing part {p}\n")
142-
133+
143134
fname_model = f"{dir_model}/consolidated.0{p}.pth"
144135
fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin{'' if p == 0 else '.' + str(p)}"
145136

0 commit comments

Comments
 (0)