5
5
6
6
import logging
7
7
import argparse
8
- import contextlib
9
- import json
10
8
import os
11
- import re
12
9
import sys
13
10
import types
14
- from enum import IntEnum
15
11
from pathlib import Path
16
- from hashlib import sha256
17
- from typing import TYPE_CHECKING , Any , Callable , ContextManager , Iterable , Iterator , Literal , Sequence , TypeVar , cast
12
+ from typing import TYPE_CHECKING , Iterable , Iterator
18
13
19
- import math
20
- import numpy as np
21
14
import torch
22
15
23
16
if TYPE_CHECKING :
32
25
33
26
logger = logging .getLogger ("lora-to-gguf" )
34
27
28
+
35
29
def parse_args () -> argparse .Namespace :
36
- all_models = ", " .join ([arch for arch in Model ._model_classes .keys ()])
37
30
parser = argparse .ArgumentParser (
38
- description = "Convert a huggingface model to a GGML compatible file" )
31
+ description = "Convert a huggingface PEFT LoRA adapter to a GGML compatible file" )
39
32
parser .add_argument (
40
33
"--outfile" , type = Path ,
41
34
help = "path to write to; default: based on input." ,
42
35
)
43
36
parser .add_argument (
44
- "--outtype" , type = str , choices = ["f32" , "f16" , "bf16" , "q8_0" , "auto" ], default = "f16" ,
45
- help = "output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type" ,
46
- )
47
- parser .add_argument (
48
- "--arch" , type = str ,
49
- help = f"Arch of the base model, must be one of: { all_models } (default: LlamaForCausalLM)" ,
50
- default = "LlamaForCausalLM"
37
+ "--outtype" , type = str , choices = ["f32" , "f16" , "bf16" , "q8_0" ], default = "f16" ,
38
+ help = "output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0" ,
51
39
)
52
40
parser .add_argument (
53
41
"--bigendian" , action = "store_true" ,
@@ -73,14 +61,13 @@ def parse_args() -> argparse.Namespace:
73
61
args = parse_args ()
74
62
logging .basicConfig (level = logging .DEBUG if args .verbose else logging .INFO )
75
63
76
- # FIXME: outtype is not working
77
64
ftype_map : dict [str , gguf .LlamaFileType ] = {
78
65
"f32" : gguf .LlamaFileType .ALL_F32 ,
79
66
"f16" : gguf .LlamaFileType .MOSTLY_F16 ,
80
67
"bf16" : gguf .LlamaFileType .MOSTLY_BF16 ,
81
68
"q8_0" : gguf .LlamaFileType .MOSTLY_Q8_0 ,
82
- "auto" : gguf .LlamaFileType .GUESSED ,
83
69
}
70
+ ftype = ftype_map [args .outtype ]
84
71
85
72
dir_base_model = args .base
86
73
dir_lora = args .lora_path
@@ -110,7 +97,7 @@ def parse_args() -> argparse.Namespace:
110
97
logger .error (f"Model { hparams ['architectures' ][0 ]} is not supported" )
111
98
sys .exit (1 )
112
99
113
- model_instance = model_class (dir_base_model , ftype_map [ args . outtype ] , fname_out , args .bigendian , False , False , None )
100
+ model_instance = model_class (dir_base_model , ftype , fname_out , args .bigendian , False , False , None )
114
101
logger .info ("Set model parameters" )
115
102
model_instance .set_gguf_parameters ()
116
103
@@ -140,16 +127,18 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
140
127
# overwrite method
141
128
def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
142
129
del bid # unused
130
+ # TODO: This will not take into account tensor transformations
143
131
return [(name , data_torch )]
144
132
145
133
# overwrite method
146
134
def extra_f16_tensors (self , name : str , new_name : str , bid : int | None , n_dims : int ) -> bool :
147
135
del name , new_name , bid , n_dims # unused
148
- return True
136
+ return ftype != gguf . LlamaFileType . ALL_F32
149
137
150
138
model_instance .get_tensors = types .MethodType (get_tensors , model_instance )
151
139
model_instance .modify_tensors = types .MethodType (modify_tensors , model_instance )
152
140
model_instance .extra_f16_tensors = types .MethodType (extra_f16_tensors , model_instance )
141
+
153
142
model_instance .gguf_writer .add_quantization_version (gguf .GGML_QUANT_VERSION )
154
143
logger .info ("Exporting model..." )
155
144
model_instance .write ()
0 commit comments