|
3 | 3 |
|
4 | 4 | from __future__ import annotations
|
5 | 5 |
|
| 6 | +import contextlib |
6 | 7 | import argparse
|
7 | 8 | import json
|
8 | 9 | import os
|
|
20 | 21 | sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
|
21 | 22 | import gguf
|
22 | 23 |
|
| 24 | +def count_model_parts(dir_model: Path, prefix: str) -> int: |
| 25 | + num_parts = 0 |
| 26 | + for filename in os.listdir(dir_model): |
| 27 | + if filename.startswith(prefix): |
| 28 | + num_parts += 1 |
| 29 | + |
| 30 | + if num_parts > 0: |
| 31 | + print("gguf: found " + str(num_parts) + " model parts") |
| 32 | + return num_parts |
| 33 | + |
23 | 34 |
|
24 | 35 | def parse_args() -> argparse.Namespace:
|
25 | 36 | parser = argparse.ArgumentParser(description="Convert a stablelm model to a GGML compatible file")
|
@@ -141,16 +152,45 @@ def parse_args() -> argparse.Namespace:
|
141 | 152 | # tensor info
|
142 | 153 | print("gguf: get tensor metadata")
|
143 | 154 |
|
144 |
| -part_names = iter(("model.safetensors",)) |
| 155 | +# get number of model parts |
| 156 | +num_parts = count_model_parts(dir_model, "model-00") |
| 157 | +if num_parts: |
| 158 | + is_safetensors = True |
| 159 | + from safetensors import safe_open |
| 160 | +else: |
| 161 | + if count_model_parts(dir_model, "model.safetensors") > 0: |
| 162 | + is_safetensors = True |
| 163 | + num_parts = 0 |
| 164 | + else: |
| 165 | + is_safetensors = False |
| 166 | + num_parts = count_model_parts(dir_model, "pytorch_model-") |
| 167 | + |
| 168 | +if is_safetensors and num_parts == 0: |
| 169 | + part_names = iter(("model.safetensors",)) |
| 170 | +elif num_parts == 0: |
| 171 | + part_names = iter(("pytorch_model.bin",)) |
| 172 | +elif is_safetensors: |
| 173 | + part_names = ( |
| 174 | + f"model-{n:05}-of-{num_parts:05}.safetensors" for n in range(1, num_parts + 1) |
| 175 | + ) |
| 176 | +else: |
| 177 | + part_names = ( |
| 178 | + f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1) |
| 179 | + ) |
| 180 | + |
145 | 181 |
|
146 | 182 | for part_name in part_names:
|
147 | 183 | if args.vocab_only:
|
148 | 184 | break
|
149 | 185 | print("gguf: loading model part '" + part_name + "'")
|
150 |
| - ctx = safe_open(dir_model / part_name, framework="pt", device="cpu") |
| 186 | + if is_safetensors: |
| 187 | + ctx = safe_open(dir_model / part_name, framework="pt", device="cpu") |
| 188 | + else: |
| 189 | + ctx = contextlib.nullcontext(torch.load(dir_model / part_name, map_location="cpu")) |
| 190 | + |
151 | 191 | with ctx as model_part:
|
152 | 192 | for name in model_part.keys():
|
153 |
| - data = model_part.get_tensor(name) |
| 193 | + data = model_part.get_tensor(name) if is_safetensors else model_part[name] |
154 | 194 |
|
155 | 195 | # we don't need these
|
156 | 196 | if name.endswith(".attention.masked_bias") or name.endswith(".attention.bias") or name.endswith(".attention.rotary_emb.inv_freq"):
|
|
0 commit comments