Skip to content

Commit 8be49fd

Browse files
committed
convert-new.py : add gguf key-value pairs
1 parent 250cf83 commit 8be49fd

File tree

1 file changed

+99
-63
lines changed

1 file changed

+99
-63
lines changed

convert-new.py

Lines changed: 99 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,15 @@ def find_n_mult(n_ff: int, n_embd: int) -> int:
114114

115115
@dataclass
116116
class Params:
117-
n_vocab: int
118-
n_embd: int
119-
n_mult: int
120-
n_head: int
121-
n_layer: int
122-
n_ctx: int
123-
n_kv_head: Optional[int] # This parameter is only used for Llama 2
117+
n_vocab: int
118+
n_embd: int
119+
n_mult: int
120+
n_layer: int
121+
n_ctx: int
122+
n_ff: int
123+
n_head: int
124+
n_head_kv: int
125+
f_norm_eps: float
124126

125127
@staticmethod
126128
def guessed(model: 'LazyModel') -> 'Params':
@@ -139,28 +141,36 @@ def guessed(model: 'LazyModel') -> 'Params':
139141
raise Exception("failed to guess 'n_layer'. This model is unknown or unsupported.\n"
140142
"Suggestion: provide 'config.json' of the model in the same directory containing model files.")
141143

142-
n_head=n_embd // 128 # guessed
144+
n_head = n_embd // 128 # guessed
145+
n_mult = 255 # guessed
146+
147+
# TODO: verify this
148+
n_ff = int(2 * (4 * n_embd) / 3)
149+
n_ff = n_mult * ((n_ff + n_mult - 1) // n_mult)
143150

144151
return Params(
145-
n_vocab = n_vocab,
146-
n_embd = n_embd,
147-
n_mult = 256,
148-
n_head = n_head,
149-
n_layer = n_layer,
150-
n_ctx = -1,
151-
n_kv_head = None,
152+
n_vocab = n_vocab,
153+
n_embd = n_embd,
154+
n_mult = 256,
155+
n_layer = n_layer,
156+
n_ctx = -1,
157+
n_ff = n_ff,
158+
n_head = n_head,
159+
n_head_kv = n_head,
160+
f_norm_eps = 1e-5,
152161
)
153162

154163
@staticmethod
155164
def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
156165
config = json.load(open(config_path))
157166

158-
n_vocab = config["vocab_size"];
159-
n_embd = config["hidden_size"];
160-
n_head = config["num_attention_heads"];
161-
n_layer = config["num_hidden_layers"];
162-
n_ff = config["intermediate_size"];
163-
n_kv_head = config.get("num_key_value_heads")
167+
n_vocab = config["vocab_size"];
168+
n_embd = config["hidden_size"];
169+
n_layer = config["num_hidden_layers"];
170+
n_ff = config["intermediate_size"];
171+
n_head = config["num_attention_heads"];
172+
n_head_kv = config["num_key_value_heads"];
173+
f_norm_eps = config["rms_norm_eps"];
164174

165175
n_mult = find_n_mult(n_ff, n_embd);
166176

@@ -173,13 +183,15 @@ def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
173183
"Suggestion: provide 'config.json' of the model in the same directory containing model files.")
174184

175185
return Params(
176-
n_vocab = n_vocab,
177-
n_embd = n_embd,
178-
n_mult = n_mult,
179-
n_head = n_head,
180-
n_layer = n_layer,
181-
n_ctx = n_ctx,
182-
n_kv_head = n_kv_head,
186+
n_vocab = n_vocab,
187+
n_embd = n_embd,
188+
n_mult = n_mult,
189+
n_layer = n_layer,
190+
n_ctx = n_ctx,
191+
n_ff = n_ff,
192+
n_head = n_head,
193+
n_head_kv = n_head_kv,
194+
f_norm_eps = f_norm_eps,
183195
)
184196

185197
# LLaMA v2 70B params.json
@@ -188,23 +200,32 @@ def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
188200
def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
189201
config = json.load(open(config_path))
190202

191-
n_vocab = config["vocab_size"];
192-
n_embd = config["dim"];
193-
n_head = config["n_heads"];
194-
n_layer = config["n_layers"];
195-
n_mult = config["multiple_of"];
203+
n_vocab = config["vocab_size"];
204+
n_embd = config["dim"];
205+
n_layer = config["n_layers"];
206+
n_mult = config["multiple_of"];
207+
n_ctx = 2048 if config["norm_eps"] == 1e-06 else 4096 # hack to determine LLaMA v1 vs v2
208+
n_ff = -1;
209+
n_head = config["n_heads"];
210+
n_head_kv = config["n_kv_head"] if "n_kv_head" in config else n_head;
211+
f_norm_eps = config["norm_eps"];
196212

197213
if n_vocab == -1:
198214
n_vocab = model["tok_embeddings.weight"].shape[0]
199215

216+
if n_ff == -1:
217+
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
218+
200219
return Params(
201-
n_vocab = n_vocab,
202-
n_embd = n_embd,
203-
n_mult = n_mult,
204-
n_head = n_head,
205-
n_layer = n_layer,
206-
n_ctx = -1,
207-
n_kv_head = None,
220+
n_vocab = n_vocab,
221+
n_embd = n_embd,
222+
n_mult = n_mult,
223+
n_layer = n_layer,
224+
n_ctx = n_ctx,
225+
n_ff = n_ff,
226+
n_head = n_head,
227+
n_head_kv = n_head_kv,
228+
f_norm_eps = f_norm_eps,
208229
)
209230

210231
@staticmethod
@@ -310,9 +331,9 @@ def __repr__(self) -> str:
310331
Vocab = Union[BpeVocab, SentencePieceVocab]
311332

312333

313-
def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray:
314-
if n_kv_head is not None and n_head != n_kv_head:
315-
n_head //= n_kv_head
334+
def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
335+
if n_head_kv is not None and n_head != n_head_kv:
336+
n_head //= n_head_kv
316337
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
317338
.swapaxes(1, 2)
318339
.reshape(weights.shape))
@@ -324,7 +345,7 @@ class Tensor(metaclass=ABCMeta):
324345
@abstractmethod
325346
def astype(self, data_type: DataType) -> 'Tensor': ...
326347
@abstractmethod
327-
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'Tensor': ...
348+
def permute(self, n_head: int, n_head_kv: int) -> 'Tensor': ...
328349
@abstractmethod
329350
def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ...
330351
@abstractmethod
@@ -362,8 +383,8 @@ def part(self, n_part: int) -> 'UnquantizedTensor':
362383
r = self.ndarray.shape[0] // 3
363384
return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...])
364385

365-
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'UnquantizedTensor':
366-
return UnquantizedTensor(permute(self.ndarray, n_head, n_kv_head))
386+
def permute(self, n_head: int, n_head_kv: int) -> 'UnquantizedTensor':
387+
return UnquantizedTensor(permute(self.ndarray, n_head, n_head_kv))
367388

368389

369390
def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, convert: bool = False) -> NDArray:
@@ -386,18 +407,18 @@ def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, conv
386407

387408

388409
class DeferredPermutedTensor(Tensor):
389-
def __init__(self, base: Tensor, n_head: int, n_kv_head: Optional[int] = None) -> None:
410+
def __init__(self, base: Tensor, n_head: int, n_head_kv: int) -> None:
390411
self.base = base
391412
self.n_head = n_head
392413
self.data_type = self.base.data_type
393414

394415
def astype(self, data_type: DataType) -> Tensor:
395-
return self.base.astype(data_type).permute(self.n_head, self.n_kv_head)
416+
return self.base.astype(data_type).permute(self.n_head, self.n_head_kv)
396417

397418
def to_ggml(self) -> GGMLCompatibleTensor:
398-
return self.base.to_ggml().permute(self.n_head, self.n_kv_head)
419+
return self.base.to_ggml().permute(self.n_head, self.n_head_kv)
399420

400-
def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> Tensor:
421+
def permute(self, n_head: int, n_head_kv: int) -> Tensor:
401422
raise Exception("shouldn't permute twice")
402423

403424

@@ -493,10 +514,10 @@ def merge_multifile_models(models_plus: List[ModelPlus]) -> ModelPlus:
493514
return ModelPlus(model, paths, format, vocab)
494515

495516

496-
def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_kv_head: Optional[int] = None) -> LazyTensor:
517+
def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor:
497518
def load() -> Tensor:
498-
return lazy_tensor.load().permute(n_head, n_kv_head)
499-
return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_kv_head}) ' + lazy_tensor.description)
519+
return lazy_tensor.load().permute(n_head, n_head_kv)
520+
return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description)
500521

501522
def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor:
502523
def load() -> Tensor:
@@ -521,7 +542,7 @@ def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
521542
for i in itertools.count():
522543
if f"model.layers.{i}.self_attn.q_proj.weight" in model:
523544
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head)
524-
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_kv_head)
545+
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_head_kv)
525546
out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
526547
elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
527548
out[f"layers.{i}.attention.wq.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head)
@@ -732,9 +753,15 @@ def __init__(self, fname_out: Path) -> None:
732753
def write_file_header(self, params: Params, file_type: GGMLFileType) -> None:
733754
llm_arch = "llama"
734755

735-
self.gguf.add_architecture(llm_arch)
736-
self.gguf.add_context_length(llm_arch, params.n_ctx)
737-
self.gguf.add_embedding_length(llm_arch, params.n_embd)
756+
self.gguf.add_architecture (llm_arch)
757+
self.gguf.add_context_length (llm_arch, params.n_ctx)
758+
self.gguf.add_embedding_length (llm_arch, params.n_embd)
759+
self.gguf.add_block_count (llm_arch, params.n_layer)
760+
self.gguf.add_feed_forward_length (llm_arch, params.n_ff)
761+
self.gguf.add_rope_dimension_count(llm_arch, params.n_embd // params.n_head)
762+
self.gguf.add_head_count (llm_arch, params.n_head)
763+
self.gguf.add_head_count_kv (llm_arch, params.n_head_kv)
764+
self.gguf.add_layer_norm_rms_eps (llm_arch, params.f_norm_eps)
738765

739766
def write_tensor_header(self, name: str, shape: Sequence[int], data_type: DataType) -> None:
740767
sname = name.encode('utf-8')
@@ -744,15 +771,22 @@ def write_tensor_header(self, name: str, shape: Sequence[int], data_type: DataTy
744771
self.fout.seek((self.fout.tell() + 31) & -32)
745772

746773
def write_vocab(self, vocab: Vocab) -> None:
774+
tokens = []
775+
scores = []
747776
for text, score in vocab.all_tokens():
748-
self.fout.write(struct.pack("i", len(text)))
749-
self.fout.write(text)
750-
self.fout.write(struct.pack("f", score))
777+
tokens.append(text)
778+
scores.append(score)
779+
780+
self.gguf.add_tokenizer_model("llama")
781+
self.gguf.add_token_list(tokens)
782+
self.gguf.add_token_scores(scores)
783+
#self.gguf.add_token_types(toktypes) # TODO: add this
784+
785+
# TODO: added / special tokens
751786

752787
@staticmethod
753-
def write_vocab_only(fname_out: Path, vocab: Vocab) -> None:
788+
def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab) -> None:
754789
of = OutputFile(fname_out)
755-
params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0)
756790
of = OutputFile(fname_out)
757791
of.write_file_header(params, file_type=GGMLFileType.AllF32)
758792
of.write_vocab(vocab)
@@ -941,12 +975,12 @@ def main(args_in: Optional[List[str]] = None) -> None:
941975
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
942976
args = parser.parse_args(args_in)
943977

944-
vocab: Vocab
945978
if args.dump_single:
946979
model_plus = lazy_load_file(args.model)
947980
do_dump_model(model_plus)
948981

949982
model_plus = load_some_model(args.model)
983+
950984
params = Params.load(model_plus)
951985
if params.n_ctx == -1:
952986
if args.ctx is None:
@@ -958,6 +992,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
958992

959993
print(f"params = {params}")
960994

995+
vocab: Vocab
961996
if args.vocab_only:
962997
vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
963998
assert args.outfile, "need --outfile if using --vocab-only"
@@ -968,6 +1003,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
9681003
if args.dump:
9691004
do_dump_model(model_plus)
9701005
return
1006+
9711007
if model_plus.vocab is not None and args.vocab_dir is None:
9721008
vocab = model_plus.vocab
9731009
else:

0 commit comments

Comments
 (0)