16
16
# At the start of the ggml file we write the model parameters
17
17
# and vocabulary.
18
18
#
19
- import os
19
+ import argparse
20
20
import sys
21
21
import json
22
22
import struct
23
23
import numpy as np
24
24
import torch
25
25
from sentencepiece import SentencePieceProcessor
26
26
27
- if len (sys .argv ) < 3 :
28
- print ("Usage: convert-ckpt-to-ggml.py dir-model ftype\n " )
29
- print (" ftype == 0 -> float32" )
30
- print (" ftype == 1 -> float16" )
31
- sys .exit (1 )
27
+ def parse_args ():
32
28
33
- # output in the same directory as the model
34
- dir_model = sys .argv [1 ]
35
-
36
- fname_hparams = sys .argv [1 ] + "/params.json"
37
- fname_tokenizer = sys .argv [1 ] + "/../tokenizer.model"
29
+ parser = argparse .ArgumentParser (description = 'Convert a LLaMA model checkpoint to a ggml compatible file' )
30
+ parser .add_argument ('dir_model' , help = 'directory containing the model checkpoint' )
31
+ parser .add_argument ('ftype' , type = int , choices = [0 , 1 ], default = 1 , help = 'file type (0: float32, 1: float16)' )
32
+ return parser .parse_args ()
38
33
39
34
def get_n_parts (dim ):
40
- if dim == 4096 :
41
- return 1
42
- elif dim == 5120 :
43
- return 2
44
- elif dim == 6656 :
45
- return 4
46
- elif dim == 8192 :
47
- return 8
48
- else :
49
- print ("Invalid dim: " + str (dim ))
35
+
36
+ mappings = {4096 : 1 , 5120 : 2 , 6656 : 4 , 8192 : 8 }
37
+ n_parts = mappings .get (dim )
38
+ if n_parts is None :
39
+ print (f"Invalid dim: { dim } " )
50
40
sys .exit (1 )
51
41
52
- # possible data types
53
- # ftype == 0 -> float32
54
- # ftype == 1 -> float16
55
- #
56
- # map from ftype to string
57
- ftype_str = ["f32" , "f16" ]
58
-
59
- ftype = 1
60
- if len (sys .argv ) > 2 :
61
- ftype = int (sys .argv [2 ])
62
- if ftype < 0 or ftype > 1 :
63
- print ("Invalid ftype: " + str (ftype ))
64
- sys .exit (1 )
65
- fname_out = sys .argv [1 ] + "/ggml-model-" + ftype_str [ftype ] + ".bin"
66
-
67
- if os .path .exists (fname_out ):
68
- print (f"Skip conversion, it already exists: { fname_out } " )
69
- sys .exit (0 )
70
-
71
- with open (fname_hparams , "r" ) as f :
72
- hparams = json .load (f )
42
+ print (f"n_parts = { n_parts } \n " )
43
+ return n_parts
73
44
74
- tokenizer = SentencePieceProcessor (fname_tokenizer )
45
+ def load_hparams_and_tokenizer (dir_model ):
46
+
47
+ fname_hparams = f"{ dir_model } /params.json"
48
+ fname_tokenizer = f"{ dir_model } /../tokenizer.model"
75
49
76
- hparams .update ({"vocab_size" : tokenizer .vocab_size ()})
50
+ with open (fname_hparams , "r" ) as f :
51
+ hparams = json .load (f )
52
+ print (hparams )
77
53
78
- n_parts = get_n_parts (hparams ["dim" ])
54
+ tokenizer = SentencePieceProcessor (fname_tokenizer )
55
+ hparams .update ({"vocab_size" : tokenizer .vocab_size ()})
79
56
80
- print (hparams )
81
- print ('n_parts = ' , n_parts )
57
+ return hparams , tokenizer
82
58
83
- for p in range (n_parts ):
84
- print ('Processing part ' , p )
59
+ def write_header (fout , hparams , ftype ):
60
+
61
+ keys = ["vocab_size" , "dim" , "multiple_of" , "n_heads" , "n_layers" ]
62
+ values = [
63
+ 0x67676d6c , # magic: ggml in hex
64
+ * [hparams [key ] for key in keys ],
65
+ hparams ["dim" ] // hparams ["n_heads" ], # rot (obsolete)
66
+ ftype
67
+ ]
68
+ fout .write (struct .pack ("i" * len (values ), * values ))
85
69
86
- #fname_model = sys.argv[1] + "/consolidated.00.pth"
87
- fname_model = sys .argv [1 ] + "/consolidated.0" + str (p ) + ".pth"
88
- fname_out = sys .argv [1 ] + "/ggml-model-" + ftype_str [ftype ] + ".bin"
89
- if (p > 0 ):
90
- fname_out = sys .argv [1 ] + "/ggml-model-" + ftype_str [ftype ] + ".bin" + "." + str (p )
70
+ def write_tokens (fout , tokenizer ):
91
71
92
- model = torch .load (fname_model , map_location = "cpu" )
93
-
94
- fout = open (fname_out , "wb" )
95
-
96
- fout .write (struct .pack ("i" , 0x67676d6c )) # magic: ggml in hex
97
- fout .write (struct .pack ("i" , hparams ["vocab_size" ]))
98
- fout .write (struct .pack ("i" , hparams ["dim" ]))
99
- fout .write (struct .pack ("i" , hparams ["multiple_of" ]))
100
- fout .write (struct .pack ("i" , hparams ["n_heads" ]))
101
- fout .write (struct .pack ("i" , hparams ["n_layers" ]))
102
- fout .write (struct .pack ("i" , hparams ["dim" ] // hparams ["n_heads" ])) # rot (obsolete)
103
- fout .write (struct .pack ("i" , ftype ))
104
-
105
- # Is this correct??
106
72
for i in range (tokenizer .vocab_size ()):
107
73
if tokenizer .is_unknown (i ):
108
- # "<unk>" token (translated as ??)
109
74
text = " \u2047 " .encode ("utf-8" )
110
- fout .write (struct .pack ("i" , len (text )))
111
- fout .write (text )
112
75
elif tokenizer .is_control (i ):
113
- # "<s>"/"</s>" tokens
114
- fout .write (struct .pack ("i" , 0 ))
76
+ text = b""
115
77
elif tokenizer .is_byte (i ):
116
- # "<U+XX>" tokens (which may be invalid UTF-8)
117
78
piece = tokenizer .id_to_piece (i )
118
79
if len (piece ) != 6 :
119
- print ("Invalid token: " + piece )
80
+ print (f "Invalid token: { piece } " )
120
81
sys .exit (1 )
121
82
byte_value = int (piece [3 :- 1 ], 16 )
122
- fout .write (struct .pack ("i" , 1 ))
123
- fout .write (struct .pack ("B" , byte_value ))
83
+ text = struct .pack ("B" , byte_value )
124
84
else :
125
- # normal token. Uses U+2581 (LOWER ONE EIGHTH BLOCK) to represent spaces.
126
85
text = tokenizer .id_to_piece (i ).replace ("\u2581 " , " " ).encode ("utf-8" )
127
- fout .write (struct .pack ("i" , len (text )))
128
- fout .write (text )
86
+ fout .write (struct .pack ("i" , len (text )))
87
+ fout .write (text )
129
88
130
- for k , v in model .items ():
131
- name = k
132
- shape = v .shape
89
+ def process_and_write_variables (fout , model , ftype ):
133
90
134
- # skip layers.X.attention.inner_attention.rope.freqs
135
- if name [- 5 :] == "freqs" :
91
+ for name , data in model .items ():
92
+
93
+ if name .endswith ("freqs" ):
136
94
continue
137
-
138
- print ("Processing variable: " + name + " with shape: " , shape , " and type: " , v .dtype )
139
-
140
- #data = tf.train.load_variable(dir_model, name).squeeze()
141
- data = v .numpy ().squeeze ()
142
- n_dims = len (data .shape );
95
+
96
+ shape = data .shape
97
+
98
+ print (f"Processing variable: { name } with shape: { shape } and type: { data .dtype } \n " )
99
+
100
+ data = np .squeeze (data )
101
+ n_dims = len (shape )
143
102
144
103
# for efficiency - transpose some matrices
145
104
# "model/h.*/attn/c_attn/w"
146
105
# "model/h.*/attn/c_proj/w"
147
106
# "model/h.*/mlp/c_fc/w"
148
107
# "model/h.*/mlp/c_proj/w"
149
- #if name[-14:] == "/attn/c_attn/w" or \
150
- # name[-14:] == "/attn/c_proj/w" or \
151
- # name[-11:] == "/mlp/c_fc/w" or \
152
- # name[-13:] == "/mlp/c_proj/w":
153
- # print(" Transposing")
108
+ #if name.endswith(("/attn/c_attn/w", "/attn/c_proj/w", "/mlp/c_fc/w", "/mlp/c_proj/w")):
109
+ # print("Transposing")
154
110
# data = data.transpose()
155
111
156
- dshape = data .shape
157
-
158
112
# default type is fp16
159
113
ftype_cur = 1
160
114
if ftype == 0 or n_dims == 1 :
@@ -164,18 +118,40 @@ def get_n_parts(dim):
164
118
165
119
# header
166
120
sname = name .encode ('utf-8' )
167
- fout .write (struct .pack ("iii" , n_dims , len (sname ), ftype_cur ))
168
- for i in range ( n_dims ):
169
- fout .write (struct .pack ("i" , dshape [ n_dims - 1 - i ] ))
170
- fout .write (sname );
171
-
121
+ fout .write (struct .pack ("iii" , len ( data . shape ) , len (sname ), ftype_cur ))
122
+ for dim in reversed ( data . shape ):
123
+ fout .write (struct .pack ("i" , dim ))
124
+ fout .write (sname )
125
+
172
126
# data
173
127
data .tofile (fout )
174
128
175
- # I hope this deallocates the memory ..
176
- model = None
129
+ def main ():
130
+
131
+ args = parse_args ()
132
+ dir_model = args .dir_model
133
+ ftype = args .ftype
134
+ ftype_str = ["f32" , "f16" ]
135
+
136
+ hparams , tokenizer = load_hparams_and_tokenizer (dir_model )
137
+ n_parts = get_n_parts (hparams ["dim" ])
138
+
139
+ for p in range (n_parts ):
140
+
141
+ print (f"Processing part { p } \n " )
142
+
143
+ fname_model = f"{ dir_model } /consolidated.0{ p } .pth"
144
+ fname_out = f"{ dir_model } /ggml-model-{ ftype_str [ftype ]} .bin{ '' if p == 0 else '.' + str (p )} "
145
+
146
+ model = torch .load (fname_model , map_location = "cpu" )
147
+
148
+ with open (fname_out , "wb" ) as fout :
149
+ write_header (fout , hparams , ftype )
150
+ write_tokens (fout , tokenizer )
151
+ process_and_write_variables (fout , model , ftype )
177
152
178
- fout .close ()
153
+ del model
154
+ print (f"Done. Output file: { fname_out } , (part { p } )\n " )
179
155
180
- print ( "Done. Output file: " + fname_out + ", (part " , p , ")" )
181
- print ( "" )
156
+ if __name__ == "__main__" :
157
+ main ( )
0 commit comments