|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | + |
| 4 | +from __future__ import annotations |
| 5 | + |
| 6 | +import logging |
| 7 | +import argparse |
| 8 | +import os |
| 9 | +import sys |
| 10 | +import types |
| 11 | +from pathlib import Path |
| 12 | +from typing import TYPE_CHECKING, Iterable, Iterator |
| 13 | + |
| 14 | +import torch |
| 15 | + |
| 16 | +if TYPE_CHECKING: |
| 17 | + from torch import Tensor |
| 18 | + |
| 19 | +if 'NO_LOCAL_GGUF' not in os.environ: |
| 20 | + sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) |
| 21 | +import gguf |
| 22 | + |
| 23 | +# reuse model definitions from convert_hf_to_gguf.py |
| 24 | +from convert_hf_to_gguf import Model |
| 25 | + |
| 26 | +logger = logging.getLogger("lora-to-gguf") |
| 27 | + |
| 28 | + |
| 29 | +def parse_args() -> argparse.Namespace: |
| 30 | + parser = argparse.ArgumentParser( |
| 31 | + description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file") |
| 32 | + parser.add_argument( |
| 33 | + "--outfile", type=Path, |
| 34 | + help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", |
| 35 | + ) |
| 36 | + parser.add_argument( |
| 37 | + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16", |
| 38 | + help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0", |
| 39 | + ) |
| 40 | + parser.add_argument( |
| 41 | + "--bigendian", action="store_true", |
| 42 | + help="model is executed on big endian machine", |
| 43 | + ) |
| 44 | + parser.add_argument( |
| 45 | + "--verbose", action="store_true", |
| 46 | + help="increase output verbosity", |
| 47 | + ) |
| 48 | + parser.add_argument( |
| 49 | + "--base", type=Path, required=True, |
| 50 | + help="directory containing base model file", |
| 51 | + ) |
| 52 | + parser.add_argument( |
| 53 | + "lora_path", type=Path, |
| 54 | + help="directory containing LoRA adapter file", |
| 55 | + ) |
| 56 | + |
| 57 | + return parser.parse_args() |
| 58 | + |
| 59 | + |
| 60 | +if __name__ == '__main__': |
| 61 | + args = parse_args() |
| 62 | + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) |
| 63 | + |
| 64 | + ftype_map: dict[str, gguf.LlamaFileType] = { |
| 65 | + "f32": gguf.LlamaFileType.ALL_F32, |
| 66 | + "f16": gguf.LlamaFileType.MOSTLY_F16, |
| 67 | + "bf16": gguf.LlamaFileType.MOSTLY_BF16, |
| 68 | + "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, |
| 69 | + } |
| 70 | + ftype = ftype_map[args.outtype] |
| 71 | + |
| 72 | + dir_base_model = args.base |
| 73 | + dir_lora = args.lora_path |
| 74 | + input_json = os.path.join(dir_lora, "adapter_config.json") |
| 75 | + input_model = os.path.join(dir_lora, "adapter_model.bin") |
| 76 | + if args.outfile is not None: |
| 77 | + fname_out = args.outfile |
| 78 | + else: |
| 79 | + # output in the same directory as the model by default |
| 80 | + fname_out = dir_lora / 'ggml-lora-{ftype}.gguf' |
| 81 | + |
| 82 | + if os.path.exists(input_model): |
| 83 | + lora_model = torch.load(input_model, map_location="cpu") |
| 84 | + else: |
| 85 | + input_model = os.path.join(dir_lora, "adapter_model.safetensors") |
| 86 | + # lazy import load_file only if lora is in safetensors format. |
| 87 | + from safetensors.torch import load_file |
| 88 | + lora_model = load_file(input_model, device="cpu") |
| 89 | + |
| 90 | + # load base model |
| 91 | + logger.info(f"Loading base model: {dir_base_model.name}") |
| 92 | + hparams = Model.load_hparams(dir_base_model) |
| 93 | + with torch.inference_mode(): |
| 94 | + try: |
| 95 | + model_class = Model.from_model_architecture(hparams["architectures"][0]) |
| 96 | + except NotImplementedError: |
| 97 | + logger.error(f"Model {hparams['architectures'][0]} is not supported") |
| 98 | + sys.exit(1) |
| 99 | + |
| 100 | + model_instance = model_class(dir_base_model, ftype, fname_out, args.bigendian, False, False, None) |
| 101 | + logger.info("Set model parameters") |
| 102 | + model_instance.set_gguf_parameters() |
| 103 | + |
| 104 | + # adapter_config = json.load(input_json) |
| 105 | + model_instance.gguf_writer.add_string("training.type", "finetune_lora") |
| 106 | + |
| 107 | + map_tensors: dict[str, Tensor] = {} |
| 108 | + for tensor_name, tensor in lora_model.items(): |
| 109 | + orig_name = tensor_name.replace("base_model.model.", "") |
| 110 | + orig_name = orig_name.replace(".lora_A.weight", ".weight") |
| 111 | + orig_name = orig_name.replace(".lora_B.weight", ".weight") |
| 112 | + is_lora_a = ".lora_A.weight" in tensor_name |
| 113 | + is_lora_b = ".lora_B.weight" in tensor_name |
| 114 | + if not is_lora_a and not is_lora_b: |
| 115 | + logger.error(f"Unexpected name '{tensor_name}': Not a lora_A or lora_B tensor") |
| 116 | + sys.exit(1) |
| 117 | + dest_name = model_instance.map_tensor_name(orig_name) |
| 118 | + dest_name = f"{dest_name}.lora_a" if is_lora_a else f"{dest_name}.lora_b" |
| 119 | + # logger.info(f"{orig_name} --> {dest_name}") |
| 120 | + map_tensors[dest_name] = tensor |
| 121 | + |
| 122 | + # overwrite method |
| 123 | + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: |
| 124 | + for name, tensor in map_tensors.items(): |
| 125 | + yield (name, tensor) |
| 126 | + |
| 127 | + # overwrite method |
| 128 | + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: |
| 129 | + del bid # unused |
| 130 | + # TODO: This will not take into account tensor transformations |
| 131 | + return [(name, data_torch)] |
| 132 | + |
| 133 | + # overwrite method |
| 134 | + def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: |
| 135 | + del name, new_name, bid, n_dims # unused |
| 136 | + return ftype != gguf.LlamaFileType.ALL_F32 |
| 137 | + |
| 138 | + model_instance.get_tensors = types.MethodType(get_tensors, model_instance) |
| 139 | + model_instance.modify_tensors = types.MethodType(modify_tensors, model_instance) |
| 140 | + model_instance.extra_f16_tensors = types.MethodType(extra_f16_tensors, model_instance) |
| 141 | + |
| 142 | + model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) |
| 143 | + logger.info("Exporting model...") |
| 144 | + model_instance.write() |
| 145 | + logger.info(f"Model successfully exported to {fname_out}") |
0 commit comments