@@ -9571,7 +9571,7 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
9571
9571
struct ggml_tensor * cur,
9572
9572
struct ggml_tensor * x_prev,
9573
9573
struct ggml_tensor ** wkv_state) {
9574
- size_t n_embed = cur->ne[0];
9574
+ size_t n_embd = cur->ne[0];
9575
9575
size_t n_seq_tokens = cur->ne[1];
9576
9576
size_t n_seqs = cur->ne[2];
9577
9577
@@ -9582,8 +9582,8 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
9582
9582
9583
9583
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
9584
9584
9585
- sx = ggml_reshape_2d(ctx, sx, n_embed , n_tokens);
9586
- cur = ggml_reshape_2d(ctx, cur, n_embed , n_tokens);
9585
+ sx = ggml_reshape_2d(ctx, sx, n_embd , n_tokens);
9586
+ cur = ggml_reshape_2d(ctx, cur, n_embd , n_tokens);
9587
9587
9588
9588
struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_x), cur);
9589
9589
@@ -9608,11 +9608,11 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
9608
9608
xxx
9609
9609
);
9610
9610
9611
- struct ggml_tensor *mw = ggml_view_2d(ctx, xxx, n_embed , n_tokens, xxx->nb[1], 0);
9612
- struct ggml_tensor *mk = ggml_view_2d(ctx, xxx, n_embed , n_tokens, xxx->nb[1], n_embed * n_tokens * sizeof(float));
9613
- struct ggml_tensor *mv = ggml_view_2d(ctx, xxx, n_embed , n_tokens, xxx->nb[1], n_embed * n_tokens * 2 * sizeof(float));
9614
- struct ggml_tensor *mr = ggml_view_2d(ctx, xxx, n_embed , n_tokens, xxx->nb[1], n_embed * n_tokens * 3 * sizeof(float));
9615
- struct ggml_tensor *mg = ggml_view_2d(ctx, xxx, n_embed , n_tokens, xxx->nb[1], n_embed * n_tokens * 4 * sizeof(float));
9611
+ struct ggml_tensor *mw = ggml_view_2d(ctx, xxx, n_embd , n_tokens, xxx->nb[1], 0);
9612
+ struct ggml_tensor *mk = ggml_view_2d(ctx, xxx, n_embd , n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
9613
+ struct ggml_tensor *mv = ggml_view_2d(ctx, xxx, n_embd , n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
9614
+ struct ggml_tensor *mr = ggml_view_2d(ctx, xxx, n_embd , n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
9615
+ struct ggml_tensor *mg = ggml_view_2d(ctx, xxx, n_embd , n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
9616
9616
9617
9617
struct ggml_tensor * xw = ggml_add(
9618
9618
ctx,
@@ -9681,7 +9681,7 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
9681
9681
)
9682
9682
);
9683
9683
9684
- w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embed ));
9684
+ w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embd ));
9685
9685
w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
9686
9686
w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);
9687
9687
@@ -9690,21 +9690,21 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
9690
9690
r = ggml_transpose(ctx, r);
9691
9691
9692
9692
struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
9693
- cur = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0);
9694
- *wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_seqs, n_embed * n_tokens * sizeof(float));
9693
+ cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
9694
+ *wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
9695
9695
9696
9696
// group norm with head_count groups
9697
- cur = ggml_reshape_3d(ctx, cur, n_embed / head_count, head_count, n_tokens);
9697
+ cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
9698
9698
cur = ggml_norm(ctx, cur, 64e-5f);
9699
9699
9700
9700
// Convert back to regular vectors.
9701
- cur = ggml_reshape_2d(ctx, cur, n_embed , n_tokens);
9701
+ cur = ggml_reshape_2d(ctx, cur, n_embd , n_tokens);
9702
9702
cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
9703
9703
9704
9704
cur = ggml_mul(ctx, cur, g);
9705
9705
cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
9706
9706
9707
- return ggml_reshape_3d(ctx, cur, n_embed , n_seq_tokens, n_seqs);
9707
+ return ggml_reshape_3d(ctx, cur, n_embd , n_seq_tokens, n_seqs);
9708
9708
}
9709
9709
9710
9710
static struct ggml_tensor * llm_build_rwkv6_channel_mix(
0 commit comments