Skip to content

Commit 86bc9d2

Browse files
committed
convert-new.py : tensor name mapping
1 parent e970845 commit 86bc9d2

File tree

2 files changed

+205
-113
lines changed

2 files changed

+205
-113
lines changed

convert-new.py

Lines changed: 34 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,6 @@ class UnquantizedDataType:
4545

4646
DataType = Union[UnquantizedDataType]
4747

48-
DATA_TYPE_TO_FTYPE: Dict[DataType, int] = {
49-
DT_F32: 0,
50-
DT_F16: 1,
51-
}
52-
53-
FTYPE_TO_DATA_TYPE: Dict[int, DataType] = \
54-
{ftype: dtype for (dtype, ftype) in DATA_TYPE_TO_FTYPE.items()}
55-
5648
DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
5749
DT_BF16: np.dtype(np.uint16),
5850
DT_F16: np.dtype(np.float16),
@@ -78,31 +70,6 @@ def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType:
7870
else:
7971
raise ValueError(self)
8072

81-
# TODO: this is LLaMA specific
82-
def make_tensors_list() -> List[str]:
83-
ret = [
84-
'tok_embeddings.weight',
85-
'norm.weight',
86-
'output.weight',
87-
]
88-
for i in range(80): # maximum number of layer
89-
ret += [
90-
f'layers.{i}.attention.wq.weight',
91-
f'layers.{i}.attention.wk.weight',
92-
f'layers.{i}.attention.wv.weight',
93-
f'layers.{i}.attention.wo.weight',
94-
f'layers.{i}.attention_norm.weight',
95-
f'layers.{i}.feed_forward.w1.weight',
96-
f'layers.{i}.feed_forward.w2.weight',
97-
f'layers.{i}.feed_forward.w3.weight',
98-
f'layers.{i}.ffn_norm.weight',
99-
]
100-
return ret
101-
102-
# TODO: this should be generalized for non-LLaMA models
103-
TENSORS_LIST = make_tensors_list()
104-
TENSORS_SET = set(TENSORS_LIST)
105-
10673
def find_n_mult(n_ff: int, n_embd: int) -> int:
10774
# hardcoded magic range
10875
for n_mult in range(8192, 1, -1):
@@ -533,34 +500,6 @@ def load() -> Tensor:
533500
s[0] = s[0] // 3
534501
return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description)
535502

536-
def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
537-
out: LazyModel = {}
538-
out["tok_embeddings.weight"] = model["model.embed_tokens.weight"]
539-
out["norm.weight"] = model["model.norm.weight"]
540-
out["output.weight"] = model["lm_head.weight"]
541-
542-
for i in itertools.count():
543-
if f"model.layers.{i}.self_attn.q_proj.weight" in model:
544-
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head, params.n_head_kv)
545-
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_head_kv)
546-
out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
547-
elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
548-
out[f"layers.{i}.attention.wq.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head)
549-
out[f"layers.{i}.attention.wk.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head)
550-
out[f"layers.{i}.attention.wv.weight"] = part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 2)
551-
else:
552-
break
553-
554-
out[f"layers.{i}.attention.wo.weight"] = model[f"model.layers.{i}.self_attn.o_proj.weight"]
555-
556-
out[f"layers.{i}.feed_forward.w1.weight"] = model[f"model.layers.{i}.mlp.gate_proj.weight"]
557-
out[f"layers.{i}.feed_forward.w2.weight"] = model[f"model.layers.{i}.mlp.down_proj.weight"]
558-
out[f"layers.{i}.feed_forward.w3.weight"] = model[f"model.layers.{i}.mlp.up_proj.weight"]
559-
560-
out[f"layers.{i}.attention_norm.weight"] = model[f"model.layers.{i}.input_layernorm.weight"]
561-
out[f"layers.{i}.ffn_norm.weight"] = model[f"model.layers.{i}.post_attention_layernorm.weight"]
562-
return out
563-
564503

565504
# Functionality that simulates `torch.load` but where individual tensors are
566505
# only loaded into memory on demand, not all at once.
@@ -750,8 +689,8 @@ class OutputFile:
750689
def __init__(self, fname_out: Path) -> None:
751690
self.gguf = gguf.GGUFWriter.open(fname_out)
752691

753-
def add_meta_arch(self, params: Params, file_type: GGMLFileType) -> None:
754-
llm_arch = "llama"
692+
def add_meta_arch(self, params: Params) -> None:
693+
llm_arch = gguf.MODEL_ARCH_NAMES[gguf.MODEL_ARCH.LLAMA]
755694

756695
self.gguf.add_architecture (llm_arch)
757696
self.gguf.add_context_length (llm_arch, params.n_ctx)
@@ -763,13 +702,6 @@ def add_meta_arch(self, params: Params, file_type: GGMLFileType) -> None:
763702
self.gguf.add_head_count_kv (llm_arch, params.n_head_kv)
764703
self.gguf.add_layer_norm_rms_eps (llm_arch, params.f_norm_eps)
765704

766-
#def write_tensor_header(self, name: str, shape: Sequence[int], data_type: DataType) -> None:
767-
# sname = name.encode('utf-8')
768-
# self.fout.write(struct.pack("iii", len(shape), len(sname), DATA_TYPE_TO_FTYPE[data_type]))
769-
# self.fout.write(struct.pack("i" * len(shape), *shape[::-1]))
770-
# self.fout.write(sname)
771-
# self.fout.seek((self.fout.tell() + 31) & -32)
772-
773705
def add_meta_vocab(self, vocab: Vocab) -> None:
774706
tokens = []
775707
scores = []
@@ -794,17 +726,17 @@ def close(self) -> None:
794726
@staticmethod
795727
def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab) -> None:
796728
of = OutputFile(fname_out)
797-
of.add_meta_arch(params, file_type=GGMLFileType.AllF32)
729+
of.add_meta_arch(params)
798730
of.add_meta_vocab(vocab)
799731
of.write_meta()
800732
of.close()
801733

802734
@staticmethod
803-
def write_all(fname_out: Path, params: Params, file_type: GGMLFileType, model: LazyModel, vocab: Vocab) -> None:
735+
def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None:
804736
check_vocab_size(params, vocab)
805737

806738
of = OutputFile(fname_out)
807-
of.add_meta_arch(params, file_type)
739+
of.add_meta_arch(params)
808740
of.add_meta_vocab(vocab)
809741

810742
def do_item(item: Tuple[str, LazyTensor]) -> NDArray:
@@ -822,21 +754,39 @@ def do_item(item: Tuple[str, LazyTensor]) -> NDArray:
822754

823755

824756
def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType:
825-
wq_type = model["layers.0.attention.wq.weight"].data_type
757+
wq_type = model[gguf.MODEL_TENSOR_NAMES[gguf.MODEL_ARCH.LLAMA][gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type
758+
826759
if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)):
827760
return GGMLFileType.AllF32
828761
if output_type_str == "f16" or (output_type_str is None and wq_type == DT_F16):
829762
return GGMLFileType.MostlyF16
763+
830764
name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}
765+
831766
raise Exception(f"Unexpected combination of types: {name_to_type}")
832767

833768

834-
def do_necessary_conversions(model: LazyModel, params: Params) -> LazyModel:
835-
if "lm_head.weight" in model:
836-
model = convert_transformers_to_orig(model, params)
837-
model = filter_and_sort_tensors(model)
769+
def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
770+
tmap = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAMA, params.n_layer)
838771

839-
return model
772+
out: LazyModel = {}
773+
for name, lazy_tensor in model.items():
774+
name_new = name
775+
776+
if name in tmap:
777+
name_new = tmap[name]
778+
elif name.endswith(".weight") and name[:-7] in tmap:
779+
name_new = tmap[name[:-7]] + ".weight"
780+
elif name.endswith(".bias") and name[:-5] in tmap:
781+
name_new = tmap[name[:-5]] + ".bias"
782+
else:
783+
raise Exception(f"Unexpected tensor name: {name}")
784+
785+
out[name_new] = lazy_tensor
786+
787+
print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type} | {lazy_tensor.shape}")
788+
789+
return out
840790

841791

842792
def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
@@ -893,11 +843,6 @@ def load_some_model(path: Path) -> ModelPlus:
893843
# Try the PyTorch patterns too, with lower priority
894844
globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"]
895845
files = [file for glob in globs for file in path.glob(glob)]
896-
if not files:
897-
# Try GGML too, but with lower priority, since if both a non-GGML
898-
# model and a GGML model exist in the same directory, we assume the
899-
# latter was converted from the former.
900-
files = list(path.glob("ggml-model*.bin*"))
901846
if not files:
902847
raise Exception(f"Can't find model in directory {path}")
903848
if len(files) > 1:
@@ -914,10 +859,6 @@ def load_some_model(path: Path) -> ModelPlus:
914859
return model_plus
915860

916861

917-
def filter_and_sort_tensors(model: LazyModel) -> LazyModel:
918-
return {name: model[name] for name in TENSORS_LIST if name in model}
919-
920-
921862
def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, SentencePieceVocab]:
922863
# Be extra-friendly and accept either a file or a directory. Also, if it's
923864
# a directory, it might be the model directory, and tokenizer.model might
@@ -937,8 +878,10 @@ def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, Sentence
937878
raise FileNotFoundError(
938879
f"Could not find tokenizer.model in {path} or its parent; "
939880
"if it's in another directory, pass the directory as --vocab-dir")
940-
added_tokens_path = path.parent / "added_tokens.json"
881+
941882
print(f"Loading vocab file '{path}', type '{vocabtype}'")
883+
884+
added_tokens_path = path.parent / "added_tokens.json"
942885
if vocabtype == "bpe":
943886
return BpeVocab(path, added_tokens_path if added_tokens_path.exists() else None)
944887
elif vocabtype == "spm":
@@ -1018,12 +961,12 @@ def main(args_in: Optional[List[str]] = None) -> None:
1018961
vocab = load_vocab(vocab_dir, args.vocabtype)
1019962

1020963
model = model_plus.model
1021-
model = do_necessary_conversions(model, params) # TODO: utilize gguf.get_tensor_name_map
964+
model = convert_model_names(model, params) # TODO: utilize gguf.get_tensor_name_map
1022965
output_type = pick_output_type(model, args.outtype)
1023966
model = convert_to_output_type(model, output_type)
1024967
outfile = args.outfile or default_outfile(model_plus.paths, output_type)
1025968

1026-
OutputFile.write_all(outfile, params, output_type, model, vocab)
969+
OutputFile.write_all(outfile, params, model, vocab)
1027970
print(f"Wrote {outfile}")
1028971

1029972

0 commit comments

Comments
 (0)