Skip to content

Commit 7eaa315

Browse files
committed
convert-new.py : add map for skipping tensor serialization
1 parent 86bc9d2 commit 7eaa315

File tree

2 files changed

+65
-37
lines changed

2 files changed

+65
-37
lines changed

convert-new.py

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@
3434

3535
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
3636

37+
ARCH=gguf.MODEL_ARCH.LLAMA
38+
NAMES=gguf.MODEL_TENSOR_NAMES[ARCH]
39+
40+
#
41+
# data types
42+
#
43+
3744
@dataclass(frozen=True)
3845
class UnquantizedDataType:
3946
name: str
@@ -55,6 +62,13 @@ class UnquantizedDataType:
5562
NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = \
5663
{dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()}
5764

65+
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
66+
'BF16': DT_BF16,
67+
'F16': DT_F16,
68+
'F32': DT_F32,
69+
'I32': DT_I32,
70+
}
71+
5872
class GGMLFileType(enum.Enum):
5973
AllF32 = 0
6074
MostlyF16 = 1 # except 1d tensors
@@ -70,14 +84,10 @@ def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType:
7084
else:
7185
raise ValueError(self)
7286

73-
def find_n_mult(n_ff: int, n_embd: int) -> int:
74-
# hardcoded magic range
75-
for n_mult in range(8192, 1, -1):
76-
calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult
77-
if calc_ff == n_ff:
78-
return n_mult
79-
raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).")
8087

88+
#
89+
# hparams loading
90+
#
8191

8292
@dataclass
8393
class Params:
@@ -91,6 +101,15 @@ class Params:
91101
n_head_kv: int
92102
f_norm_eps: float
93103

104+
@staticmethod
105+
def find_n_mult(n_ff: int, n_embd: int) -> int:
106+
# hardcoded magic range
107+
for n_mult in range(8192, 1, -1):
108+
calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult
109+
if calc_ff == n_ff:
110+
return n_mult
111+
raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).")
112+
94113
@staticmethod
95114
def guessed(model: 'LazyModel') -> 'Params':
96115
# try transformer naming first
@@ -139,7 +158,7 @@ def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
139158
n_head_kv = config["num_key_value_heads"];
140159
f_norm_eps = config["rms_norm_eps"];
141160

142-
n_mult = find_n_mult(n_ff, n_embd);
161+
n_mult = Params.find_n_mult(n_ff, n_embd);
143162

144163
if "max_sequence_length" in config:
145164
n_ctx = config["max_sequence_length"]
@@ -210,6 +229,10 @@ def load(model_plus: 'ModelPlus') -> 'Params':
210229
return params
211230

212231

232+
#
233+
# vocab
234+
#
235+
213236
class BpeVocab:
214237
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
215238
self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read())
@@ -294,10 +317,14 @@ def all_tokens(self) -> Iterable[Tuple[bytes, float]]:
294317
def __repr__(self) -> str:
295318
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
296319

297-
298320
Vocab = Union[BpeVocab, SentencePieceVocab]
299321

300322

323+
#
324+
# data loading
325+
# TODO: reuse (probably move to gguf.py?)
326+
#
327+
301328
def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
302329
if n_head_kv is not None and n_head != n_head_kv:
303330
n_head //= n_head_kv
@@ -593,14 +620,6 @@ def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus:
593620
return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None)
594621

595622

596-
SAFETENSORS_DATA_TYPES: Dict[str, DataType] = {
597-
'BF16': DT_BF16,
598-
'F16': DT_F16,
599-
'F32': DT_F32,
600-
'I32': DT_I32,
601-
}
602-
603-
604623
def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus:
605624
header_size, = struct.unpack('<Q', fp.read(8))
606625
header: Dict[str, Dict[str, Any]] = json.loads(fp.read(header_size))
@@ -690,17 +709,16 @@ def __init__(self, fname_out: Path) -> None:
690709
self.gguf = gguf.GGUFWriter.open(fname_out)
691710

692711
def add_meta_arch(self, params: Params) -> None:
693-
llm_arch = gguf.MODEL_ARCH_NAMES[gguf.MODEL_ARCH.LLAMA]
694-
695-
self.gguf.add_architecture (llm_arch)
696-
self.gguf.add_context_length (llm_arch, params.n_ctx)
697-
self.gguf.add_embedding_length (llm_arch, params.n_embd)
698-
self.gguf.add_block_count (llm_arch, params.n_layer)
699-
self.gguf.add_feed_forward_length (llm_arch, params.n_ff)
700-
self.gguf.add_rope_dimension_count(llm_arch, params.n_embd // params.n_head)
701-
self.gguf.add_head_count (llm_arch, params.n_head)
702-
self.gguf.add_head_count_kv (llm_arch, params.n_head_kv)
703-
self.gguf.add_layer_norm_rms_eps (llm_arch, params.f_norm_eps)
712+
arch = gguf.MODEL_ARCH_NAMES[ARCH]
713+
self.gguf.add_architecture (arch)
714+
self.gguf.add_context_length (arch, params.n_ctx)
715+
self.gguf.add_embedding_length (arch, params.n_embd)
716+
self.gguf.add_block_count (arch, params.n_layer)
717+
self.gguf.add_feed_forward_length (arch, params.n_ff)
718+
self.gguf.add_rope_dimension_count(arch, params.n_embd // params.n_head)
719+
self.gguf.add_head_count (arch, params.n_head)
720+
self.gguf.add_head_count_kv (arch, params.n_head_kv)
721+
self.gguf.add_layer_norm_rms_eps (arch, params.f_norm_eps)
704722

705723
def add_meta_vocab(self, vocab: Vocab) -> None:
706724
tokens = []
@@ -754,7 +772,7 @@ def do_item(item: Tuple[str, LazyTensor]) -> NDArray:
754772

755773

756774
def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType:
757-
wq_type = model[gguf.MODEL_TENSOR_NAMES[gguf.MODEL_ARCH.LLAMA][gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type
775+
wq_type = model[NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type
758776

759777
if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)):
760778
return GGMLFileType.AllF32
@@ -767,7 +785,7 @@ def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFi
767785

768786

769787
def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
770-
tmap = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAMA, params.n_layer)
788+
tmap = gguf.get_tensor_name_map(ARCH, params.n_layer)
771789

772790
out: LazyModel = {}
773791
for name, lazy_tensor in model.items():
@@ -782,9 +800,11 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
782800
else:
783801
raise Exception(f"Unexpected tensor name: {name}")
784802

785-
out[name_new] = lazy_tensor
786-
787-
print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type} | {lazy_tensor.shape}")
803+
if gguf.should_skip_tensor(ARCH, params.n_layer, name_new):
804+
print(f"skipping tensor {name_new}")
805+
else:
806+
print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type} | {lazy_tensor.shape}")
807+
out[name_new] = lazy_tensor
788808

789809
return out
790810

@@ -961,7 +981,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
961981
vocab = load_vocab(vocab_dir, args.vocabtype)
962982

963983
model = model_plus.model
964-
model = convert_model_names(model, params) # TODO: utilize gguf.get_tensor_name_map
984+
model = convert_model_names(model, params)
965985
output_type = pick_output_type(model, args.outtype)
966986
model = convert_to_output_type(model, output_type)
967987
outfile = args.outfile or default_outfile(model_plus.paths, output_type)

gguf.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,19 @@ class MODEL_TENSOR(IntEnum):
159159

160160
# tensors that will not be serialized
161161
MODEL_TENSOR_SKIP = {
162-
MODEL_ARCH.LLAMA : {
162+
MODEL_ARCH.LLAMA : [
163163
MODEL_TENSOR.ROPE_FREQS,
164164
MODEL_TENSOR.ATTN_ROT_EMBD,
165-
},
166-
},
165+
],
166+
}
167+
168+
def should_skip_tensor(arch : MODEL_ARCH, n_blocks : int, name : str) -> bool:
169+
for skip in MODEL_TENSOR_SKIP.get(arch, []):
170+
for i in range(n_blocks):
171+
if name == MODEL_TENSOR_NAMES[arch][skip].format(bid=i):
172+
return True
173+
174+
return False
167175

168176
def get_tensor_name_map(arch : MODEL_ARCH, n_blocks : int) -> dict:
169177
tensor_map = {}

0 commit comments

Comments
 (0)