3
3
4
4
from __future__ import annotations
5
5
6
+ from dataclasses import dataclass
6
7
import logging
7
8
import argparse
8
9
import os
9
10
import sys
10
- import types
11
11
from pathlib import Path
12
- from typing import TYPE_CHECKING , Iterator
12
+ from types import EllipsisType
13
+ from typing import TYPE_CHECKING , Callable , Iterable , Iterator , Sequence , SupportsIndex , cast
13
14
14
15
import torch
15
16
26
27
logger = logging .getLogger ("lora-to-gguf" )
27
28
28
29
30
+ @dataclass
31
+ class PartialLoraTensor :
32
+ A : Tensor | None = None
33
+ B : Tensor | None = None
34
+
35
+
36
+ # magic to support tensor shape modifications and splitting
37
+ class LoraTorchTensor :
38
+ _lora_A : Tensor
39
+ _lora_B : Tensor
40
+ _rank : int
41
+
42
+ def __init__ (self , A : Tensor , B : Tensor ):
43
+ assert len (A .shape ) == len (B .shape )
44
+ if A .dtype != B .dtype :
45
+ A = A .to (torch .float32 )
46
+ B = B .to (torch .float32 )
47
+ self ._lora_A = A
48
+ self ._lora_B = B
49
+ assert self ._lora_A .shape [- 2 ] == self ._lora_B .shape [- 1 ]
50
+ self ._rank = self ._lora_B .shape [- 1 ]
51
+
52
+ def __getitem__ (
53
+ self ,
54
+ indices : (
55
+ SupportsIndex
56
+ | slice
57
+ | tuple [SupportsIndex | slice | EllipsisType | Tensor , ...]
58
+ ),
59
+ ) -> LoraTorchTensor :
60
+ shape = self .shape
61
+ if isinstance (indices , (SupportsIndex , slice )):
62
+ if len (shape ) > 2 :
63
+ return LoraTorchTensor (self ._lora_A [indices ], self ._lora_B [indices ])
64
+ else :
65
+ raise NotImplementedError
66
+ elif isinstance (indices , tuple ):
67
+ assert len (indices ) > 0
68
+ if isinstance (indices [- 1 ], EllipsisType ):
69
+ return self [indices [:- 1 ]]
70
+ # expand ellipsis
71
+ indices = tuple (
72
+ u
73
+ for v in (
74
+ (
75
+ (slice (None , None ) for _ in range (len (indices ) - 1 ))
76
+ if isinstance (i , EllipsisType )
77
+ else (i ,)
78
+ )
79
+ for i in indices
80
+ )
81
+ for u in v
82
+ )
83
+
84
+ if len (indices ) < len (shape ):
85
+ indices = (* indices , * (slice (None , None ) for _ in range (len (indices ), len (shape ))))
86
+
87
+ # TODO: make sure this is correct
88
+ # lora_A has a shape which looks like (..., 1, 1, rank, self.shape[-1])
89
+ indices_A = (
90
+ * (
91
+ 0 if isinstance (i , SupportsIndex ) else slice (None , None )
92
+ for i in indices [:- 2 ]
93
+ ),
94
+ slice (None , None ),
95
+ indices [- 1 ],
96
+ )
97
+ indices_B = indices [:- 1 ]
98
+ return LoraTorchTensor (self ._lora_A [indices_A ], self ._lora_B [indices_B ])
99
+ else :
100
+ raise NotImplementedError
101
+
102
+ @property
103
+ def dtype (self ) -> torch .dtype :
104
+ assert self ._lora_A .dtype == self ._lora_B .dtype
105
+ return self ._lora_A .dtype
106
+
107
+ @property
108
+ def shape (self ) -> tuple [int , ...]:
109
+ return (* self ._lora_B .shape [:- 1 ], self ._lora_A .shape [- 1 ])
110
+
111
+ def size (self , dim = None ):
112
+ assert dim is None
113
+ return self .shape
114
+
115
+ def reshape (self , * shape : int | tuple [int ]) -> LoraTorchTensor :
116
+ if isinstance (shape [0 ], tuple ):
117
+ new_shape : tuple [int ] = shape [0 ]
118
+ else :
119
+ new_shape = cast (tuple [int ], shape )
120
+ orig_shape = self .shape
121
+ if new_shape [- 1 ] != orig_shape [- 1 ]:
122
+ raise NotImplementedError
123
+ return LoraTorchTensor (
124
+ self ._lora_A .reshape ((* (1 for _ in new_shape [:- 2 ]), * self ._lora_A .shape [- 2 :])),
125
+ self ._lora_B .reshape ((* new_shape [:- 1 ], self ._rank )),
126
+ )
127
+
128
+ def reshape_as (self , other : Tensor ) -> LoraTorchTensor :
129
+ return self .reshape (* other .shape )
130
+
131
+ def view (self , * size : int ) -> LoraTorchTensor :
132
+ return self .reshape (* size )
133
+
134
+ def permute (self , * dims : int ) -> LoraTorchTensor :
135
+ shape = self .shape
136
+ dims = tuple (dim - len (shape ) if dim >= 0 else dim for dim in dims )
137
+ if dims [- 1 ] == - 2 and dims [- 2 ] == - 1 :
138
+ return LoraTorchTensor (self ._lora_B .permute (* dims ), self ._lora_A .permute (* dims ))
139
+ else :
140
+ assert dims [- 1 ] == - 1
141
+ assert all (dim == 1 for dim in self ._lora_A .shape [:- 2 ])
142
+ return LoraTorchTensor (self ._lora_A , self ._lora_B .permute (* dims ))
143
+
144
+ def transpose (self , dim0 : int , dim1 : int ) -> LoraTorchTensor :
145
+ shape = self .shape
146
+ dims = [i for i in range (len (shape ))]
147
+ dims [dim0 ], dims [dim1 ] = dims [dim1 ], dims [dim0 ]
148
+ return self .permute (* dims )
149
+
150
+ def swapaxes (self , axis0 : int , axis1 : int ) -> LoraTorchTensor :
151
+ return self .transpose (axis0 , axis1 )
152
+
153
+ def to (self , * args , ** kwargs ):
154
+ return LoraTorchTensor (self ._lora_A .to (* args , ** kwargs ), self ._lora_B .to (* args , ** kwargs ))
155
+
156
+ @classmethod
157
+ def __torch_function__ (cls , func : Callable , types , args = (), kwargs = None ):
158
+ del types # unused
159
+
160
+ if kwargs is None :
161
+ kwargs = {}
162
+
163
+ if func is torch .permute :
164
+ return type (args [0 ]).permute (* args , ** kwargs )
165
+ elif func is torch .reshape :
166
+ return type (args [0 ]).reshape (* args , ** kwargs )
167
+ elif func is torch .stack :
168
+ assert isinstance (args [0 ], Sequence )
169
+ dim = kwargs .get ("dim" , 0 )
170
+ assert dim == 0
171
+ return LoraTorchTensor (
172
+ torch .stack ([a ._lora_A for a in args [0 ]], dim ),
173
+ torch .stack ([b ._lora_B for b in args [0 ]], dim ),
174
+ )
175
+ elif func is torch .cat :
176
+ assert isinstance (args [0 ], Sequence )
177
+ dim = kwargs .get ("dim" , 0 )
178
+ assert dim == 0
179
+ if len (args [0 ][0 ].shape ) > 2 :
180
+ return LoraTorchTensor (
181
+ torch .cat ([a ._lora_A for a in args [0 ]], dim ),
182
+ torch .cat ([b ._lora_B for b in args [0 ]], dim ),
183
+ )
184
+ else :
185
+ return LoraTorchTensor (
186
+ args [0 ][0 ]._lora_A , # TODO: is this correct? (can't cat over the rank)
187
+ torch .cat ([b ._lora_B for b in args [0 ]], dim ),
188
+ )
189
+ else :
190
+ raise NotImplementedError
191
+
192
+
29
193
def get_base_tensor_name (lora_tensor_name : str ) -> str :
30
194
base_name = lora_tensor_name .replace ("base_model.model." , "" )
31
195
base_name = base_name .replace (".lora_A.weight" , ".weight" )
@@ -79,20 +243,21 @@ def parse_args() -> argparse.Namespace:
79
243
dir_base_model = args .base
80
244
dir_lora = args .lora_path
81
245
input_json = os .path .join (dir_lora , "adapter_config.json" )
82
- input_model = os .path .join (dir_lora , "adapter_model.bin " )
246
+ input_model = os .path .join (dir_lora , "adapter_model.safetensors " )
83
247
if args .outfile is not None :
84
248
fname_out = args .outfile
85
249
else :
86
250
# output in the same directory as the model by default
87
251
fname_out = dir_lora / 'ggml-lora-{ftype}.gguf'
88
252
89
253
if os .path .exists (input_model ):
90
- lora_model = torch .load (input_model , map_location = "cpu" )
91
- else :
92
- input_model = os .path .join (dir_lora , "adapter_model.safetensors" )
93
254
# lazy import load_file only if lora is in safetensors format.
94
255
from safetensors .torch import load_file
256
+
95
257
lora_model = load_file (input_model , device = "cpu" )
258
+ else :
259
+ input_model = os .path .join (dir_lora , "adapter_model.bin" )
260
+ lora_model = torch .load (input_model , map_location = "cpu" , weights_only = True )
96
261
97
262
# load base model
98
263
logger .info (f"Loading base model: { dir_base_model .name } " )
@@ -104,53 +269,54 @@ def parse_args() -> argparse.Namespace:
104
269
logger .error (f"Model { hparams ['architectures' ][0 ]} is not supported" )
105
270
sys .exit (1 )
106
271
107
- model_instance = model_class (dir_base_model , ftype , fname_out , args .bigendian , False , False , None )
108
- logger .info ("Set model parameters" )
109
- model_instance .set_gguf_parameters ()
272
+ class LoraModel (model_class ):
273
+ model_arch = model_class .model_arch
110
274
111
- # adapter_config = json.load(input_json)
112
- model_instance .gguf_writer .add_string ("training.type" , "finetune_lora" )
113
- if not model_instance .support_lora ():
114
- logger .error ("LoRA conversion is not yet supported for this model" )
115
- sys .exit (1 )
116
-
117
- # map original name to gguf name
118
- map_name : dict [str , str ] = {}
119
- for tensor_name , tensor in lora_model .items ():
120
- base_name = get_base_tensor_name (tensor_name )
121
- is_lora_a = ".lora_A.weight" in tensor_name
122
- is_lora_b = ".lora_B.weight" in tensor_name
123
- if not is_lora_a and not is_lora_b :
124
- logger .error (f"Unexpected name '{ tensor_name } ': Not a lora_A or lora_B tensor" )
125
- sys .exit (1 )
126
- dest_name = model_instance .map_tensor_name (base_name )
127
- dest_name = f"{ dest_name } .lora_a" if is_lora_a else f"{ dest_name } .lora_b"
128
- map_name [tensor_name ] = dest_name
275
+ def get_tensors (self ) -> Iterator [tuple [str , Tensor ]]:
276
+ tensor_map : dict [str , PartialLoraTensor ] = {}
129
277
130
- # overwrite method
131
- def map_tensor_name (self , name : str ) -> str :
132
- return map_name [name ]
278
+ for name , tensor in lora_model .items ():
279
+ base_name = get_base_tensor_name (name )
280
+ is_lora_a = ".lora_A.weight" in name
281
+ is_lora_b = ".lora_B.weight" in name
282
+ if not is_lora_a and not is_lora_b :
283
+ if ".base_layer.weight" in name :
284
+ continue
285
+ logger .error (f"Unexpected name '{ name } ': Not a lora_A or lora_B tensor" )
286
+ sys .exit (1 )
133
287
134
- # overwrite method
135
- def get_tensors (self ) -> Iterator [tuple [str , Tensor ]]:
136
- for name , tensor in lora_model .items ():
137
- yield (name , tensor )
288
+ if base_name in tensor_map :
289
+ if is_lora_a :
290
+ tensor_map [base_name ].A = tensor
291
+ else :
292
+ tensor_map [base_name ].B = tensor
293
+ else :
294
+ if is_lora_a :
295
+ tensor_map [base_name ] = PartialLoraTensor (A = tensor )
296
+ else :
297
+ tensor_map [base_name ] = PartialLoraTensor (B = tensor )
138
298
139
- # overwrite method
140
- def extra_f16_tensors ( self , name : str , new_name : str , bid : int | None , n_dims : int ) -> bool :
141
- del name , new_name , bid , n_dims # unused
142
- return ftype != gguf . LlamaFileType . ALL_F32
299
+ for name , tensor in tensor_map . items ():
300
+ assert tensor . A is not None
301
+ assert tensor . B is not None
302
+ yield ( name , cast ( torch . Tensor , LoraTorchTensor ( tensor . A , tensor . B )))
143
303
144
- model_instance ._map_tensor_name = model_instance .map_tensor_name # type: ignore
145
- model_instance .map_tensor_name = types .MethodType (map_tensor_name , model_instance )
304
+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
305
+ dest = super ().modify_tensors (data_torch , name , bid )
306
+ for dest_name , dest_data in dest :
307
+ assert isinstance (dest_data , LoraTorchTensor )
308
+ # logger.info(f"{orig_name} --> {dest_name}")
309
+ yield (dest_name + ".lora_a" , dest_data ._lora_A )
310
+ yield (dest_name + ".lora_b" , dest_data ._lora_B )
146
311
147
- model_instance ._get_tensors = model_instance .get_tensors # type: ignore
148
- model_instance .get_tensors = types .MethodType (get_tensors , model_instance )
312
+ model_instance = LoraModel (dir_base_model , ftype , fname_out , args .bigendian , False , False , None )
313
+ logger .info ("Set model parameters" )
314
+ model_instance .set_gguf_parameters ()
149
315
150
- model_instance . _extra_f16_tensors = model_instance . extra_f16_tensors # type: ignore
151
- model_instance . extra_f16_tensors = types . MethodType ( extra_f16_tensors , model_instance )
316
+ # adapter_config = json.load(input_json)
317
+ model_instance . gguf_writer . add_string ( "training.type" , "finetune_lora" )
152
318
153
- model_instance .gguf_writer .add_quantization_version (gguf .GGML_QUANT_VERSION )
154
- logger .info ("Exporting model..." )
155
- model_instance .write ()
156
- logger .info (f"Model successfully exported to { model_instance .fname_out } " )
319
+ model_instance .gguf_writer .add_quantization_version (gguf .GGML_QUANT_VERSION )
320
+ logger .info ("Exporting model..." )
321
+ model_instance .write ()
322
+ logger .info (f"Model successfully exported to { model_instance .fname_out } " )
0 commit comments