@@ -245,6 +245,7 @@ struct vk_device_struct {
245
245
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
246
246
vk_pipeline pipeline_timestep_embedding_f32;
247
247
vk_pipeline pipeline_pool2d_f32;
248
+ vk_pipeline pipeline_rwkv_wkv6_f32;
248
249
249
250
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
250
251
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2 ][2 ][2 ];
@@ -528,6 +529,13 @@ struct vk_op_pool2d_push_constants {
528
529
int32_t p0; int32_t p1;
529
530
};
530
531
532
+ struct vk_op_rwkv_wkv6_push_constants {
533
+ uint32_t B;
534
+ uint32_t T;
535
+ uint32_t C;
536
+ uint32_t H;
537
+ };
538
+
531
539
// Allow pre-recording command buffers
532
540
struct vk_staging_memcpy {
533
541
vk_staging_memcpy (void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -2014,6 +2022,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2014
2022
2015
2023
ggml_vk_create_pipeline (device, device->pipeline_pool2d_f32 , " pool2d_f32" , pool2d_f32_len, pool2d_f32_data, " main" , 2 , sizeof (vk_op_pool2d_push_constants), {512 , 1 , 1 }, {}, 1 );
2016
2024
2025
+ ggml_vk_create_pipeline (device, device->pipeline_rwkv_wkv6_f32 , " rwkv_wkv6_f32" , rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, " main" , 7 , sizeof (vk_op_rwkv_wkv6_push_constants), {1 , 1 , 1 }, {device->subgroup_size }, 1 );
2026
+
2017
2027
for (auto &c : compiles) {
2018
2028
c.wait ();
2019
2029
}
@@ -5022,6 +5032,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5022
5032
return ctx->device ->pipeline_pool2d_f32 ;
5023
5033
}
5024
5034
return nullptr ;
5035
+ case GGML_OP_RWKV_WKV6:
5036
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5037
+ return ctx->device ->pipeline_rwkv_wkv6_f32 ;
5038
+ }
5039
+ return nullptr ;
5025
5040
case GGML_OP_LEAKY_RELU:
5026
5041
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5027
5042
return ctx->device ->pipeline_leaky_relu_f32 ;
@@ -5424,6 +5439,134 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
5424
5439
}, dryrun);
5425
5440
}
5426
5441
5442
+ static void ggml_vk_op_f32_rwkv6 (ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false ) {
5443
+ const ggml_tensor * k = dst->src [0 ];
5444
+ const ggml_tensor * v = dst->src [1 ];
5445
+ const ggml_tensor * r = dst->src [2 ];
5446
+ const ggml_tensor * tf = dst->src [3 ];
5447
+ const ggml_tensor * td = dst->src [4 ];
5448
+ const ggml_tensor * state = dst->src [5 ];
5449
+
5450
+ GGML_ASSERT (!ggml_is_quantized (k->type ));
5451
+ GGML_ASSERT (!ggml_is_quantized (v->type ));
5452
+ GGML_ASSERT (!ggml_is_quantized (r->type ));
5453
+ GGML_ASSERT (!ggml_is_quantized (tf->type ));
5454
+ GGML_ASSERT (!ggml_is_quantized (td->type ));
5455
+ GGML_ASSERT (!ggml_is_quantized (state->type ));
5456
+ GGML_ASSERT (dst->buffer != nullptr );
5457
+
5458
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline (ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
5459
+ GGML_ASSERT (pipeline != nullptr );
5460
+
5461
+ if (dryrun) {
5462
+ ggml_pipeline_request_descriptor_sets (ctx->device , pipeline, 1 );
5463
+ return ;
5464
+ }
5465
+
5466
+ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer ->context ;
5467
+ ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer ->context ;
5468
+ ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer ->context ;
5469
+ ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer ->context ;
5470
+ ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer ->context ;
5471
+ ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer ->context ;
5472
+ ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer ->context ;
5473
+
5474
+ ggml_vk_sync_buffers (subctx);
5475
+
5476
+ vk_buffer d_D, d_K, d_V, d_R, d_TF, d_TD, d_State;
5477
+ uint64_t k_offset, v_offset, r_offset, tf_offset, td_offset, state_offset, dst_offset;
5478
+ bool K_uma = false , V_uma = false , R_uma = false , TF_uma = false , TD_uma = false , STATE_uma = false , DST_uma = false ;
5479
+
5480
+ if (ctx->device ->uma ) {
5481
+ ggml_vk_host_get (ctx->device , k->data , d_K, k_offset);
5482
+ ggml_vk_host_get (ctx->device , v->data , d_V, v_offset);
5483
+ ggml_vk_host_get (ctx->device , r->data , d_R, r_offset);
5484
+ ggml_vk_host_get (ctx->device , tf->data , d_TF, tf_offset);
5485
+ ggml_vk_host_get (ctx->device , td->data , d_TD, td_offset);
5486
+ ggml_vk_host_get (ctx->device , state->data , d_State, state_offset);
5487
+ ggml_vk_host_get (ctx->device , dst->data , d_D, dst_offset);
5488
+
5489
+ K_uma = d_K != nullptr ;
5490
+ V_uma = d_V != nullptr ;
5491
+ R_uma = d_R != nullptr ;
5492
+ TF_uma = d_TF != nullptr ;
5493
+ TD_uma = d_TD != nullptr ;
5494
+ STATE_uma = d_State != nullptr ;
5495
+ DST_uma = d_D != nullptr ;
5496
+ }
5497
+
5498
+ if (!K_uma) {
5499
+ d_K = k_buf_ctx->dev_buffer ;
5500
+ k_offset = vk_tensor_offset (k) + k->view_offs ;
5501
+ }
5502
+ if (!V_uma) {
5503
+ d_V = v_buf_ctx->dev_buffer ;
5504
+ v_offset = vk_tensor_offset (v) + v->view_offs ;
5505
+ }
5506
+ if (!R_uma) {
5507
+ d_R = r_buf_ctx->dev_buffer ;
5508
+ r_offset = vk_tensor_offset (r) + r->view_offs ;
5509
+ }
5510
+ if (!TF_uma) {
5511
+ d_TF = tf_buf_ctx->dev_buffer ;
5512
+ tf_offset = vk_tensor_offset (tf) + tf->view_offs ;
5513
+ }
5514
+ if (!TD_uma) {
5515
+ d_TD = td_buf_ctx->dev_buffer ;
5516
+ td_offset = vk_tensor_offset (td) + td->view_offs ;
5517
+ }
5518
+ if (!STATE_uma) {
5519
+ d_State = state_buf_ctx->dev_buffer ;
5520
+ state_offset = vk_tensor_offset (state) + state->view_offs ;
5521
+ }
5522
+ if (!DST_uma) {
5523
+ d_D = dst_buf_ctx->dev_buffer ;
5524
+ dst_offset = vk_tensor_offset (dst) + dst->view_offs ;
5525
+ }
5526
+
5527
+ const uint64_t k_size = ggml_nbytes (k);
5528
+ const uint64_t v_size = ggml_nbytes (v);
5529
+ const uint64_t r_size = ggml_nbytes (r);
5530
+ const uint64_t tf_size = ggml_nbytes (tf);
5531
+ const uint64_t td_size = ggml_nbytes (td);
5532
+ const uint64_t state_size = ggml_nbytes (state);
5533
+ const uint64_t dst_size = ggml_nbytes (dst);
5534
+
5535
+ std::array<uint32_t , 3 > elements = {
5536
+ (uint32_t )(pc.B * pc.H ),
5537
+ 1 ,
5538
+ 1
5539
+ };
5540
+
5541
+ ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, {
5542
+ vk_subbuffer{ d_K, k_offset, k_size },
5543
+ vk_subbuffer{ d_V, v_offset, v_size },
5544
+ vk_subbuffer{ d_R, r_offset, r_size },
5545
+ vk_subbuffer{ d_TF, tf_offset, tf_size },
5546
+ vk_subbuffer{ d_TD, td_offset, td_size },
5547
+ vk_subbuffer{ d_State, state_offset, state_size },
5548
+ vk_subbuffer{ d_D, dst_offset, dst_size }
5549
+ }, sizeof (vk_op_rwkv_wkv6_push_constants), &pc, elements);
5550
+ }
5551
+
5552
+ static void ggml_vk_rwkv_wkv6 (ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false ) {
5553
+ const size_t seq_length = dst->src [0 ]->ne [3 ];
5554
+ const size_t n_embed = dst->ne [0 ];
5555
+ const size_t n_heads = dst->src [0 ]->ne [2 ];
5556
+ const size_t n_seqs = dst->src [5 ]->ne [1 ];
5557
+
5558
+ ggml_vk_op_f32_rwkv6 (
5559
+ ctx, subctx, dst,
5560
+ {
5561
+ (uint32_t )n_seqs,
5562
+ (uint32_t )seq_length,
5563
+ (uint32_t )n_embed,
5564
+ (uint32_t )n_heads,
5565
+ },
5566
+ dryrun
5567
+ );
5568
+ }
5569
+
5427
5570
static void ggml_vk_concat (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
5428
5571
int * op_params = (int *)dst->op_params ;
5429
5572
@@ -6569,6 +6712,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
6569
6712
case GGML_OP_IM2COL:
6570
6713
case GGML_OP_TIMESTEP_EMBEDDING:
6571
6714
case GGML_OP_POOL_2D:
6715
+ case GGML_OP_RWKV_WKV6:
6572
6716
case GGML_OP_LEAKY_RELU:
6573
6717
case GGML_OP_FLASH_ATTN_EXT:
6574
6718
break ;
@@ -6768,6 +6912,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
6768
6912
case GGML_OP_FLASH_ATTN_EXT:
6769
6913
ggml_vk_flash_attn (ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
6770
6914
6915
+ break ;
6916
+
6917
+ case GGML_OP_RWKV_WKV6:
6918
+ ggml_vk_rwkv_wkv6 (ctx, compute_ctx, node, dryrun);
6919
+
6771
6920
break ;
6772
6921
default :
6773
6922
return false ;
@@ -6848,6 +6997,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
6848
6997
case GGML_OP_IM2COL:
6849
6998
case GGML_OP_TIMESTEP_EMBEDDING:
6850
6999
case GGML_OP_POOL_2D:
7000
+ case GGML_OP_RWKV_WKV6:
6851
7001
case GGML_OP_LEAKY_RELU:
6852
7002
case GGML_OP_REPEAT:
6853
7003
buf = tensor->buffer ;
@@ -7724,6 +7874,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
7724
7874
case GGML_OP_IM2COL:
7725
7875
case GGML_OP_TIMESTEP_EMBEDDING:
7726
7876
case GGML_OP_POOL_2D:
7877
+ case GGML_OP_RWKV_WKV6:
7727
7878
case GGML_OP_LEAKY_RELU:
7728
7879
return true ;
7729
7880
default :
@@ -8300,7 +8451,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8300
8451
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
8301
8452
const float * op_params = (const float *)tensor->op_params ;
8302
8453
tensor_clone = ggml_leaky_relu (ggml_ctx, src0_clone, op_params[0 ], false );
8303
- } else {
8454
+ } else if (tensor->op == GGML_OP_RWKV_WKV6) {
8455
+ tensor_clone = ggml_rwkv_wkv6 (ggml_ctx, tensor->src [0 ], tensor->src [1 ], tensor->src [2 ], tensor->src [3 ],
8456
+ tensor->src [4 ], tensor->src [5 ]);
8457
+ }
8458
+ else {
8304
8459
std::cerr << " Missing vk_check_results OP: " << ggml_op_name (tensor->op ) << std::endl;
8305
8460
GGML_ABORT (" fatal error" );
8306
8461
}
0 commit comments