Skip to content

Commit 35405cd

Browse files
committed
rwkv6: rename to wkv6
1 parent c02e5ab commit 35405cd

File tree

7 files changed

+38
-36
lines changed

7 files changed

+38
-36
lines changed

ggml/include/ggml.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ extern "C" {
510510
GGML_OP_WIN_UNPART,
511511
GGML_OP_GET_REL_POS,
512512
GGML_OP_ADD_REL_POS,
513-
GGML_OP_RWKV_WKV,
513+
GGML_OP_RWKV_WKV6,
514514

515515
GGML_OP_UNARY,
516516

@@ -1887,7 +1887,7 @@ extern "C" {
18871887
struct ggml_tensor * pw,
18881888
struct ggml_tensor * ph);
18891889

1890-
GGML_API struct ggml_tensor * ggml_rwkv_wkv(
1890+
GGML_API struct ggml_tensor * ggml_rwkv_wkv6(
18911891
struct ggml_context * ctx,
18921892
struct ggml_tensor * k,
18931893
struct ggml_tensor * v,

ggml/src/ggml-cuda.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2313,8 +2313,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23132313
case GGML_OP_CROSS_ENTROPY_LOSS:
23142314
ggml_cuda_cross_entropy_loss(ctx, dst);
23152315
break;
2316-
case GGML_OP_RWKV_WKV:
2317-
ggml_cuda_op_rwkv_wkv(ctx, dst);
2316+
case GGML_OP_RWKV_WKV6:
2317+
ggml_cuda_op_rwkv_wkv6(ctx, dst);
23182318
break;
23192319
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
23202320
ggml_cuda_cross_entropy_loss_back(ctx, dst);
@@ -3147,7 +3147,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31473147
case GGML_OP_ARANGE:
31483148
case GGML_OP_TIMESTEP_EMBEDDING:
31493149
case GGML_OP_LEAKY_RELU:
3150-
case GGML_OP_RWKV_WKV:
3150+
case GGML_OP_RWKV_WKV6:
31513151
return true;
31523152
case GGML_OP_FLASH_ATTN_EXT: {
31533153
#ifndef FLASH_ATTN_AVAILABLE

ggml/src/ggml-cuda/rwkv-wkv.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const
6464
}
6565
}
6666

67-
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
67+
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
6868
const float * k_d = (const float *)dst->src[0]->data;
6969
const float * v_d = (const float *)dst->src[1]->data;
7070
const float * r_d = (const float *)dst->src[2]->data;

ggml/src/ggml-cuda/rwkv-wkv.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
#define CUDA_WKV_BLOCK_SIZE 64
44

5-
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
5+
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml.c

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3172,7 +3172,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
31723172
"win_unpart(x)",
31733173
"get_rel_pos(x)",
31743174
"add_rel_pos(x)",
3175-
"rwkv_wkv(k, v, r, tf, td, s)",
3175+
"rwkv_wkv6(k, v, r, tf, td, s)",
31763176

31773177
"unary(x)",
31783178

@@ -7452,9 +7452,9 @@ struct ggml_tensor * ggml_add_rel_pos_inplace(
74527452
return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
74537453
}
74547454

7455-
// ggml_rwkv_wkv
7455+
// ggml_rwkv_wkv6
74567456

7457-
struct ggml_tensor * ggml_rwkv_wkv(
7457+
struct ggml_tensor * ggml_rwkv_wkv6(
74587458
struct ggml_context * ctx,
74597459
struct ggml_tensor * k,
74607460
struct ggml_tensor * v,
@@ -7486,7 +7486,7 @@ struct ggml_tensor * ggml_rwkv_wkv(
74867486
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
74877487
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
74887488

7489-
result->op = GGML_OP_RWKV_WKV;
7489+
result->op = GGML_OP_RWKV_WKV6;
74907490
result->src[0] = k;
74917491
result->src[1] = v;
74927492
result->src[2] = r;
@@ -16695,15 +16695,16 @@ static void ggml_compute_forward_add_rel_pos(
1669516695
}
1669616696
}
1669716697

16698-
// ggml_compute_forward_rwkv_wkv
16698+
// ggml_compute_forward_rwkv_wkv6
1669916699

16700-
static void ggml_compute_forward_rwkv_wkv_f32(
16700+
static void ggml_compute_forward_rwkv_wkv6_f32(
1670116701
const struct ggml_compute_params * params,
1670216702
struct ggml_tensor * dst) {
1670316703
const size_t T = dst->src[1]->ne[3];
1670416704
const size_t C = dst->ne[0];
1670516705
const size_t H = dst->src[1]->ne[2];
1670616706
const size_t n_seqs = dst->src[5]->ne[1];
16707+
const size_t head_size = C / H;
1670716708

1670816709
float * dst_data = (float *) dst->data;
1670916710
float * state = ((float *) dst->data) + C * T;
@@ -16720,18 +16721,18 @@ static void ggml_compute_forward_rwkv_wkv_f32(
1672016721
float * time_faaaa = (float *) dst->src[3]->data;
1672116722
float * time_decay = (float *) dst->src[4]->data;
1672216723

16723-
size_t t_stride = H * (C / H);
16724+
size_t t_stride = H * head_size;
1672416725

1672516726
size_t h_stride = C / H;
16726-
size_t h_stride_2d = (C / H) * (C / H);
16727+
size_t h_stride_2d = head_size * head_size;
1672716728

1672816729
// basically fused operations:
1672916730
// dst = r @ (time_faaaa * (k @ v) + state),
1673016731
// state = time_decay * state + (k @ v),
1673116732
// recursive through each token
1673216733
for (size_t t = 0; t < T; t++) {
1673316734
size_t t_offset = t * t_stride;
16734-
size_t state_offset = (C / H) * C * (t / (T / n_seqs));
16735+
size_t state_offset = head_size * C * (t / (T / n_seqs));
1673516736
float * state_cur = state + state_offset;
1673616737
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
1673716738

@@ -16740,7 +16741,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
1674016741
size_t t_h_offset = t_offset + h_offset;
1674116742
size_t h_2d_offset = h * h_stride_2d;
1674216743

16743-
for (size_t i = 0; i < C / H; i++) {
16744+
for (size_t i = 0; i < head_size; i++) {
1674416745
size_t t_h_i_offset = t_h_offset + i;
1674516746
size_t h_i_offset = h_offset + i;
1674616747
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
@@ -16751,7 +16752,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
1675116752
// RWKV v6: different time_decay for each token.
1675216753
float time_decay_val = time_decay[t_h_i_offset];
1675316754

16754-
for (size_t j = 0; j < C / H; j ++) {
16755+
for (size_t j = 0; j < head_size; j ++) {
1675516756
size_t t_h_j_offset = t_h_offset + j;
1675616757
size_t h_2d_i_j_offset = h_2d_i_offset + j;
1675716758

@@ -16767,7 +16768,8 @@ static void ggml_compute_forward_rwkv_wkv_f32(
1676716768
}
1676816769
}
1676916770

16770-
static void ggml_compute_forward_rwkv_wkv(
16771+
16772+
static void ggml_compute_forward_rwkv_wkv6(
1677116773
const struct ggml_compute_params * params,
1677216774
struct ggml_tensor * dst) {
1677316775

@@ -16776,7 +16778,7 @@ static void ggml_compute_forward_rwkv_wkv(
1677616778
switch (src0->type) {
1677716779
case GGML_TYPE_F32:
1677816780
{
16779-
ggml_compute_forward_rwkv_wkv_f32(params, dst);
16781+
ggml_compute_forward_rwkv_wkv6_f32(params, dst);
1678016782
} break;
1678116783
default:
1678216784
{
@@ -17528,9 +17530,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1752817530
{
1752917531
ggml_compute_forward_add_rel_pos(params, tensor);
1753017532
} break;
17531-
case GGML_OP_RWKV_WKV:
17533+
case GGML_OP_RWKV_WKV6:
1753217534
{
17533-
ggml_compute_forward_rwkv_wkv(params, tensor);
17535+
ggml_compute_forward_rwkv_wkv6(params, tensor);
1753417536
} break;
1753517537
case GGML_OP_MAP_UNARY:
1753617538
{
@@ -18719,7 +18721,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1871918721
} break;
1872018722
case GGML_OP_GET_REL_POS:
1872118723
case GGML_OP_ADD_REL_POS:
18722-
case GGML_OP_RWKV_WKV:
18724+
case GGML_OP_RWKV_WKV6:
1872318725
case GGML_OP_MAP_UNARY:
1872418726
case GGML_OP_MAP_BINARY:
1872518727
case GGML_OP_MAP_CUSTOM1_F32:
@@ -19369,7 +19371,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1936919371
case GGML_OP_WIN_PART:
1937019372
case GGML_OP_WIN_UNPART:
1937119373
case GGML_OP_GET_REL_POS:
19372-
case GGML_OP_RWKV_WKV:
19374+
case GGML_OP_RWKV_WKV6:
1937319375
case GGML_OP_MAP_UNARY:
1937419376
case GGML_OP_MAP_BINARY:
1937519377
case GGML_OP_MAP_CUSTOM1_F32:

src/llama.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7070,7 +7070,7 @@ static const std::map<llm_tensor, llm_tensor_info> llm_tensor_info_mapping = {
70707070
{LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
70717071
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
70727072
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
7073-
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV}},
7073+
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
70747074
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
70757075
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
70767076
{LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
@@ -7185,7 +7185,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
71857185
ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
71867186
op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C);
71877187
} break;
7188-
case GGML_OP_RWKV_WKV:
7188+
case GGML_OP_RWKV_WKV6:
71897189
{
71907190
// FIXME
71917191
const int64_t S = 123;
@@ -7198,7 +7198,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
71987198
ggml_tensor * tf = w;
71997199
ggml_tensor * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
72007200
ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H);
7201-
op_tensor = ggml_rwkv_wkv(ctx, k, v, r, tf, td, state);
7201+
op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state);
72027202
} break;
72037203
default:
72047204
GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name);
@@ -10141,7 +10141,7 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
1014110141
v = ggml_transpose(ctx, v);
1014210142
r = ggml_transpose(ctx, r);
1014310143

10144-
struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
10144+
struct ggml_tensor * wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
1014510145
cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
1014610146
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
1014710147

tests/test-backend-ops.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1613,8 +1613,8 @@ struct test_ssm_scan : public test_case {
16131613
}
16141614
};
16151615

1616-
// GGML_OP_RWKV_WKV
1617-
struct test_rwkv_wkv : public test_case {
1616+
// GGML_OP_RWKV_WKV6
1617+
struct test_rwkv_wkv6 : public test_case {
16181618
const ggml_type type;
16191619

16201620
const int64_t head_count;
@@ -1626,7 +1626,7 @@ struct test_rwkv_wkv : public test_case {
16261626
return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
16271627
}
16281628

1629-
test_rwkv_wkv(ggml_type type = GGML_TYPE_F32,
1629+
test_rwkv_wkv6(ggml_type type = GGML_TYPE_F32,
16301630
int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
16311631
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
16321632

@@ -1638,7 +1638,7 @@ struct test_rwkv_wkv : public test_case {
16381638
ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
16391639
ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
16401640
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
1641-
ggml_tensor * out = ggml_rwkv_wkv(ctx, k, v, r, tf, td, s);
1641+
ggml_tensor * out = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, s);
16421642
return out;
16431643
}
16441644
};
@@ -3498,10 +3498,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
34983498

34993499
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
35003500

3501-
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 1, 1));
3502-
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 1));
3503-
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 4));
3504-
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 128, 4));
3501+
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
3502+
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));
3503+
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
3504+
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
35053505

35063506
#if 1
35073507
for (ggml_type type_a : base_types) {

0 commit comments

Comments
 (0)