@@ -72,7 +72,12 @@ def dequantize_weights(fin, n_rows, n_cols):
72
72
73
73
def read_variables (fin ):
74
74
model = {}
75
- pbar = tqdm (total = os .path .getsize (fin .name ), unit = "B" , unit_scale = True , desc = "Reading variables" )
75
+ pbar = tqdm (
76
+ total = os .path .getsize (fin .name ),
77
+ unit = "B" ,
78
+ unit_scale = True ,
79
+ desc = "Reading variables" ,
80
+ )
76
81
while True :
77
82
start_pos = fin .tell ()
78
83
try :
@@ -98,7 +103,9 @@ def read_variables(fin):
98
103
data_size = np .prod (shape )
99
104
data = np .fromfile (fin , dtype = dtype , count = data_size ).reshape (shape )
100
105
101
- model [name ] = torch .tensor (data , dtype = torch .float32 if dtype == np .float32 else torch .float16 )
106
+ model [name ] = torch .tensor (
107
+ data , dtype = torch .float32 if dtype == np .float32 else torch .float16
108
+ )
102
109
103
110
pbar .update (fin .tell () - start_pos )
104
111
@@ -112,11 +119,17 @@ def convert_to_hf_format(model, hparams):
112
119
dim = hparams ["dim" ]
113
120
dims_per_head = dim // n_heads
114
121
base = 10000.0
115
- inv_freq = 1.0 / (base ** (torch .arange (0 , dims_per_head , 2 ).float () / dims_per_head ))
122
+ inv_freq = 1.0 / (
123
+ base ** (torch .arange (0 , dims_per_head , 2 ).float () / dims_per_head )
124
+ )
116
125
117
126
# permute for sliced rotary
118
127
def permute (w ):
119
- return w .view (n_heads , dim // n_heads // 2 , 2 , dim ).transpose (1 , 2 ).reshape (dim , dim )
128
+ return (
129
+ w .view (n_heads , dim // n_heads // 2 , 2 , dim )
130
+ .transpose (1 , 2 )
131
+ .reshape (dim , dim )
132
+ )
120
133
121
134
state_dict = {}
122
135
for layer_i in range (n_layers ):
@@ -164,16 +177,22 @@ def permute(w):
164
177
165
178
166
179
def chat (model , hparams , llama_dir ):
167
- from transformers import (GenerationConfig , LlamaForCausalLM ,
168
- LlamaTokenizer , StoppingCriteria ,
169
- StoppingCriteriaList )
180
+ from transformers import (
181
+ GenerationConfig ,
182
+ LlamaForCausalLM ,
183
+ LlamaTokenizer ,
184
+ StoppingCriteria ,
185
+ StoppingCriteriaList ,
186
+ )
170
187
from transformers .models .llama .configuration_llama import LlamaConfig
171
188
172
189
class StoppingCriteriaSub (StoppingCriteria ):
173
190
def __init__ (self ):
174
191
super ().__init__ ()
175
192
176
- def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor , stops = []):
193
+ def __call__ (
194
+ self , input_ids : torch .LongTensor , scores : torch .FloatTensor , stops = []
195
+ ):
177
196
print (tokenizer .decode (input_ids [0 ]), end = "" , flush = True )
178
197
if input_ids [0 ][- 1 ] == 13 :
179
198
return True
@@ -237,7 +256,11 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops
237
256
def main ():
238
257
parser = argparse .ArgumentParser ()
239
258
parser .add_argument (
240
- "--input_dir" , "-i" , type = str , required = True , help = "The input directory containing the ggml files."
259
+ "--input_dir" ,
260
+ "-i" ,
261
+ type = str ,
262
+ required = True ,
263
+ help = "The input directory containing the ggml files." ,
241
264
)
242
265
parser .add_argument (
243
266
"--prefix" ,
@@ -252,14 +275,21 @@ def main():
252
275
help = "Whether to save the model in the huggingface format. (default: False)" ,
253
276
)
254
277
parser .add_argument (
255
- "--chat" , "-c" , action = "store_true" , help = "Whether to open a chat with the model. (default: False)"
278
+ "--chat" ,
279
+ "-c" ,
280
+ action = "store_true" ,
281
+ help = "Whether to open a chat with the model. (default: False)" ,
256
282
)
257
283
args = parser .parse_args ()
258
284
259
285
llama_dir = os .path .abspath (f"{ args .input_dir } /../" )
260
286
261
287
ggml_files = sorted (
262
- [f"{ args .input_dir } /{ f } " for f in os .listdir (args .input_dir ) if f .startswith (args .prefix )]
288
+ [
289
+ f"{ args .input_dir } /{ f } "
290
+ for f in os .listdir (args .input_dir )
291
+ if f .startswith (args .prefix )
292
+ ]
263
293
)
264
294
265
295
fin = open (ggml_files [0 ], "rb" )
0 commit comments