Skip to content

Commit 160bc03

Browse files
rwkv6: add wkv6 support for Vulkan backend (#10829)
* rwkv_wkv6 vulkan shader * RWKV_WKV6 Vulkan op tests passed Signed-off-by: Molly Sophia <[email protected]> * Apply code format changes Signed-off-by: Molly Sophia <[email protected]> * add [[unroll]] and remove unnecessary conditions * add uma support * fix erros in EditorConfig Checker --------- Signed-off-by: Molly Sophia <[email protected]> Co-authored-by: Molly Sophia <[email protected]>
1 parent 08ea539 commit 160bc03

File tree

3 files changed

+245
-1
lines changed

3 files changed

+245
-1
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ struct vk_device_struct {
245245
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
246246
vk_pipeline pipeline_timestep_embedding_f32;
247247
vk_pipeline pipeline_pool2d_f32;
248+
vk_pipeline pipeline_rwkv_wkv6_f32;
248249

249250
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
250251
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
@@ -528,6 +529,13 @@ struct vk_op_pool2d_push_constants {
528529
int32_t p0; int32_t p1;
529530
};
530531

532+
struct vk_op_rwkv_wkv6_push_constants {
533+
uint32_t B;
534+
uint32_t T;
535+
uint32_t C;
536+
uint32_t H;
537+
};
538+
531539
// Allow pre-recording command buffers
532540
struct vk_staging_memcpy {
533541
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -2014,6 +2022,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
20142022

20152023
ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
20162024

2025+
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
2026+
20172027
for (auto &c : compiles) {
20182028
c.wait();
20192029
}
@@ -5022,6 +5032,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
50225032
return ctx->device->pipeline_pool2d_f32;
50235033
}
50245034
return nullptr;
5035+
case GGML_OP_RWKV_WKV6:
5036+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5037+
return ctx->device->pipeline_rwkv_wkv6_f32;
5038+
}
5039+
return nullptr;
50255040
case GGML_OP_LEAKY_RELU:
50265041
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
50275042
return ctx->device->pipeline_leaky_relu_f32;
@@ -5424,6 +5439,134 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
54245439
}, dryrun);
54255440
}
54265441

5442+
static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
5443+
const ggml_tensor * k = dst->src[0];
5444+
const ggml_tensor * v = dst->src[1];
5445+
const ggml_tensor * r = dst->src[2];
5446+
const ggml_tensor * tf = dst->src[3];
5447+
const ggml_tensor * td = dst->src[4];
5448+
const ggml_tensor * state = dst->src[5];
5449+
5450+
GGML_ASSERT(!ggml_is_quantized(k->type));
5451+
GGML_ASSERT(!ggml_is_quantized(v->type));
5452+
GGML_ASSERT(!ggml_is_quantized(r->type));
5453+
GGML_ASSERT(!ggml_is_quantized(tf->type));
5454+
GGML_ASSERT(!ggml_is_quantized(td->type));
5455+
GGML_ASSERT(!ggml_is_quantized(state->type));
5456+
GGML_ASSERT(dst->buffer != nullptr);
5457+
5458+
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
5459+
GGML_ASSERT(pipeline != nullptr);
5460+
5461+
if (dryrun) {
5462+
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
5463+
return;
5464+
}
5465+
5466+
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
5467+
ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
5468+
ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
5469+
ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
5470+
ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
5471+
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
5472+
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
5473+
5474+
ggml_vk_sync_buffers(subctx);
5475+
5476+
vk_buffer d_D, d_K, d_V, d_R, d_TF, d_TD, d_State;
5477+
uint64_t k_offset, v_offset, r_offset, tf_offset, td_offset, state_offset, dst_offset;
5478+
bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
5479+
5480+
if (ctx->device->uma) {
5481+
ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
5482+
ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
5483+
ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
5484+
ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
5485+
ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
5486+
ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
5487+
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
5488+
5489+
K_uma = d_K != nullptr;
5490+
V_uma = d_V != nullptr;
5491+
R_uma = d_R != nullptr;
5492+
TF_uma = d_TF != nullptr;
5493+
TD_uma = d_TD != nullptr;
5494+
STATE_uma = d_State != nullptr;
5495+
DST_uma = d_D != nullptr;
5496+
}
5497+
5498+
if (!K_uma) {
5499+
d_K = k_buf_ctx->dev_buffer;
5500+
k_offset = vk_tensor_offset(k) + k->view_offs;
5501+
}
5502+
if (!V_uma) {
5503+
d_V = v_buf_ctx->dev_buffer;
5504+
v_offset = vk_tensor_offset(v) + v->view_offs;
5505+
}
5506+
if (!R_uma) {
5507+
d_R = r_buf_ctx->dev_buffer;
5508+
r_offset = vk_tensor_offset(r) + r->view_offs;
5509+
}
5510+
if (!TF_uma) {
5511+
d_TF = tf_buf_ctx->dev_buffer;
5512+
tf_offset = vk_tensor_offset(tf) + tf->view_offs;
5513+
}
5514+
if (!TD_uma) {
5515+
d_TD = td_buf_ctx->dev_buffer;
5516+
td_offset = vk_tensor_offset(td) + td->view_offs;
5517+
}
5518+
if (!STATE_uma) {
5519+
d_State = state_buf_ctx->dev_buffer;
5520+
state_offset = vk_tensor_offset(state) + state->view_offs;
5521+
}
5522+
if (!DST_uma) {
5523+
d_D = dst_buf_ctx->dev_buffer;
5524+
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
5525+
}
5526+
5527+
const uint64_t k_size = ggml_nbytes(k);
5528+
const uint64_t v_size = ggml_nbytes(v);
5529+
const uint64_t r_size = ggml_nbytes(r);
5530+
const uint64_t tf_size = ggml_nbytes(tf);
5531+
const uint64_t td_size = ggml_nbytes(td);
5532+
const uint64_t state_size = ggml_nbytes(state);
5533+
const uint64_t dst_size = ggml_nbytes(dst);
5534+
5535+
std::array<uint32_t, 3> elements = {
5536+
(uint32_t)(pc.B * pc.H),
5537+
1,
5538+
1
5539+
};
5540+
5541+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
5542+
vk_subbuffer{ d_K, k_offset, k_size },
5543+
vk_subbuffer{ d_V, v_offset, v_size },
5544+
vk_subbuffer{ d_R, r_offset, r_size },
5545+
vk_subbuffer{ d_TF, tf_offset, tf_size },
5546+
vk_subbuffer{ d_TD, td_offset, td_size },
5547+
vk_subbuffer{ d_State, state_offset, state_size },
5548+
vk_subbuffer{ d_D, dst_offset, dst_size }
5549+
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
5550+
}
5551+
5552+
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
5553+
const size_t seq_length = dst->src[0]->ne[3];
5554+
const size_t n_embed = dst->ne[0];
5555+
const size_t n_heads = dst->src[0]->ne[2];
5556+
const size_t n_seqs = dst->src[5]->ne[1];
5557+
5558+
ggml_vk_op_f32_rwkv6(
5559+
ctx, subctx, dst,
5560+
{
5561+
(uint32_t)n_seqs,
5562+
(uint32_t)seq_length,
5563+
(uint32_t)n_embed,
5564+
(uint32_t)n_heads,
5565+
},
5566+
dryrun
5567+
);
5568+
}
5569+
54275570
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
54285571
int * op_params = (int *)dst->op_params;
54295572

@@ -6569,6 +6712,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
65696712
case GGML_OP_IM2COL:
65706713
case GGML_OP_TIMESTEP_EMBEDDING:
65716714
case GGML_OP_POOL_2D:
6715+
case GGML_OP_RWKV_WKV6:
65726716
case GGML_OP_LEAKY_RELU:
65736717
case GGML_OP_FLASH_ATTN_EXT:
65746718
break;
@@ -6768,6 +6912,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
67686912
case GGML_OP_FLASH_ATTN_EXT:
67696913
ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
67706914

6915+
break;
6916+
6917+
case GGML_OP_RWKV_WKV6:
6918+
ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun);
6919+
67716920
break;
67726921
default:
67736922
return false;
@@ -6848,6 +6997,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
68486997
case GGML_OP_IM2COL:
68496998
case GGML_OP_TIMESTEP_EMBEDDING:
68506999
case GGML_OP_POOL_2D:
7000+
case GGML_OP_RWKV_WKV6:
68517001
case GGML_OP_LEAKY_RELU:
68527002
case GGML_OP_REPEAT:
68537003
buf = tensor->buffer;
@@ -7724,6 +7874,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
77247874
case GGML_OP_IM2COL:
77257875
case GGML_OP_TIMESTEP_EMBEDDING:
77267876
case GGML_OP_POOL_2D:
7877+
case GGML_OP_RWKV_WKV6:
77277878
case GGML_OP_LEAKY_RELU:
77287879
return true;
77297880
default:
@@ -8300,7 +8451,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
83008451
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
83018452
const float * op_params = (const float *)tensor->op_params;
83028453
tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
8303-
} else {
8454+
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
8455+
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
8456+
tensor->src[4], tensor->src[5]);
8457+
}
8458+
else {
83048459
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
83058460
GGML_ABORT("fatal error");
83068461
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,8 @@ void process_shaders() {
479479

480480
string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
481481

482+
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
483+
482484
for (auto &c : compiles) {
483485
c.wait();
484486
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#version 450
2+
3+
#extension GL_EXT_control_flow_attributes : require
4+
5+
#define BLOCK_SIZE 64
6+
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
7+
8+
layout(push_constant) uniform Parameters {
9+
uint B;
10+
uint T;
11+
uint C;
12+
uint H;
13+
};
14+
15+
layout(binding = 0) readonly buffer KBuf { A_TYPE k[]; };
16+
layout(binding = 1) readonly buffer VBuf { A_TYPE v[]; };
17+
layout(binding = 2) readonly buffer RBuf { A_TYPE r[]; };
18+
layout(binding = 3) readonly buffer TimeFBuf { A_TYPE tf[]; };
19+
layout(binding = 4) readonly buffer TimeDBuf { A_TYPE td[]; };
20+
layout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[]; };
21+
layout(binding = 6) buffer DstBuf { A_TYPE dst[]; };
22+
23+
shared A_TYPE _k[BLOCK_SIZE], _r[BLOCK_SIZE], _tf[BLOCK_SIZE], _td[BLOCK_SIZE];
24+
25+
void main() {
26+
const uint head_size = BLOCK_SIZE;
27+
const uint batch_id = gl_WorkGroupID.x / H;
28+
const uint head_id = gl_WorkGroupID.x % H;
29+
const uint tid = gl_LocalInvocationID.x;
30+
31+
const uint state_size = C * head_size;
32+
const uint n_seq_tokens = T / B;
33+
34+
if (batch_id >= B || head_id >= H) {
35+
return;
36+
}
37+
38+
A_TYPE state[BLOCK_SIZE];
39+
[[unroll]] for (uint i = 0; i < head_size; i++) {
40+
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
41+
+ i * head_size + tid];
42+
}
43+
44+
barrier();
45+
_tf[tid] = tf[head_id * head_size + tid];
46+
barrier();
47+
48+
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
49+
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
50+
51+
for (uint t = start_t; t < end_t; t += C) {
52+
barrier();
53+
_k[tid] = k[t];
54+
_r[tid] = r[t];
55+
_td[tid] = td[t];
56+
barrier();
57+
58+
const A_TYPE v_val = v[t];
59+
A_TYPE y = 0.0;
60+
61+
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
62+
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
63+
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
64+
vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
65+
vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
66+
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
67+
68+
vec4 kv = k_vec * v_val;
69+
70+
vec4 temp = tf_vec * kv + s_vec;
71+
y += dot(r_vec, temp);
72+
73+
s_vec = s_vec * td_vec + kv;
74+
state[j] = s_vec.x;
75+
state[j+1] = s_vec.y;
76+
state[j+2] = s_vec.z;
77+
state[j+3] = s_vec.w;
78+
}
79+
80+
dst[t] = y;
81+
}
82+
83+
[[unroll]] for (uint i = 0; i < head_size; i++) {
84+
dst[T * C + batch_id * state_size + head_id * head_size * head_size
85+
+ i * head_size + tid] = state[i];
86+
}
87+
}

0 commit comments

Comments
 (0)