Skip to content

Commit c4bd31e

Browse files
committed
wkv7 CUDA impl
Signed-off-by: Molly Sophia <[email protected]>
1 parent 8674cdb commit c4bd31e

File tree

5 files changed

+197
-90
lines changed

5 files changed

+197
-90
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
#include "ggml-cuda/tsembd.cuh"
3737
#include "ggml-cuda/unary.cuh"
3838
#include "ggml-cuda/upscale.cuh"
39-
#include "ggml-cuda/wkv6.cuh"
39+
#include "ggml-cuda/wkv.cuh"
4040
#include "ggml-cuda/gla.cuh"
4141
#include "ggml.h"
4242

@@ -2298,6 +2298,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22982298
case GGML_OP_RWKV_WKV6:
22992299
ggml_cuda_op_rwkv_wkv6(ctx, dst);
23002300
break;
2301+
case GGML_OP_RWKV_WKV7:
2302+
ggml_cuda_op_rwkv_wkv7(ctx, dst);
2303+
break;
23012304
case GGML_OP_GATED_LINEAR_ATTN:
23022305
ggml_cuda_op_gated_linear_attn(ctx, dst);
23032306
break;
@@ -3196,6 +3199,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31963199
case GGML_OP_TIMESTEP_EMBEDDING:
31973200
case GGML_OP_LEAKY_RELU:
31983201
case GGML_OP_RWKV_WKV6:
3202+
case GGML_OP_RWKV_WKV7:
31993203
case GGML_OP_GATED_LINEAR_ATTN:
32003204
return true;
32013205
case GGML_OP_FLASH_ATTN_EXT: {

ggml/src/ggml-cuda/wkv.cu

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
#include "common.cuh"
2+
#include "wkv.cuh"
3+
4+
static __global__ void rwkv_wkv6_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
5+
const int tid = threadIdx.x;
6+
const int bid = blockIdx.x;
7+
8+
const int head_size = CUDA_WKV_BLOCK_SIZE;
9+
const int batch_i = bid / H;
10+
const int head_i = bid % H;
11+
const int state_size = C * head_size;
12+
const int n_seq_tokens = T / B;
13+
14+
float state[head_size];
15+
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
16+
17+
#pragma unroll
18+
for (int i = 0; i < head_size; i++) {
19+
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
20+
}
21+
22+
__syncthreads();
23+
_tf[tid] = tf[head_i * head_size + tid];
24+
__syncthreads();
25+
26+
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
27+
__syncthreads();
28+
_k[tid] = k[t];
29+
_r[tid] = r[t];
30+
_td[tid] = td[t];
31+
__syncthreads();
32+
33+
const float _v = v[t];
34+
float y = 0;
35+
for (int j = 0; j < head_size; j += 4) {
36+
const float4& k = (float4&)(_k[j]);
37+
const float4& r = (float4&)(_r[j]);
38+
const float4& tf = (float4&)(_tf[j]);
39+
const float4& td = (float4&)(_td[j]);
40+
float4& s = (float4&)(state[j]);
41+
float4 kv;
42+
43+
kv.x = k.x * _v;
44+
kv.y = k.y * _v;
45+
kv.z = k.z * _v;
46+
kv.w = k.w * _v;
47+
48+
y += r.x * (tf.x * kv.x + s.x);
49+
y += r.y * (tf.y * kv.y + s.y);
50+
y += r.z * (tf.z * kv.z + s.z);
51+
y += r.w * (tf.w * kv.w + s.w);
52+
53+
s.x = s.x * td.x + kv.x;
54+
s.y = s.y * td.y + kv.y;
55+
s.z = s.z * td.z + kv.z;
56+
s.w = s.w * td.w + kv.w;
57+
}
58+
dst[t] = y;
59+
}
60+
61+
#pragma unroll
62+
for (int i = 0; i < head_size; i++) {
63+
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
64+
}
65+
}
66+
67+
static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) {
68+
const int tid = threadIdx.x;
69+
const int bid = blockIdx.x;
70+
71+
const int head_size = CUDA_WKV_BLOCK_SIZE;
72+
const int batch_i = bid / H;
73+
const int head_i = bid % H;
74+
const int state_size = C * head_size;
75+
const int n_seq_tokens = T / B;
76+
77+
float state[head_size];
78+
__shared__ float _r[head_size], _w[head_size], _k[head_size], _a[head_size], _b[head_size];
79+
80+
#pragma unroll
81+
for (int i = 0; i < head_size; i++) {
82+
state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
83+
}
84+
85+
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
86+
__syncthreads();
87+
_r[tid] = r[t];
88+
_w[tid] = w[t];
89+
_k[tid] = k[t];
90+
_a[tid] = a[t];
91+
_b[tid] = b[t];
92+
__syncthreads();
93+
94+
float sa = 0;
95+
#pragma unroll
96+
for (int j = 0; j < head_size; j += 4)
97+
{
98+
const float4& a = (float4&)(_a[j]);
99+
const float4& s = (float4&)(state[j]);
100+
sa += a.x * s.x;
101+
sa += a.y * s.y;
102+
sa += a.z * s.z;
103+
sa += a.w * s.w;
104+
}
105+
106+
const float _v = v[t];
107+
float y = 0;
108+
for (int j = 0; j < head_size; j += 4) {
109+
const float4& r = (float4&)(_r[j]);
110+
const float4& w = (float4&)(_w[j]);
111+
const float4& k = (float4&)(_k[j]);
112+
const float4& b = (float4&)(_b[j]);
113+
float4& s = (float4&)(state[j]);
114+
float4 kv;
115+
116+
kv.x = k.x * _v;
117+
kv.y = k.y * _v;
118+
kv.z = k.z * _v;
119+
kv.w = k.w * _v;
120+
121+
s.x = s.x * w.x + kv.x + sa * b.x;
122+
s.y = s.y * w.y + kv.y + sa * b.y;
123+
s.z = s.z * w.z + kv.z + sa * b.z;
124+
s.w = s.w * w.w + kv.w + sa * b.w;
125+
126+
y += s.x * r.x;
127+
y += s.y * r.y;
128+
y += s.z * r.z;
129+
y += s.w * r.w;
130+
}
131+
dst[t] = y;
132+
}
133+
134+
#pragma unroll
135+
for (int i = 0; i < head_size; i++) {
136+
dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
137+
}
138+
}
139+
140+
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
141+
const float * k_d = (const float *)dst->src[0]->data;
142+
const float * v_d = (const float *)dst->src[1]->data;
143+
const float * r_d = (const float *)dst->src[2]->data;
144+
const float * tf_d = (const float *)dst->src[3]->data;
145+
const float * td_d = (const float *)dst->src[4]->data;
146+
const float * s_d = (const float *)dst->src[5]->data;
147+
148+
const int64_t B = dst->src[5]->ne[1];
149+
const int64_t T = dst->src[0]->ne[2];
150+
const int64_t C = dst->ne[0];
151+
const int64_t H = dst->src[0]->ne[1];
152+
153+
float * dst_d = (float *)dst->data;
154+
155+
cudaStream_t stream = ctx.stream();
156+
157+
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
158+
GGML_ASSERT(C % H == 0);
159+
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64
160+
161+
rwkv_wkv6_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
162+
}
163+
164+
void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
165+
const float * r_d = (const float *)dst->src[0]->data;
166+
const float * w_d = (const float *)dst->src[1]->data;
167+
const float * k_d = (const float *)dst->src[2]->data;
168+
const float * v_d = (const float *)dst->src[3]->data;
169+
const float * a_d = (const float *)dst->src[4]->data;
170+
const float * b_d = (const float *)dst->src[5]->data;
171+
const float * s_d = (const float *)dst->src[6]->data;
172+
173+
const int64_t B = dst->src[6]->ne[1];
174+
const int64_t T = dst->src[0]->ne[2];
175+
const int64_t C = dst->ne[0];
176+
const int64_t H = dst->src[0]->ne[1];
177+
178+
float * dst_d = (float *)dst->data;
179+
180+
cudaStream_t stream = ctx.stream();
181+
182+
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
183+
GGML_ASSERT(C % H == 0);
184+
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE);
185+
186+
rwkv_wkv7_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
187+
}

ggml/src/ggml-cuda/wkv6.cuh renamed to ggml/src/ggml-cuda/wkv.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@
33
#define CUDA_WKV_BLOCK_SIZE 64
44

55
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6+
7+
void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cuda/wkv6.cu

Lines changed: 0 additions & 89 deletions
This file was deleted.

tests/test-backend-ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1893,6 +1893,9 @@ struct test_rwkv_wkv7 : public test_case {
18931893
ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
18941894
ggml_tensor * a = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
18951895
ggml_tensor * b = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
1896+
// Outputs may become NaN with long seqlen without these normalization
1897+
a = ggml_l2_norm(ctx, a, 1e-7F);
1898+
b = ggml_l2_norm(ctx, b, 1e-7F);
18961899
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
18971900
ggml_tensor * out = ggml_rwkv_wkv7(ctx, r, w, k, v, a, b, s);
18981901
return out;

0 commit comments

Comments
 (0)