Skip to content

Commit 3eb7b62

Browse files
ngxsonmglambda
authored andcommitted
lora : improve compat with mergekit-extract-lora (ggml-org#11131)
* (wip) support mergekit-extracted lora * support mergekit-extract-lora * use lora->get_scale * correct comment * correct norm name & condition * add some hints
1 parent b64bb18 commit 3eb7b62

File tree

4 files changed

+74
-12
lines changed

4 files changed

+74
-12
lines changed

convert_lora_to_gguf.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,9 @@ def get_base_tensor_name(lora_tensor_name: str) -> str:
226226
base_name = lora_tensor_name.replace("base_model.model.", "")
227227
base_name = base_name.replace(".lora_A.weight", ".weight")
228228
base_name = base_name.replace(".lora_B.weight", ".weight")
229+
# models produced by mergekit-extract-lora have token embeddings in the adapter
230+
base_name = base_name.replace(".lora_embedding_A", ".weight")
231+
base_name = base_name.replace(".lora_embedding_B", ".weight")
229232
return base_name
230233

231234

@@ -260,6 +263,10 @@ def parse_args() -> argparse.Namespace:
260263
"--base", type=Path,
261264
help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required. If base model is unspecified, it will be loaded from Hugging Face hub based on the adapter config",
262265
)
266+
parser.add_argument(
267+
"--base-model-id", type=str,
268+
help="the model ID of the base model, if it is not available locally or in the adapter config. If specified, it will ignore --base and load the base model config from the Hugging Face hub (Example: 'meta-llama/Llama-3.2-1B-Instruct')",
269+
)
263270
parser.add_argument(
264271
"lora_path", type=Path,
265272
help="directory containing Hugging Face PEFT LoRA config (adapter_model.json) and weights (adapter_model.safetensors or adapter_model.bin)",
@@ -290,6 +297,7 @@ def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
290297

291298
dir_base_model: Path | None = args.base
292299
dir_lora: Path = args.lora_path
300+
base_model_id: str | None = args.base_model_id
293301
lora_config = dir_lora / "adapter_config.json"
294302
input_model = dir_lora / "adapter_model.safetensors"
295303

@@ -313,7 +321,10 @@ def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
313321
lparams: dict[str, Any] = json.load(f)
314322

315323
# load base model
316-
if dir_base_model is None:
324+
if base_model_id is not None:
325+
logger.info(f"Loading base model from Hugging Face: {base_model_id}")
326+
hparams = load_hparams_from_hf(base_model_id)
327+
elif dir_base_model is None:
317328
if "base_model_name_or_path" in lparams:
318329
model_id = lparams["base_model_name_or_path"]
319330
logger.info(f"Loading base model from Hugging Face: {model_id}")
@@ -371,11 +382,16 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
371382
if self.lazy:
372383
tensor = LazyTorchTensor.from_eager(tensor)
373384
base_name = get_base_tensor_name(name)
374-
is_lora_a = ".lora_A.weight" in name
375-
is_lora_b = ".lora_B.weight" in name
385+
# note: mergekit-extract-lora also adds token embeddings to the adapter
386+
is_lora_a = ".lora_A.weight" in name or ".lora_embedding_A" in name
387+
is_lora_b = ".lora_B.weight" in name or ".lora_embedding_B" in name
376388
if not is_lora_a and not is_lora_b:
377389
if ".base_layer.weight" in name:
378390
continue
391+
# mergekit-extract-lora add these layernorm to the adapter, we need to keep them
392+
if "_layernorm" in name or ".norm" in name:
393+
yield (base_name, tensor)
394+
continue
379395
logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
380396
if ".embed_tokens.weight" in name or ".lm_head.weight" in name:
381397
logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning")
@@ -407,9 +423,21 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
407423
if name == "lm_head.weight" and len(dest) == 0:
408424
raise ValueError("lm_head is present in adapter, but is ignored in base model")
409425
for dest_name, dest_data in dest:
426+
# mergekit-extract-lora add these layernorm to the adapter
427+
if "_norm" in dest_name:
428+
assert dest_data.dim() == 1
429+
yield (dest_name, dest_data)
430+
continue
431+
432+
# otherwise, we must get the lora_A and lora_B tensors
410433
assert isinstance(dest_data, LoraTorchTensor)
411434
lora_a, lora_b = dest_data.get_lora_A_B()
412435

436+
# note: mergekit-extract-lora flip and transpose A and B
437+
# here we only need to transpose token_embd.lora_a, see llm_build_inp_embd()
438+
if "token_embd.weight" in dest_name:
439+
lora_a = lora_a.T
440+
413441
yield (dest_name + ".lora_a", lora_a)
414442
yield (dest_name + ".lora_b", lora_b)
415443

src/llama-adapter.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,10 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char
242242
} else {
243243
ab_map[name].b = cur;
244244
}
245+
} else if (str_endswith(name, "_norm.weight")) {
246+
// TODO: add support for norm vector
247+
// for now, we don't really care because most adapters still work fine without it
248+
continue;
245249
} else {
246250
throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix");
247251
}
@@ -251,6 +255,7 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char
251255
for (auto & it : ab_map) {
252256
const std::string & name = it.first;
253257
llama_lora_weight & w = it.second;
258+
bool is_token_embd = str_endswith(name, "token_embd.weight");
254259

255260
if (!w.a || !w.b) {
256261
throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component");
@@ -259,16 +264,23 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char
259264
// device buft and device ctx
260265
auto * model_tensor = llama_model_get_tensor(model, name.c_str());
261266
if (!model_tensor) {
262-
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model");
267+
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)");
263268
}
264269

265270
struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
266271
// validate tensor shape
267-
if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) {
268-
throw std::runtime_error("tensor '" + name + "' has incorrect shape");
269-
}
270-
if (w.a->ne[1] != w.b->ne[0]) {
271-
throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)");
272+
if (is_token_embd) {
273+
// expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd()
274+
if (model_tensor->ne[0] != w.b->ne[1] || model_tensor->ne[1] != w.a->ne[1]) {
275+
throw std::runtime_error("tensor '" + name + "' has incorrect shape (hint: maybe wrong base model?)");
276+
}
277+
} else {
278+
if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) {
279+
throw std::runtime_error("tensor '" + name + "' has incorrect shape (hint: maybe wrong base model?)");
280+
}
281+
if (w.a->ne[1] != w.b->ne[0]) {
282+
throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)");
283+
}
272284
}
273285

274286
// save tensor to adapter

src/llama-adapter.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ struct llama_lora_weight {
4545
struct ggml_tensor * a = nullptr;
4646
struct ggml_tensor * b = nullptr;
4747

48+
// get actual scale based on rank and alpha
49+
float get_scale(float alpha, float adapter_scale) {
50+
const float rank = (float) b->ne[0];
51+
const float scale = alpha ? adapter_scale * alpha / rank : adapter_scale;
52+
return scale;
53+
}
54+
4855
llama_lora_weight() = default;
4956
llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {}
5057
};

src/llama.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2545,6 +2545,21 @@ static struct ggml_tensor * llm_build_inp_embd(
25452545
ggml_set_input(lctx.inp_tokens);
25462546

25472547
inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
2548+
2549+
// apply lora for embedding tokens if needed
2550+
for (auto & it : lctx.lora_adapters) {
2551+
struct llama_lora_weight * lora = it.first->get_weight(tok_embd);
2552+
if (lora == nullptr) {
2553+
continue;
2554+
}
2555+
const float adapter_scale = it.second;
2556+
const float scale = lora->get_scale(it.first->alpha, adapter_scale);
2557+
struct ggml_tensor * inpL_delta = ggml_scale(ctx, ggml_mul_mat(
2558+
ctx, lora->b, // non-transposed lora_b
2559+
ggml_get_rows(ctx, lora->a, lctx.inp_tokens)
2560+
), scale);
2561+
inpL = ggml_add(ctx, inpL, inpL_delta);
2562+
}
25482563
} else {
25492564
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
25502565
inpL = lctx.inp_embd;
@@ -2617,9 +2632,8 @@ static struct ggml_tensor * llm_build_lora_mm(
26172632
if (lora == nullptr) {
26182633
continue;
26192634
}
2620-
const float alpha = it.first->alpha;
2621-
const float rank = (float) lora->b->ne[0];
2622-
const float scale = alpha ? it.second * alpha / rank : it.second;
2635+
const float adapter_scale = it.second;
2636+
const float scale = lora->get_scale(it.first->alpha, adapter_scale);
26232637
struct ggml_tensor * ab_cur = ggml_mul_mat(
26242638
ctx0, lora->b,
26252639
ggml_mul_mat(ctx0, lora->a, cur)
@@ -3967,6 +3981,7 @@ struct llm_build_context {
39673981

39683982
// feed-forward network
39693983
if (model.layers[il].ffn_gate_inp == nullptr) {
3984+
39703985
cur = llm_build_norm(ctx0, ffn_inp, hparams,
39713986
model.layers[il].ffn_norm, NULL,
39723987
LLM_NORM_RMS, cb, il);

0 commit comments

Comments
 (0)