@@ -259,6 +259,7 @@ struct vk_device_struct {
259
259
vk_pipeline pipeline_timestep_embedding_f32;
260
260
vk_pipeline pipeline_pool2d_f32;
261
261
vk_pipeline pipeline_rwkv_wkv6_f32;
262
+ vk_pipeline pipeline_opt_step_adamw_f32;
262
263
263
264
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
264
265
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2 ][2 ][2 ];
@@ -2173,6 +2174,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2173
2174
2174
2175
ggml_vk_create_pipeline (device, device->pipeline_rwkv_wkv6_f32 , " rwkv_wkv6_f32" , rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, " main" , 7 , sizeof (vk_op_rwkv_wkv6_push_constants), {1 , 1 , 1 }, {device->subgroup_size }, 1 );
2175
2176
2177
+ ggml_vk_create_pipeline (device, device->pipeline_opt_step_adamw_f32 , " opt_step_adamw_f32" , opt_step_adamw_f32_len, opt_step_adamw_f32_data, " main" , 5 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
2178
+
2176
2179
for (auto &c : compiles) {
2177
2180
c.wait ();
2178
2181
}
@@ -5329,6 +5332,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5329
5332
return ctx->device ->pipeline_rwkv_wkv6_f32 ;
5330
5333
}
5331
5334
return nullptr ;
5335
+ case GGML_OP_OPT_STEP_ADAMW:
5336
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5337
+ return ctx->device ->pipeline_opt_step_adamw_f32 ;
5338
+ }
5339
+ return nullptr ;
5332
5340
case GGML_OP_LEAKY_RELU:
5333
5341
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5334
5342
return ctx->device ->pipeline_leaky_relu_f32 ;
@@ -5936,6 +5944,111 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
5936
5944
);
5937
5945
}
5938
5946
5947
+ static void ggml_vk_op_f32_opt_step_adamw (ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false ) {
5948
+ const ggml_tensor * x = dst->src [0 ];
5949
+ const ggml_tensor * g = dst->src [1 ];
5950
+ const ggml_tensor * gm = dst->src [2 ];
5951
+ const ggml_tensor * gv = dst->src [3 ];
5952
+ const ggml_tensor * p = dst->src [4 ];
5953
+
5954
+ GGML_ASSERT (x->type == GGML_TYPE_F32);
5955
+ GGML_ASSERT (g->type == GGML_TYPE_F32);
5956
+ GGML_ASSERT (gm->type == GGML_TYPE_F32);
5957
+ GGML_ASSERT (gv->type == GGML_TYPE_F32);
5958
+ GGML_ASSERT (p->type == GGML_TYPE_F32);
5959
+ GGML_ASSERT (dst->buffer != nullptr );
5960
+ GGML_ASSERT (ggml_is_contiguous (x));
5961
+ GGML_ASSERT (ggml_is_contiguous (g));
5962
+ GGML_ASSERT (ggml_is_contiguous (gm));
5963
+ GGML_ASSERT (ggml_is_contiguous (gv));
5964
+ GGML_ASSERT (ggml_is_contiguous (p));
5965
+ GGML_ASSERT (ggml_are_same_shape (x, g));
5966
+ GGML_ASSERT (ggml_are_same_shape (x, gm));
5967
+ GGML_ASSERT (ggml_are_same_shape (x, gv));
5968
+ GGML_ASSERT (ggml_nelements (p) == 7 );
5969
+
5970
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline (ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW);
5971
+ GGML_ASSERT (pipeline != nullptr );
5972
+
5973
+ if (dryrun) {
5974
+ ggml_pipeline_request_descriptor_sets (ctx->device , pipeline, 1 );
5975
+ return ;
5976
+ }
5977
+
5978
+ ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer ->context ;
5979
+ ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer ->context ;
5980
+ ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer ->context ;
5981
+ ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer ->context ;
5982
+ ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer ->context ;
5983
+
5984
+ ggml_vk_sync_buffers (subctx);
5985
+
5986
+ vk_buffer d_X = nullptr , d_G = nullptr , d_GM = nullptr , d_GV = nullptr , d_P = nullptr ;
5987
+ size_t x_offset = 0 , g_offset = 0 , gm_offset = 0 , gv_offset = 0 , p_offset = 0 ;
5988
+ bool X_uma = false , G_uma = false , GM_uma = false , GV_uma = false , P_uma = false ;
5989
+
5990
+ if (ctx->device ->uma ) {
5991
+ ggml_vk_host_get (ctx->device , x->data , d_X, x_offset);
5992
+ ggml_vk_host_get (ctx->device , g->data , d_G, g_offset);
5993
+ ggml_vk_host_get (ctx->device , gm->data , d_GM, gm_offset);
5994
+ ggml_vk_host_get (ctx->device , gv->data , d_GV, gv_offset);
5995
+ ggml_vk_host_get (ctx->device , p->data , d_P, p_offset);
5996
+
5997
+ X_uma = d_X != nullptr ;
5998
+ G_uma = d_G != nullptr ;
5999
+ GM_uma = d_GM != nullptr ;
6000
+ GV_uma = d_GV != nullptr ;
6001
+ P_uma = d_P != nullptr ;
6002
+ }
6003
+
6004
+ if (!X_uma) {
6005
+ d_X = x_buf_ctx->dev_buffer ;
6006
+ x_offset = vk_tensor_offset (x) + x->view_offs ;
6007
+ }
6008
+ if (!G_uma) {
6009
+ d_G = g_buf_ctx->dev_buffer ;
6010
+ g_offset = vk_tensor_offset (g) + g->view_offs ;
6011
+ }
6012
+ if (!GM_uma) {
6013
+ d_GM = gm_buf_ctx->dev_buffer ;
6014
+ gm_offset = vk_tensor_offset (gm) + gm->view_offs ;
6015
+ }
6016
+ if (!GV_uma) {
6017
+ d_GV = gv_buf_ctx->dev_buffer ;
6018
+ gv_offset = vk_tensor_offset (gv) + gv->view_offs ;
6019
+ }
6020
+ if (!P_uma) {
6021
+ d_P = p_buf_ctx->dev_buffer ;
6022
+ p_offset = vk_tensor_offset (p) + p->view_offs ;
6023
+ }
6024
+
6025
+ const uint64_t x_size = ggml_nbytes (x);
6026
+ const uint64_t g_size = ggml_nbytes (g);
6027
+ const uint64_t gm_size = ggml_nbytes (gm);
6028
+ const uint64_t gv_size = ggml_nbytes (gv);
6029
+ const uint64_t p_size = ggml_nbytes (p);
6030
+
6031
+ std::array<uint32_t , 3 > elements = { (uint32_t )ggml_nelements (x), 1 , 1 };
6032
+
6033
+ ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, {
6034
+ vk_subbuffer{ d_X, x_offset, x_size },
6035
+ vk_subbuffer{ d_G, g_offset, g_size },
6036
+ vk_subbuffer{ d_GM, gm_offset, gm_size },
6037
+ vk_subbuffer{ d_GV, gv_offset, gv_size },
6038
+ vk_subbuffer{ d_P, p_offset, p_size },
6039
+ }, sizeof (vk_op_push_constants), &pc, elements);
6040
+ }
6041
+
6042
+ static void ggml_vk_opt_step_adamw (ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false ) {
6043
+ const size_t n = ggml_nelements (dst->src [0 ]);
6044
+
6045
+ ggml_vk_op_f32_opt_step_adamw (
6046
+ ctx, subctx, dst,
6047
+ { (uint32_t )n, 0 , 0 .0f , 0 .0f },
6048
+ dryrun
6049
+ );
6050
+ }
6051
+
5939
6052
static void ggml_vk_concat (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
5940
6053
int * op_params = (int *)dst->op_params ;
5941
6054
@@ -7100,6 +7213,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7100
7213
case GGML_OP_RWKV_WKV6:
7101
7214
case GGML_OP_LEAKY_RELU:
7102
7215
case GGML_OP_FLASH_ATTN_EXT:
7216
+ case GGML_OP_OPT_STEP_ADAMW:
7103
7217
break ;
7104
7218
default :
7105
7219
std::cerr << " ggml_vulkan: Error: Missing op: " << ggml_op_name (node->op ) << std::endl;
@@ -7322,6 +7436,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7322
7436
case GGML_OP_RWKV_WKV6:
7323
7437
ggml_vk_rwkv_wkv6 (ctx, compute_ctx, node, dryrun);
7324
7438
7439
+ break ;
7440
+
7441
+ case GGML_OP_OPT_STEP_ADAMW:
7442
+ ggml_vk_opt_step_adamw (ctx, compute_ctx, node, dryrun);
7443
+
7325
7444
break ;
7326
7445
default :
7327
7446
return false ;
@@ -7409,6 +7528,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7409
7528
case GGML_OP_RWKV_WKV6:
7410
7529
case GGML_OP_LEAKY_RELU:
7411
7530
case GGML_OP_REPEAT:
7531
+ case GGML_OP_OPT_STEP_ADAMW:
7412
7532
buf = tensor->buffer ;
7413
7533
7414
7534
break ;
@@ -8346,6 +8466,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8346
8466
case GGML_OP_POOL_2D:
8347
8467
case GGML_OP_RWKV_WKV6:
8348
8468
case GGML_OP_LEAKY_RELU:
8469
+ case GGML_OP_OPT_STEP_ADAMW:
8349
8470
return true ;
8350
8471
default :
8351
8472
return false ;
@@ -8951,6 +9072,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8951
9072
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
8952
9073
tensor_clone = ggml_rwkv_wkv6 (ggml_ctx, tensor->src [0 ], tensor->src [1 ], tensor->src [2 ], tensor->src [3 ],
8953
9074
tensor->src [4 ], tensor->src [5 ]);
9075
+ } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
9076
+ tensor_clone = ggml_opt_step_adamw (ggml_ctx, tensor->src [0 ], tensor->src [1 ], tensor->src [2 ],
9077
+ tensor->src [3 ], tensor->src [4 ]);
8954
9078
}
8955
9079
else {
8956
9080
std::cerr << " Missing vk_check_results OP: " << ggml_op_name (tensor->op ) << std::endl;
0 commit comments