Skip to content

Commit 32fc925

Browse files
committed
convert.py : add ftype when converting (does not work)
1 parent bee1f0e commit 32fc925

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

convert.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ class UnquantizedDataType:
6969
'I32': DT_I32,
7070
}
7171

72+
# TODO: match this with `llama_ftype`
73+
# TODO: rename to LLAMAFileType
74+
# TODO: move to `gguf.py`
7275
class GGMLFileType(enum.Enum):
7376
AllF32 = 0
7477
MostlyF16 = 1 # except 1d tensors
@@ -101,6 +104,8 @@ class Params:
101104
n_head_kv: int
102105
f_norm_eps: float
103106

107+
ftype: Optional[GGMLFileType] = None
108+
104109
@staticmethod
105110
def find_n_mult(n_ff: int, n_embd: int) -> int:
106111
# hardcoded magic range
@@ -738,6 +743,9 @@ def add_meta_arch(self, params: Params) -> None:
738743
self.gguf.add_head_count_kv (params.n_head_kv)
739744
self.gguf.add_layer_norm_rms_eps (params.f_norm_eps)
740745

746+
if params.ftype:
747+
self.gguf.add_file_type(params.ftype)
748+
741749
def add_meta_vocab(self, vocab: Vocab) -> None:
742750
tokens = []
743751
scores = []
@@ -1020,6 +1028,12 @@ def main(args_in: Optional[List[str]] = None) -> None:
10201028
" - LLaMA v2: --ctx 4096\n")
10211029
params.n_ctx = args.ctx
10221030

1031+
if args.outtype:
1032+
params.ftype = {
1033+
"f32": GGMLFileType.AllF32,
1034+
"f16": GGMLFileType.MostlyF16,
1035+
}[args.outtype]
1036+
10231037
print(f"params = {params}")
10241038

10251039
vocab: Vocab
@@ -1040,11 +1054,14 @@ def main(args_in: Optional[List[str]] = None) -> None:
10401054
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
10411055
vocab = load_vocab(vocab_dir, args.vocabtype)
10421056

1043-
model = model_plus.model
1044-
model = convert_model_names(model, params)
1045-
output_type = pick_output_type(model, args.outtype)
1046-
model = convert_to_output_type(model, output_type)
1047-
outfile = args.outfile or default_outfile(model_plus.paths, output_type)
1057+
model = model_plus.model
1058+
model = convert_model_names(model, params)
1059+
ftype = pick_output_type(model, args.outtype)
1060+
model = convert_to_output_type(model, ftype)
1061+
outfile = args.outfile or default_outfile(model_plus.paths, ftype)
1062+
1063+
params.ftype = ftype
1064+
print(f"Writing {outfile}, format {ftype}")
10481065

10491066
OutputFile.write_all(outfile, params, model, vocab)
10501067
print(f"Wrote {outfile}")

gguf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ def add_source_hf_repo(self, repo: str):
597597
self.add_string(KEY_GENERAL_SOURCE_HF_REPO, repo)
598598

599599
def add_file_type(self, ftype: int):
600-
self.add_string(KEY_GENERAL_FILE_TYPE, file_type)
600+
self.add_uint32(KEY_GENERAL_FILE_TYPE, ftype)
601601

602602
def add_name(self, name: str):
603603
self.add_string(KEY_GENERAL_NAME, name)

0 commit comments

Comments
 (0)