Skip to content

Commit a00bb06

Browse files
committed
Make convert script with pytorch files
1 parent 51b3b56 commit a00bb06

File tree

1 file changed

+43
-3
lines changed

1 file changed

+43
-3
lines changed

convert-stablelm-hf-to-gguf.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from __future__ import annotations
55

6+
import contextlib
67
import argparse
78
import json
89
import os
@@ -20,6 +21,16 @@
2021
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
2122
import gguf
2223

24+
def count_model_parts(dir_model: Path, prefix: str) -> int:
25+
num_parts = 0
26+
for filename in os.listdir(dir_model):
27+
if filename.startswith(prefix):
28+
num_parts += 1
29+
30+
if num_parts > 0:
31+
print("gguf: found " + str(num_parts) + " model parts")
32+
return num_parts
33+
2334

2435
def parse_args() -> argparse.Namespace:
2536
parser = argparse.ArgumentParser(description="Convert a stablelm model to a GGML compatible file")
@@ -141,16 +152,45 @@ def parse_args() -> argparse.Namespace:
141152
# tensor info
142153
print("gguf: get tensor metadata")
143154

144-
part_names = iter(("model.safetensors",))
155+
# get number of model parts
156+
num_parts = count_model_parts(dir_model, "model-00")
157+
if num_parts:
158+
is_safetensors = True
159+
from safetensors import safe_open
160+
else:
161+
if count_model_parts(dir_model, "model.safetensors") > 0:
162+
is_safetensors = True
163+
num_parts = 0
164+
else:
165+
is_safetensors = False
166+
num_parts = count_model_parts(dir_model, "pytorch_model-")
167+
168+
if is_safetensors and num_parts == 0:
169+
part_names = iter(("model.safetensors",))
170+
elif num_parts == 0:
171+
part_names = iter(("pytorch_model.bin",))
172+
elif is_safetensors:
173+
part_names = (
174+
f"model-{n:05}-of-{num_parts:05}.safetensors" for n in range(1, num_parts + 1)
175+
)
176+
else:
177+
part_names = (
178+
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
179+
)
180+
145181

146182
for part_name in part_names:
147183
if args.vocab_only:
148184
break
149185
print("gguf: loading model part '" + part_name + "'")
150-
ctx = safe_open(dir_model / part_name, framework="pt", device="cpu")
186+
if is_safetensors:
187+
ctx = safe_open(dir_model / part_name, framework="pt", device="cpu")
188+
else:
189+
ctx = contextlib.nullcontext(torch.load(dir_model / part_name, map_location="cpu"))
190+
151191
with ctx as model_part:
152192
for name in model_part.keys():
153-
data = model_part.get_tensor(name)
193+
data = model_part.get_tensor(name) if is_safetensors else model_part[name]
154194

155195
# we don't need these
156196
if name.endswith(".attention.masked_bias") or name.endswith(".attention.bias") or name.endswith(".attention.rotary_emb.inv_freq"):

0 commit comments

Comments
 (0)