Skip to content

Commit dfa2d70

Browse files
committed
apply black to ggml-to-pth
1 parent 9cbc404 commit dfa2d70

File tree

1 file changed

+41
-11
lines changed

1 file changed

+41
-11
lines changed

convert-ggml-to-pth.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,12 @@ def dequantize_weights(fin, n_rows, n_cols):
7272

7373
def read_variables(fin):
7474
model = {}
75-
pbar = tqdm(total=os.path.getsize(fin.name), unit="B", unit_scale=True, desc="Reading variables")
75+
pbar = tqdm(
76+
total=os.path.getsize(fin.name),
77+
unit="B",
78+
unit_scale=True,
79+
desc="Reading variables",
80+
)
7681
while True:
7782
start_pos = fin.tell()
7883
try:
@@ -98,7 +103,9 @@ def read_variables(fin):
98103
data_size = np.prod(shape)
99104
data = np.fromfile(fin, dtype=dtype, count=data_size).reshape(shape)
100105

101-
model[name] = torch.tensor(data, dtype=torch.float32 if dtype == np.float32 else torch.float16)
106+
model[name] = torch.tensor(
107+
data, dtype=torch.float32 if dtype == np.float32 else torch.float16
108+
)
102109

103110
pbar.update(fin.tell() - start_pos)
104111

@@ -112,11 +119,17 @@ def convert_to_hf_format(model, hparams):
112119
dim = hparams["dim"]
113120
dims_per_head = dim // n_heads
114121
base = 10000.0
115-
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
122+
inv_freq = 1.0 / (
123+
base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)
124+
)
116125

117126
# permute for sliced rotary
118127
def permute(w):
119-
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
128+
return (
129+
w.view(n_heads, dim // n_heads // 2, 2, dim)
130+
.transpose(1, 2)
131+
.reshape(dim, dim)
132+
)
120133

121134
state_dict = {}
122135
for layer_i in range(n_layers):
@@ -164,16 +177,22 @@ def permute(w):
164177

165178

166179
def chat(model, hparams, llama_dir):
167-
from transformers import (GenerationConfig, LlamaForCausalLM,
168-
LlamaTokenizer, StoppingCriteria,
169-
StoppingCriteriaList)
180+
from transformers import (
181+
GenerationConfig,
182+
LlamaForCausalLM,
183+
LlamaTokenizer,
184+
StoppingCriteria,
185+
StoppingCriteriaList,
186+
)
170187
from transformers.models.llama.configuration_llama import LlamaConfig
171188

172189
class StoppingCriteriaSub(StoppingCriteria):
173190
def __init__(self):
174191
super().__init__()
175192

176-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops=[]):
193+
def __call__(
194+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops=[]
195+
):
177196
print(tokenizer.decode(input_ids[0]), end="", flush=True)
178197
if input_ids[0][-1] == 13:
179198
return True
@@ -237,7 +256,11 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops
237256
def main():
238257
parser = argparse.ArgumentParser()
239258
parser.add_argument(
240-
"--input_dir", "-i", type=str, required=True, help="The input directory containing the ggml files."
259+
"--input_dir",
260+
"-i",
261+
type=str,
262+
required=True,
263+
help="The input directory containing the ggml files.",
241264
)
242265
parser.add_argument(
243266
"--prefix",
@@ -252,14 +275,21 @@ def main():
252275
help="Whether to save the model in the huggingface format. (default: False)",
253276
)
254277
parser.add_argument(
255-
"--chat", "-c", action="store_true", help="Whether to open a chat with the model. (default: False)"
278+
"--chat",
279+
"-c",
280+
action="store_true",
281+
help="Whether to open a chat with the model. (default: False)",
256282
)
257283
args = parser.parse_args()
258284

259285
llama_dir = os.path.abspath(f"{args.input_dir}/../")
260286

261287
ggml_files = sorted(
262-
[f"{args.input_dir}/{f}" for f in os.listdir(args.input_dir) if f.startswith(args.prefix)]
288+
[
289+
f"{args.input_dir}/{f}"
290+
for f in os.listdir(args.input_dir)
291+
if f.startswith(args.prefix)
292+
]
263293
)
264294

265295
fin = open(ggml_files[0], "rb")

0 commit comments

Comments
 (0)