Skip to content

Commit b936d64

Browse files
committed
rwkv: better handling for models without gate
Signed-off-by: Molly Sophia <[email protected]>
1 parent c137349 commit b936d64

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/llama.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,8 +1057,10 @@ static struct ggml_tensor * llm_build_rwkv7_time_mix(
10571057

10581058
size_t n_tokens = n_seqs * n_seq_tokens;
10591059

1060+
bool has_gating = layer->time_mix_g1 && layer->time_mix_g2;
1061+
10601062
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
1061-
struct ggml_tensor * dummy = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, n_tokens, 6);
1063+
struct ggml_tensor * dummy = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, n_tokens, layer->time_mix_lerp_fused->ne[2]);
10621064
sx = ggml_repeat(ctx, sx, dummy);
10631065

10641066
struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_fused), cur);
@@ -1068,7 +1070,7 @@ static struct ggml_tensor * llm_build_rwkv7_time_mix(
10681070
struct ggml_tensor * xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
10691071
struct ggml_tensor * xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
10701072
struct ggml_tensor * xa = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
1071-
struct ggml_tensor * xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 5 * sizeof(float));
1073+
struct ggml_tensor * xg = has_gating ? ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 5 * sizeof(float)) : nullptr;
10721074

10731075
struct ggml_tensor * r = llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr);
10741076
// Assume that there won't be lora adapters on these “lora” matmuls?
@@ -1142,7 +1144,7 @@ static struct ggml_tensor * llm_build_rwkv7_time_mix(
11421144
ggml_mul(ctx, ggml_mul(ctx, k, r), ggml_reshape_2d(ctx, layer->time_mix_r_k, head_size, head_count)));
11431145
cur = ggml_add(ctx, cur, ggml_reshape_2d(ctx, ggml_mul(ctx, v, rk), n_embd, n_tokens));
11441146

1145-
if (g) {
1147+
if (has_gating) {
11461148
cur = ggml_mul(ctx, cur, g);
11471149
}
11481150
cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);

0 commit comments

Comments
 (0)