Skip to content

Commit 9d96328

Browse files
committed
convert_lora : MoE LoRA conversion support
* convert_lora : prefer safetensors, similarly to convert_hf
1 parent 916e959 commit 9d96328

File tree

2 files changed

+216
-57
lines changed

2 files changed

+216
-57
lines changed

convert_hf_to_gguf.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -373,9 +373,6 @@ def from_model_architecture(cls, arch: str) -> type[Model]:
373373
except KeyError:
374374
raise NotImplementedError(f'Architecture {arch!r} not supported!') from None
375375

376-
def support_lora(self) -> bool:
377-
return False
378-
379376
# used for GPT-2 BPE and WordPiece vocabs
380377
def get_vocab_base(self) -> tuple[list[str], list[int], str]:
381378
tokens: list[str] = []
@@ -1415,9 +1412,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
14151412
n_head = self.hparams["num_attention_heads"]
14161413
n_kv_head = self.hparams.get("num_key_value_heads")
14171414

1418-
if name.endswith(("q_proj.weight", "q_proj.bias", "q_proj.lora_B.weight")):
1415+
if name.endswith(("q_proj.weight", "q_proj.bias")):
14191416
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
1420-
if name.endswith(("k_proj.weight", "k_proj.bias", "k_proj.lora_B.weight")):
1417+
if name.endswith(("k_proj.weight", "k_proj.bias")):
14211418
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
14221419

14231420
# process the experts separately
@@ -1465,10 +1462,6 @@ def write_tensors(self):
14651462
if len(experts) > 0:
14661463
raise ValueError(f"Unprocessed experts: {experts}")
14671464

1468-
def support_lora(self) -> bool:
1469-
# TODO: support lora conversion for MOE
1470-
return "num_local_experts" not in self.hparams
1471-
14721465

14731466
@Model.register("BitnetForCausalLM")
14741467
class BitnetModel(Model):

convert_lora_to_gguf.py

Lines changed: 214 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33

44
from __future__ import annotations
55

6+
from dataclasses import dataclass
67
import logging
78
import argparse
89
import os
910
import sys
10-
import types
1111
from pathlib import Path
12-
from typing import TYPE_CHECKING, Iterator
12+
from types import EllipsisType
13+
from typing import TYPE_CHECKING, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
1314

1415
import torch
1516

@@ -26,6 +27,169 @@
2627
logger = logging.getLogger("lora-to-gguf")
2728

2829

30+
@dataclass
31+
class PartialLoraTensor:
32+
A: Tensor | None = None
33+
B: Tensor | None = None
34+
35+
36+
# magic to support tensor shape modifications and splitting
37+
class LoraTorchTensor:
38+
_lora_A: Tensor
39+
_lora_B: Tensor
40+
_rank: int
41+
42+
def __init__(self, A: Tensor, B: Tensor):
43+
assert len(A.shape) == len(B.shape)
44+
if A.dtype != B.dtype:
45+
A = A.to(torch.float32)
46+
B = B.to(torch.float32)
47+
self._lora_A = A
48+
self._lora_B = B
49+
assert self._lora_A.shape[-2] == self._lora_B.shape[-1]
50+
self._rank = self._lora_B.shape[-1]
51+
52+
def __getitem__(
53+
self,
54+
indices: (
55+
SupportsIndex
56+
| slice
57+
| tuple[SupportsIndex | slice | EllipsisType | Tensor, ...]
58+
),
59+
) -> LoraTorchTensor:
60+
shape = self.shape
61+
if isinstance(indices, (SupportsIndex, slice)):
62+
if len(shape) > 2:
63+
return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
64+
else:
65+
raise NotImplementedError
66+
elif isinstance(indices, tuple):
67+
assert len(indices) > 0
68+
if isinstance(indices[-1], EllipsisType):
69+
return self[indices[:-1]]
70+
# expand ellipsis
71+
indices = tuple(
72+
u
73+
for v in (
74+
(
75+
(slice(None, None) for _ in range(len(indices) - 1))
76+
if isinstance(i, EllipsisType)
77+
else (i,)
78+
)
79+
for i in indices
80+
)
81+
for u in v
82+
)
83+
84+
if len(indices) < len(shape):
85+
indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape))))
86+
87+
# TODO: make sure this is correct
88+
# lora_A has a shape which looks like (..., 1, 1, rank, self.shape[-1])
89+
indices_A = (
90+
*(
91+
0 if isinstance(i, SupportsIndex) else slice(None, None)
92+
for i in indices[:-2]
93+
),
94+
slice(None, None),
95+
indices[-1],
96+
)
97+
indices_B = indices[:-1]
98+
return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B])
99+
else:
100+
raise NotImplementedError
101+
102+
@property
103+
def dtype(self) -> torch.dtype:
104+
assert self._lora_A.dtype == self._lora_B.dtype
105+
return self._lora_A.dtype
106+
107+
@property
108+
def shape(self) -> tuple[int, ...]:
109+
return (*self._lora_B.shape[:-1], self._lora_A.shape[-1])
110+
111+
def size(self, dim=None):
112+
assert dim is None
113+
return self.shape
114+
115+
def reshape(self, *shape: int | tuple[int]) -> LoraTorchTensor:
116+
if isinstance(shape[0], tuple):
117+
new_shape: tuple[int] = shape[0]
118+
else:
119+
new_shape = cast(tuple[int], shape)
120+
orig_shape = self.shape
121+
if new_shape[-1] != orig_shape[-1]:
122+
raise NotImplementedError
123+
return LoraTorchTensor(
124+
self._lora_A.reshape((*(1 for _ in new_shape[:-2]), *self._lora_A.shape[-2:])),
125+
self._lora_B.reshape((*new_shape[:-1], self._rank)),
126+
)
127+
128+
def reshape_as(self, other: Tensor) -> LoraTorchTensor:
129+
return self.reshape(*other.shape)
130+
131+
def view(self, *size: int) -> LoraTorchTensor:
132+
return self.reshape(*size)
133+
134+
def permute(self, *dims: int) -> LoraTorchTensor:
135+
shape = self.shape
136+
dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims)
137+
if dims[-1] == -2 and dims[-2] == -1:
138+
return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims))
139+
else:
140+
assert dims[-1] == -1
141+
assert all(dim == 1 for dim in self._lora_A.shape[:-2])
142+
return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims))
143+
144+
def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor:
145+
shape = self.shape
146+
dims = [i for i in range(len(shape))]
147+
dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
148+
return self.permute(*dims)
149+
150+
def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor:
151+
return self.transpose(axis0, axis1)
152+
153+
def to(self, *args, **kwargs):
154+
return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs))
155+
156+
@classmethod
157+
def __torch_function__(cls, func: Callable, types, args=(), kwargs=None):
158+
del types # unused
159+
160+
if kwargs is None:
161+
kwargs = {}
162+
163+
if func is torch.permute:
164+
return type(args[0]).permute(*args, **kwargs)
165+
elif func is torch.reshape:
166+
return type(args[0]).reshape(*args, **kwargs)
167+
elif func is torch.stack:
168+
assert isinstance(args[0], Sequence)
169+
dim = kwargs.get("dim", 0)
170+
assert dim == 0
171+
return LoraTorchTensor(
172+
torch.stack([a._lora_A for a in args[0]], dim),
173+
torch.stack([b._lora_B for b in args[0]], dim),
174+
)
175+
elif func is torch.cat:
176+
assert isinstance(args[0], Sequence)
177+
dim = kwargs.get("dim", 0)
178+
assert dim == 0
179+
if len(args[0][0].shape) > 2:
180+
return LoraTorchTensor(
181+
torch.cat([a._lora_A for a in args[0]], dim),
182+
torch.cat([b._lora_B for b in args[0]], dim),
183+
)
184+
else:
185+
return LoraTorchTensor(
186+
args[0][0]._lora_A, # TODO: is this correct? (can't cat over the rank)
187+
torch.cat([b._lora_B for b in args[0]], dim),
188+
)
189+
else:
190+
raise NotImplementedError
191+
192+
29193
def get_base_tensor_name(lora_tensor_name: str) -> str:
30194
base_name = lora_tensor_name.replace("base_model.model.", "")
31195
base_name = base_name.replace(".lora_A.weight", ".weight")
@@ -79,20 +243,21 @@ def parse_args() -> argparse.Namespace:
79243
dir_base_model = args.base
80244
dir_lora = args.lora_path
81245
input_json = os.path.join(dir_lora, "adapter_config.json")
82-
input_model = os.path.join(dir_lora, "adapter_model.bin")
246+
input_model = os.path.join(dir_lora, "adapter_model.safetensors")
83247
if args.outfile is not None:
84248
fname_out = args.outfile
85249
else:
86250
# output in the same directory as the model by default
87251
fname_out = dir_lora / 'ggml-lora-{ftype}.gguf'
88252

89253
if os.path.exists(input_model):
90-
lora_model = torch.load(input_model, map_location="cpu")
91-
else:
92-
input_model = os.path.join(dir_lora, "adapter_model.safetensors")
93254
# lazy import load_file only if lora is in safetensors format.
94255
from safetensors.torch import load_file
256+
95257
lora_model = load_file(input_model, device="cpu")
258+
else:
259+
input_model = os.path.join(dir_lora, "adapter_model.bin")
260+
lora_model = torch.load(input_model, map_location="cpu", weights_only=True)
96261

97262
# load base model
98263
logger.info(f"Loading base model: {dir_base_model.name}")
@@ -104,53 +269,54 @@ def parse_args() -> argparse.Namespace:
104269
logger.error(f"Model {hparams['architectures'][0]} is not supported")
105270
sys.exit(1)
106271

107-
model_instance = model_class(dir_base_model, ftype, fname_out, args.bigendian, False, False, None)
108-
logger.info("Set model parameters")
109-
model_instance.set_gguf_parameters()
272+
class LoraModel(model_class):
273+
model_arch = model_class.model_arch
110274

111-
# adapter_config = json.load(input_json)
112-
model_instance.gguf_writer.add_string("training.type", "finetune_lora")
113-
if not model_instance.support_lora():
114-
logger.error("LoRA conversion is not yet supported for this model")
115-
sys.exit(1)
116-
117-
# map original name to gguf name
118-
map_name: dict[str, str] = {}
119-
for tensor_name, tensor in lora_model.items():
120-
base_name = get_base_tensor_name(tensor_name)
121-
is_lora_a = ".lora_A.weight" in tensor_name
122-
is_lora_b = ".lora_B.weight" in tensor_name
123-
if not is_lora_a and not is_lora_b:
124-
logger.error(f"Unexpected name '{tensor_name}': Not a lora_A or lora_B tensor")
125-
sys.exit(1)
126-
dest_name = model_instance.map_tensor_name(base_name)
127-
dest_name = f"{dest_name}.lora_a" if is_lora_a else f"{dest_name}.lora_b"
128-
map_name[tensor_name] = dest_name
275+
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
276+
tensor_map: dict[str, PartialLoraTensor] = {}
129277

130-
# overwrite method
131-
def map_tensor_name(self, name: str) -> str:
132-
return map_name[name]
278+
for name, tensor in lora_model.items():
279+
base_name = get_base_tensor_name(name)
280+
is_lora_a = ".lora_A.weight" in name
281+
is_lora_b = ".lora_B.weight" in name
282+
if not is_lora_a and not is_lora_b:
283+
if ".base_layer.weight" in name:
284+
continue
285+
logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
286+
sys.exit(1)
133287

134-
# overwrite method
135-
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
136-
for name, tensor in lora_model.items():
137-
yield (name, tensor)
288+
if base_name in tensor_map:
289+
if is_lora_a:
290+
tensor_map[base_name].A = tensor
291+
else:
292+
tensor_map[base_name].B = tensor
293+
else:
294+
if is_lora_a:
295+
tensor_map[base_name] = PartialLoraTensor(A=tensor)
296+
else:
297+
tensor_map[base_name] = PartialLoraTensor(B=tensor)
138298

139-
# overwrite method
140-
def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
141-
del name, new_name, bid, n_dims # unused
142-
return ftype != gguf.LlamaFileType.ALL_F32
299+
for name, tensor in tensor_map.items():
300+
assert tensor.A is not None
301+
assert tensor.B is not None
302+
yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B)))
143303

144-
model_instance._map_tensor_name = model_instance.map_tensor_name # type: ignore
145-
model_instance.map_tensor_name = types.MethodType(map_tensor_name, model_instance)
304+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
305+
dest = super().modify_tensors(data_torch, name, bid)
306+
for dest_name, dest_data in dest:
307+
assert isinstance(dest_data, LoraTorchTensor)
308+
# logger.info(f"{orig_name} --> {dest_name}")
309+
yield (dest_name + ".lora_a", dest_data._lora_A)
310+
yield (dest_name + ".lora_b", dest_data._lora_B)
146311

147-
model_instance._get_tensors = model_instance.get_tensors # type: ignore
148-
model_instance.get_tensors = types.MethodType(get_tensors, model_instance)
312+
model_instance = LoraModel(dir_base_model, ftype, fname_out, args.bigendian, False, False, None)
313+
logger.info("Set model parameters")
314+
model_instance.set_gguf_parameters()
149315

150-
model_instance._extra_f16_tensors = model_instance.extra_f16_tensors # type: ignore
151-
model_instance.extra_f16_tensors = types.MethodType(extra_f16_tensors, model_instance)
316+
# adapter_config = json.load(input_json)
317+
model_instance.gguf_writer.add_string("training.type", "finetune_lora")
152318

153-
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
154-
logger.info("Exporting model...")
155-
model_instance.write()
156-
logger.info(f"Model successfully exported to {model_instance.fname_out}")
319+
model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
320+
logger.info("Exporting model...")
321+
model_instance.write()
322+
logger.info(f"Model successfully exported to {model_instance.fname_out}")

0 commit comments

Comments
 (0)