@@ -1057,8 +1057,10 @@ static struct ggml_tensor * llm_build_rwkv7_time_mix(
1057
1057
1058
1058
size_t n_tokens = n_seqs * n_seq_tokens;
1059
1059
1060
+ bool has_gating = layer->time_mix_g1 && layer->time_mix_g2 ;
1061
+
1060
1062
struct ggml_tensor * sx = ggml_sub (ctx, x_prev, cur);
1061
- struct ggml_tensor * dummy = ggml_new_tensor_3d (ctx, GGML_TYPE_F32, n_embd, n_tokens, 6 );
1063
+ struct ggml_tensor * dummy = ggml_new_tensor_3d (ctx, GGML_TYPE_F32, n_embd, n_tokens, layer-> time_mix_lerp_fused -> ne [ 2 ] );
1062
1064
sx = ggml_repeat (ctx, sx, dummy);
1063
1065
1064
1066
struct ggml_tensor * xxx = ggml_add (ctx, ggml_mul (ctx, sx, layer->time_mix_lerp_fused ), cur);
@@ -1068,7 +1070,7 @@ static struct ggml_tensor * llm_build_rwkv7_time_mix(
1068
1070
struct ggml_tensor * xk = ggml_view_2d (ctx, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * 2 * sizeof (float ));
1069
1071
struct ggml_tensor * xv = ggml_view_2d (ctx, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * 3 * sizeof (float ));
1070
1072
struct ggml_tensor * xa = ggml_view_2d (ctx, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * 4 * sizeof (float ));
1071
- struct ggml_tensor * xg = ggml_view_2d (ctx, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * 5 * sizeof (float ));
1073
+ struct ggml_tensor * xg = has_gating ? ggml_view_2d (ctx, xxx, n_embd, n_tokens, xxx->nb [1 ], n_embd * n_tokens * 5 * sizeof (float )) : nullptr ;
1072
1074
1073
1075
struct ggml_tensor * r = llm_build_lora_mm (lctx, ctx, layer->time_mix_receptance , xr);
1074
1076
// Assume that there won't be lora adapters on these “lora” matmuls?
@@ -1142,7 +1144,7 @@ static struct ggml_tensor * llm_build_rwkv7_time_mix(
1142
1144
ggml_mul (ctx, ggml_mul (ctx, k, r), ggml_reshape_2d (ctx, layer->time_mix_r_k , head_size, head_count)));
1143
1145
cur = ggml_add (ctx, cur, ggml_reshape_2d (ctx, ggml_mul (ctx, v, rk), n_embd, n_tokens));
1144
1146
1145
- if (g ) {
1147
+ if (has_gating ) {
1146
1148
cur = ggml_mul (ctx, cur, g);
1147
1149
}
1148
1150
cur = llm_build_lora_mm (lctx, ctx, layer->time_mix_output , cur);
0 commit comments