Skip to content

Commit 194ead5

Browse files
committed
ggml: Add op rwkv_wkv7
Signed-off-by: Molly Sophia <[email protected]>
1 parent 98eff12 commit 194ead5

File tree

19 files changed

+1269
-327
lines changed

19 files changed

+1269
-327
lines changed

ggml/include/ggml.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,7 @@ extern "C" {
503503
GGML_OP_ADD_REL_POS,
504504
GGML_OP_RWKV_WKV6,
505505
GGML_OP_GATED_LINEAR_ATTN,
506+
GGML_OP_RWKV_WKV7,
506507

507508
GGML_OP_UNARY,
508509

@@ -1903,6 +1904,16 @@ extern "C" {
19031904
struct ggml_tensor * state,
19041905
float scale);
19051906

1907+
GGML_API struct ggml_tensor * ggml_rwkv_wkv7(
1908+
struct ggml_context * ctx,
1909+
struct ggml_tensor * r,
1910+
struct ggml_tensor * w,
1911+
struct ggml_tensor * k,
1912+
struct ggml_tensor * v,
1913+
struct ggml_tensor * a,
1914+
struct ggml_tensor * b,
1915+
struct ggml_tensor * state);
1916+
19061917
// custom operators
19071918

19081919
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 185 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13667,6 +13667,184 @@ static void ggml_compute_forward_gla(
1366713667
}
1366813668
}
1366913669

13670+
// ggml_compute_forward_rwkv_wkv7
13671+
13672+
static void ggml_compute_forward_rwkv_wkv7_f32(
13673+
const struct ggml_compute_params * params,
13674+
struct ggml_tensor * dst) {
13675+
const int64_t T = dst->src[1]->ne[2];
13676+
const int64_t C = dst->ne[0];
13677+
const int64_t HEADS = dst->src[1]->ne[1];
13678+
const int64_t n_seqs = dst->src[6]->ne[1];
13679+
const int64_t head_size = C / HEADS;
13680+
13681+
float * dst_data = (float *) dst->data;
13682+
float * state = ((float *) dst->data) + C * T;
13683+
13684+
const int ith = params->ith;
13685+
const int nth = params->nth;
13686+
13687+
if (ith >= HEADS) {
13688+
return;
13689+
}
13690+
13691+
const int h_start = (HEADS * ith) / nth;
13692+
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
13693+
(HEADS * (ith + 1)) / nth : HEADS;
13694+
13695+
float * r = (float *) dst->src[0]->data;
13696+
float * w = (float *) dst->src[1]->data;
13697+
float * k = (float *) dst->src[2]->data;
13698+
float * v = (float *) dst->src[3]->data;
13699+
float * a = (float *) dst->src[4]->data;
13700+
float * b = (float *) dst->src[5]->data;
13701+
13702+
int64_t t_stride = HEADS * head_size; // Same to C
13703+
13704+
int64_t h_stride = C / HEADS;
13705+
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
13706+
int64_t h_stride_2d = head_size * head_size;
13707+
13708+
#if defined(GGML_SIMD)
13709+
for (int64_t t = 0; t < T; t++) {
13710+
int64_t t_offset = t * t_stride;
13711+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
13712+
float * state_cur = state + state_offset;
13713+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
13714+
13715+
for (int64_t h = h_start; h < h_end; h++) {
13716+
int64_t h_offset = h * h_stride;
13717+
int64_t t_h_offset = t_offset + h_offset;
13718+
int64_t h_2d_offset = h * h_stride_2d;
13719+
13720+
for (int64_t ii = 0; ii < head_size; ii++) {
13721+
int64_t t_h_i_offset = t_h_offset + ii;
13722+
int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
13723+
13724+
GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
13725+
13726+
float sa = 0;
13727+
{
13728+
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
13729+
GGML_F32_VEC ax[GGML_F32_ARR];
13730+
GGML_F32_VEC ay[GGML_F32_ARR];
13731+
for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
13732+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
13733+
ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
13734+
ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
13735+
sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
13736+
}
13737+
}
13738+
GGML_F32_VEC_REDUCE(sa, sum);
13739+
}
13740+
13741+
GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
13742+
13743+
int64_t j = 0;
13744+
GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
13745+
for (; j < head_size; j += GGML_F32_STEP) {
13746+
for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
13747+
int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
13748+
int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
13749+
13750+
GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
13751+
GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
13752+
GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
13753+
GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
13754+
13755+
k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
13756+
13757+
GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
13758+
// kv + s * decay + sa * b
13759+
state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
13760+
state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
13761+
GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
13762+
13763+
result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
13764+
}
13765+
}
13766+
GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
13767+
13768+
// There shouldn't be left-overs though.
13769+
for (; j < head_size; j++) {
13770+
int64_t t_h_j_offset = t_h_offset + j;
13771+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
13772+
13773+
float r_val = r[t_h_j_offset];
13774+
float w_val = w[t_h_j_offset];
13775+
float k_val = k[t_h_j_offset];
13776+
float b_val = b[t_h_j_offset];
13777+
float kv_val = v[t_h_i_offset] * k_val;
13778+
13779+
float prev_state_val = state_prev[h_2d_i_j_offset];
13780+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
13781+
dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
13782+
}
13783+
}
13784+
}
13785+
}
13786+
#else
13787+
for (int64_t t = 0; t < T; t++) {
13788+
int64_t t_offset = t * t_stride;
13789+
int64_t state_offset = head_size * C * (t / (T / n_seqs));
13790+
float * state_cur = state + state_offset;
13791+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
13792+
13793+
for (int64_t h = h_start; h < h_end; h++) {
13794+
int64_t h_offset = h * h_stride;
13795+
int64_t t_h_offset = t_offset + h_offset;
13796+
int64_t h_2d_offset = h * h_stride_2d;
13797+
13798+
for (int64_t i = 0; i < head_size; i++) {
13799+
int64_t t_h_i_offset = t_h_offset + i;
13800+
int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
13801+
13802+
float v_val = v[t_h_i_offset];
13803+
13804+
float sa = 0, result = 0;
13805+
for (int64_t j = 0; j < head_size; j++) {
13806+
sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
13807+
}
13808+
13809+
for (int64_t j = 0; j < head_size; j++) {
13810+
int64_t t_h_j_offset = t_h_offset + j;
13811+
int64_t h_2d_i_j_offset = h_2d_i_offset + j;
13812+
13813+
float r_val = r[t_h_j_offset];
13814+
float w_val = w[t_h_j_offset];
13815+
float k_val = k[t_h_j_offset];
13816+
float b_val = b[t_h_j_offset];
13817+
float kv_val = v_val * k_val;
13818+
float prev_state_val = state_prev[h_2d_i_j_offset];
13819+
state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
13820+
result += state_cur[h_2d_i_j_offset] * r_val;
13821+
}
13822+
dst_data[t_h_i_offset] = result;
13823+
}
13824+
}
13825+
}
13826+
#endif
13827+
}
13828+
13829+
13830+
static void ggml_compute_forward_rwkv_wkv7(
13831+
const struct ggml_compute_params * params,
13832+
struct ggml_tensor * dst) {
13833+
13834+
const struct ggml_tensor * src0 = dst->src[0];
13835+
13836+
switch (src0->type) {
13837+
case GGML_TYPE_F32:
13838+
{
13839+
ggml_compute_forward_rwkv_wkv7_f32(params, dst);
13840+
} break;
13841+
default:
13842+
{
13843+
GGML_ABORT("fatal error");
13844+
}
13845+
}
13846+
}
13847+
1367013848
// ggml_compute_forward_map_unary
1367113849

1367213850
static void ggml_compute_forward_map_unary_f32(
@@ -14424,6 +14602,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1442414602
{
1442514603
ggml_compute_forward_gla(params, tensor);
1442614604
} break;
14605+
case GGML_OP_RWKV_WKV7:
14606+
{
14607+
ggml_compute_forward_rwkv_wkv7(params, tensor);
14608+
} break;
1442714609
case GGML_OP_MAP_UNARY:
1442814610
{
1442914611
ggml_unary_op_f32_t fun;
@@ -14716,14 +14898,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1471614898
case GGML_OP_FLASH_ATTN_BACK:
1471714899
case GGML_OP_SSM_CONV:
1471814900
case GGML_OP_SSM_SCAN:
14901+
case GGML_OP_RWKV_WKV6:
14902+
case GGML_OP_GATED_LINEAR_ATTN:
14903+
case GGML_OP_RWKV_WKV7:
1471914904
{
1472014905
n_tasks = n_threads;
1472114906
} break;
1472214907
case GGML_OP_WIN_PART:
1472314908
case GGML_OP_WIN_UNPART:
1472414909
case GGML_OP_GET_REL_POS:
14725-
case GGML_OP_RWKV_WKV6:
14726-
case GGML_OP_GATED_LINEAR_ATTN:
1472714910
case GGML_OP_MAP_UNARY:
1472814911
case GGML_OP_MAP_BINARY:
1472914912
case GGML_OP_MAP_CUSTOM1_F32:

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
#include "ggml-cuda/tsembd.cuh"
3737
#include "ggml-cuda/unary.cuh"
3838
#include "ggml-cuda/upscale.cuh"
39-
#include "ggml-cuda/wkv6.cuh"
39+
#include "ggml-cuda/wkv.cuh"
4040
#include "ggml-cuda/gla.cuh"
4141
#include "ggml.h"
4242

@@ -2307,6 +2307,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23072307
case GGML_OP_GATED_LINEAR_ATTN:
23082308
ggml_cuda_op_gated_linear_attn(ctx, dst);
23092309
break;
2310+
case GGML_OP_RWKV_WKV7:
2311+
ggml_cuda_op_rwkv_wkv7(ctx, dst);
2312+
break;
23102313
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
23112314
ggml_cuda_cross_entropy_loss_back(ctx, dst);
23122315
break;
@@ -3217,6 +3220,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32173220
case GGML_OP_LEAKY_RELU:
32183221
case GGML_OP_RWKV_WKV6:
32193222
case GGML_OP_GATED_LINEAR_ATTN:
3223+
case GGML_OP_RWKV_WKV7:
32203224
return true;
32213225
case GGML_OP_FLASH_ATTN_EXT: {
32223226
#ifndef FLASH_ATTN_AVAILABLE

0 commit comments

Comments
 (0)