Skip to content

Commit 916e959

Browse files
committed
llm_build_lora_mm_id
1 parent e68344c commit 916e959

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

src/llama.cpp

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7882,7 +7882,6 @@ static struct ggml_tensor * llm_build_lora_mm(
78827882
if (lora == nullptr) {
78837883
continue;
78847884
}
7885-
// TODO: check if lora_a need transpose
78867885
struct ggml_tensor * ab_cur = ggml_mul_mat(
78877886
ctx0, lora->b,
78887887
ggml_mul_mat(ctx0, lora->a, cur)
@@ -7893,6 +7892,31 @@ static struct ggml_tensor * llm_build_lora_mm(
78937892
return res;
78947893
}
78957894

7895+
// do mat_mul_id, while optionally apply lora
7896+
static struct ggml_tensor * llm_build_lora_mm_id(
7897+
struct llama_context & lctx,
7898+
struct ggml_context * ctx0,
7899+
struct ggml_tensor * w, // struct ggml_tensor * as
7900+
struct ggml_tensor * cur, // struct ggml_tensor * b
7901+
struct ggml_tensor * ids) {
7902+
struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
7903+
for (auto & it : lctx.lora_adapters) {
7904+
struct llama_lora_weight * lora = it.first->get_weight(w);
7905+
float scale = it.second;
7906+
if (lora == nullptr) {
7907+
continue;
7908+
}
7909+
struct ggml_tensor * ab_cur = ggml_mul_mat_id(
7910+
ctx0, lora->b,
7911+
ggml_mul_mat_id(ctx0, lora->a, cur, ids),
7912+
ids
7913+
);
7914+
ab_cur = ggml_scale_inplace(ctx0, ab_cur, scale);
7915+
res = ggml_add(ctx0, res, ab_cur);
7916+
}
7917+
return res;
7918+
}
7919+
78967920
static struct ggml_tensor * llm_build_norm(
78977921
struct ggml_context * ctx,
78987922
struct ggml_tensor * cur,
@@ -8103,10 +8127,10 @@ static struct ggml_tensor * llm_build_moe_ffn(
81038127
}
81048128

81058129
cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
8106-
ggml_tensor * up = ggml_mul_mat_id(ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
8130+
ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
81078131
cb(up, "ffn_moe_up", il);
81088132

8109-
ggml_tensor * gate = ggml_mul_mat_id(ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
8133+
ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
81108134
cb(gate, "ffn_moe_gate", il);
81118135

81128136
switch (type_op) {
@@ -8127,7 +8151,7 @@ static struct ggml_tensor * llm_build_moe_ffn(
81278151
ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens]
81288152
cb(par, "ffn_moe_gate_par", il);
81298153

8130-
ggml_tensor * experts = ggml_mul_mat_id(ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
8154+
ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
81318155
cb(experts, "ffn_moe_down", il);
81328156

81338157
experts = ggml_mul(ctx, experts, weights);

0 commit comments

Comments
 (0)