Skip to content

Commit 6f64b6c

Browse files
authored
Create convert-llama-7b-pth-to-gguf.py
1 parent 62490f1 commit 6f64b6c

File tree

1 file changed

+302
-0
lines changed

1 file changed

+302
-0
lines changed

convert-llama-7b-pth-to-gguf.py

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
# 7b pth llama --> gguf conversion, GQA/70b not supported
2+
# Only models with a single datafile are supported, like 7B
3+
# HF files required in the model dir: config.json tokenizer_config.json tokenizer.json tokenizer.model
4+
5+
import gguf
6+
import gguf_namemap as tmap
7+
import os
8+
import sys
9+
import struct
10+
import json
11+
import numpy as np
12+
import torch
13+
from typing import Any, List
14+
from pathlib import Path
15+
from sentencepiece import SentencePieceProcessor
16+
17+
18+
#NDArray = np.ndarray[Any, Any]
19+
# compatible with python < 3.9
20+
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
21+
22+
def count_model_parts(dir_model: str) -> int:
23+
num_parts = 0
24+
for filename in os.listdir(dir_model):
25+
if filename.startswith("consolidated."):
26+
num_parts += 1
27+
28+
if num_parts > 0:
29+
print("gguf: found " + str(num_parts) + " model parts")
30+
return num_parts
31+
32+
if len(sys.argv) < 3:
33+
print("Usage: convert-h5-to-ggml.py dir-model ftype\n")
34+
print(" ftype == 0 -> float32")
35+
print(" ftype == 1 -> float16")
36+
sys.exit(1)
37+
38+
39+
# output in the same directory as the model
40+
dir_model = sys.argv[1]
41+
last_dir = os.path.basename(os.path.normpath(dir_model))
42+
43+
44+
# possible tensor data types
45+
# ftype == 0 -> float32
46+
# ftype == 1 -> float16
47+
#
48+
# map from ftype to string
49+
ftype_str = ["f32", "f16"]
50+
51+
ftype = 1
52+
if len(sys.argv) > 2:
53+
ftype = int(sys.argv[2])
54+
if ftype < 0 or ftype > 1:
55+
print("Invalid ftype: " + str(ftype))
56+
sys.exit(1)
57+
58+
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".gguf"
59+
60+
print("gguf: loading model "+last_dir)
61+
62+
with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
63+
hparams = json.load(f)
64+
65+
if hparams["architectures"][0] != "LlamaForCausalLM":
66+
print("Model architecture not supported: " + hparams["architectures"][0])
67+
sys.exit()
68+
69+
# get number of model parts
70+
num_parts = count_model_parts(dir_model)
71+
72+
if num_parts > 1:
73+
print("gguf: Only models with a single datafile are supported.")
74+
sys.exit()
75+
76+
gguf_writer = gguf.GGUFWriter.open(fname_out)
77+
78+
79+
print("gguf: get model metadata")
80+
81+
llm_arch = "llama"
82+
block_count = hparams["num_hidden_layers"]
83+
head_count = hparams["num_attention_heads"]
84+
85+
if "num_key_value_heads" in hparams:
86+
head_count_kv = hparams["num_key_value_heads"]
87+
else:
88+
head_count_kv = head_count
89+
90+
if "_name_or_path" in hparams:
91+
hf_repo = hparams["_name_or_path"]
92+
else:
93+
hf_repo=""
94+
95+
gguf_writer.add_architecture(llm_arch)
96+
gguf_writer.add_name(last_dir)
97+
gguf_writer.add_file_type( "All tensors F32" if ftype == 0 else "Most tensors F16, some F32")
98+
gguf_writer.add_source_hf_repo(hf_repo)
99+
gguf_writer.add_context_length(llm_arch, hparams["max_position_embeddings"])
100+
gguf_writer.add_embedding_length(llm_arch, hparams["hidden_size"])
101+
gguf_writer.add_block_count(llm_arch, block_count)
102+
gguf_writer.add_feed_forward_length(llm_arch, hparams["intermediate_size"])
103+
gguf_writer.add_rope_dimension_count(llm_arch, hparams["hidden_size"] // hparams["num_attention_heads"])
104+
gguf_writer.add_head_count(llm_arch, head_count)
105+
gguf_writer.add_head_count_kv(llm_arch, head_count_kv)
106+
gguf_writer.add_layer_norm_rms_eps(llm_arch, hparams["rms_norm_eps"])
107+
108+
109+
# TOKENIZATION
110+
111+
print("gguf: get tokenizer metadata")
112+
113+
tokens: List[str] = []
114+
scores: List[float] = []
115+
116+
if Path(dir_model + "/tokenizer.model").is_file():
117+
# vocab type sentencepiece
118+
print("gguf: get sentencepiece tokenizer vocab and scores")
119+
120+
tokenizer = SentencePieceProcessor(dir_model + "/tokenizer.model")
121+
122+
for i in range(tokenizer.vocab_size()):
123+
text: bytes
124+
if tokenizer.is_unknown(i):
125+
text = " \u2047 ".encode("utf-8")
126+
elif tokenizer.is_control(i):
127+
text = b""
128+
if tokenizer.is_byte(i):
129+
piece = tokenizer.id_to_piece(i)
130+
if len(piece) != 6:
131+
raise Exception(f"Invalid token: {piece}")
132+
byte_value = int(piece[3:-1], 16)
133+
text = struct.pack("B", byte_value)
134+
else:
135+
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
136+
score: float = tokenizer.get_score(i)
137+
138+
tokens.append(text)
139+
scores.append(score)
140+
141+
gguf_writer.add_tokenizer_model("llama")
142+
gguf_writer.add_token_list(tokens)
143+
gguf_writer.add_token_scores(scores)
144+
145+
if Path(dir_model + "/tokenizer.json").is_file():
146+
with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f:
147+
tokenizer = json.load(f)
148+
149+
if "added_tokens" in tokenizer and Path(dir_model + "/tokenizer_config.json").is_file():
150+
print("gguf: get special token ids")
151+
152+
with open(dir_model + "/tokenizer_config.json", "r", encoding="utf-8") as f:
153+
tokenizer_config = json.load(f)
154+
155+
# find special token ids
156+
157+
if "bos_token" in tokenizer_config and tokenizer_config["bos_token"] != None:
158+
for key in tokenizer["added_tokens"]:
159+
if key["content"] == tokenizer_config["bos_token"]["content"]:
160+
gguf_writer.add_bos_token_id(key["id"])
161+
162+
if "eos_token" in tokenizer_config and tokenizer_config["eos_token"] != None:
163+
for key in tokenizer["added_tokens"]:
164+
if key["content"] == tokenizer_config["eos_token"]["content"]:
165+
gguf_writer.add_eos_token_id(key["id"])
166+
167+
if "unk_token" in tokenizer_config and tokenizer_config["unk_token"] != None:
168+
for key in tokenizer["added_tokens"]:
169+
if key["content"] == tokenizer_config["unk_token"]["content"]:
170+
gguf_writer.add_unk_token_id(key["id"])
171+
172+
if "sep_token" in tokenizer_config and tokenizer_config["sep_token"] != None:
173+
for key in tokenizer["added_tokens"]:
174+
if key["content"] == tokenizer_config["sep_token"]["content"]:
175+
gguf_writer.add_sep_token_id(key["id"])
176+
177+
if "pad_token" in tokenizer_config and tokenizer_config["pad_token"] != None:
178+
for key in tokenizer["added_tokens"]:
179+
if key["content"] == tokenizer_config["pad_token"]["content"]:
180+
gguf_writer.add_pad_token_id(key["id"])
181+
182+
183+
# TENSORS
184+
185+
tensor_map = tmap.get_tensor_namemap(block_count)
186+
187+
# tensor info
188+
print("gguf: get tensor metadata")
189+
190+
part_names = ( f"consolidated.{n:02}.pth" for n in range(0, num_parts) )
191+
192+
for part_name in part_names:
193+
print("gguf: loading model part '"+ part_name + "'")
194+
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
195+
196+
for name in model_part.keys():
197+
data = model_part[name]
198+
199+
# we don't need these
200+
if name == "rope.freqs":
201+
continue
202+
203+
# convert any unsupported data types to float32
204+
if data.dtype != torch.float16 and data.dtype != torch.float32:
205+
data = data.to(torch.float32)
206+
207+
data = data.squeeze().numpy()
208+
209+
# map tensor names
210+
if name.endswith(".weight") and name[:-7] in tensor_map:
211+
name = tensor_map[name[:-7]] + ".weight"
212+
elif name.endswith(".bias") and name[:-5] in tensor_map:
213+
name = tensor_map[name[:-5]] + ".bias"
214+
else:
215+
print( "Can not map tensor '" + name + "'" )
216+
sys.exit()
217+
218+
n_dims = len(data.shape)
219+
data_dtype = data.dtype
220+
221+
# if f32 desired, convert any float16 to float32
222+
if ftype == 0 and data.dtype == np.float16:
223+
data_dtype = np.float32
224+
225+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
226+
if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
227+
data_dtype = np.float32
228+
229+
# if f16 desired, convert any float32 2-dim weight tensors to float16
230+
if ftype == 1 and data.dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
231+
data_dtype = np.float16
232+
233+
data_nbytes = data.size * 2 if data_dtype == np.float16 else data.size * 4
234+
235+
gguf_writer.add_tensor_info(name, data.shape, data_dtype, data_nbytes)
236+
237+
238+
print("gguf: write header")
239+
gguf_writer.write_header_to_file()
240+
print("gguf: write metadata")
241+
gguf_writer.write_kv_data_to_file()
242+
print("gguf: write tensor metadata")
243+
gguf_writer.write_ti_data_to_file()
244+
245+
# tensor data
246+
print("gguf: convert and write tensor data")
247+
248+
part_names = ( f"consolidated.{n:02}.pth" for n in range(0, num_parts) )
249+
250+
for part_name in part_names:
251+
print("gguf: loading model part '"+ part_name + "'")
252+
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
253+
254+
for name in model_part.keys():
255+
data = model_part[name]
256+
257+
258+
old_dtype = data.dtype
259+
260+
# we don't need these
261+
if name == "rope.freqs":
262+
continue
263+
264+
# convert any unsupported data types to float32
265+
if data.dtype != torch.float16 and data.dtype != torch.float32:
266+
data = data.to(torch.float32)
267+
268+
data = data.squeeze().numpy()
269+
270+
# map tensor names
271+
if name.endswith(".weight") and name[:-7] in tensor_map:
272+
name = tensor_map[name[:-7]] + ".weight"
273+
elif name.endswith(".bias") and name[:-5] in tensor_map:
274+
name = tensor_map[name[:-5]] + ".bias"
275+
else:
276+
print( "Can not map tensor '" + name + "'" )
277+
sys.exit()
278+
279+
n_dims = len(data.shape)
280+
data_dtype = data.dtype
281+
282+
# if f32 desired, convert any float16 to float32
283+
if ftype == 0 and data.dtype == np.float16:
284+
data = data.astype(np.float32)
285+
286+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
287+
if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
288+
data = data.astype(np.float32)
289+
290+
# if f16 desired, convert any float32 2-dim weight tensors to float16
291+
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
292+
data = data.astype(np.float16)
293+
294+
print( name + ", shape " + str(len(data.shape)) + ", " + str(old_dtype) + " --> " + str(data.dtype))
295+
296+
gguf_writer.write_tensor_to_file(data)
297+
298+
gguf_writer.close()
299+
300+
301+
print("gguf: model successfully exported to '" + fname_out + "'")
302+
print("")

0 commit comments

Comments
 (0)