Skip to content

Commit 7a83f20

Browse files
committed
fix ftype
1 parent 84288ff commit 7a83f20

File tree

1 file changed

+10
-21
lines changed

1 file changed

+10
-21
lines changed

convert_lora_to_gguf.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,12 @@
55

66
import logging
77
import argparse
8-
import contextlib
9-
import json
108
import os
11-
import re
129
import sys
1310
import types
14-
from enum import IntEnum
1511
from pathlib import Path
16-
from hashlib import sha256
17-
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
12+
from typing import TYPE_CHECKING, Iterable, Iterator
1813

19-
import math
20-
import numpy as np
2114
import torch
2215

2316
if TYPE_CHECKING:
@@ -32,22 +25,17 @@
3225

3326
logger = logging.getLogger("lora-to-gguf")
3427

28+
3529
def parse_args() -> argparse.Namespace:
36-
all_models = ", ".join([arch for arch in Model._model_classes.keys()])
3730
parser = argparse.ArgumentParser(
38-
description="Convert a huggingface model to a GGML compatible file")
31+
description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file")
3932
parser.add_argument(
4033
"--outfile", type=Path,
4134
help="path to write to; default: based on input.",
4235
)
4336
parser.add_argument(
44-
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
45-
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
46-
)
47-
parser.add_argument(
48-
"--arch", type=str,
49-
help=f"Arch of the base model, must be one of: {all_models} (default: LlamaForCausalLM)",
50-
default="LlamaForCausalLM"
37+
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16",
38+
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0",
5139
)
5240
parser.add_argument(
5341
"--bigendian", action="store_true",
@@ -73,14 +61,13 @@ def parse_args() -> argparse.Namespace:
7361
args = parse_args()
7462
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
7563

76-
# FIXME: outtype is not working
7764
ftype_map: dict[str, gguf.LlamaFileType] = {
7865
"f32": gguf.LlamaFileType.ALL_F32,
7966
"f16": gguf.LlamaFileType.MOSTLY_F16,
8067
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
8168
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
82-
"auto": gguf.LlamaFileType.GUESSED,
8369
}
70+
ftype = ftype_map[args.outtype]
8471

8572
dir_base_model = args.base
8673
dir_lora = args.lora_path
@@ -110,7 +97,7 @@ def parse_args() -> argparse.Namespace:
11097
logger.error(f"Model {hparams['architectures'][0]} is not supported")
11198
sys.exit(1)
11299

113-
model_instance = model_class(dir_base_model, ftype_map[args.outtype], fname_out, args.bigendian, False, False, None)
100+
model_instance = model_class(dir_base_model, ftype, fname_out, args.bigendian, False, False, None)
114101
logger.info("Set model parameters")
115102
model_instance.set_gguf_parameters()
116103

@@ -140,16 +127,18 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
140127
# overwrite method
141128
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
142129
del bid # unused
130+
# TODO: This will not take into account tensor transformations
143131
return [(name, data_torch)]
144132

145133
# overwrite method
146134
def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
147135
del name, new_name, bid, n_dims # unused
148-
return True
136+
return ftype != gguf.LlamaFileType.ALL_F32
149137

150138
model_instance.get_tensors = types.MethodType(get_tensors, model_instance)
151139
model_instance.modify_tensors = types.MethodType(modify_tensors, model_instance)
152140
model_instance.extra_f16_tensors = types.MethodType(extra_f16_tensors, model_instance)
141+
153142
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
154143
logger.info("Exporting model...")
155144
model_instance.write()

0 commit comments

Comments
 (0)