Skip to content

Commit 9526033

Browse files
vulkan: implement GGML_OP_OPT_STEP_ADAMW
1 parent 095f8d1 commit 9526033

File tree

3 files changed

+168
-0
lines changed

3 files changed

+168
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ struct vk_device_struct {
259259
vk_pipeline pipeline_timestep_embedding_f32;
260260
vk_pipeline pipeline_pool2d_f32;
261261
vk_pipeline pipeline_rwkv_wkv6_f32;
262+
vk_pipeline pipeline_opt_step_adamw_f32;
262263

263264
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
264265
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
@@ -2173,6 +2174,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
21732174

21742175
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
21752176

2177+
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2178+
21762179
for (auto &c : compiles) {
21772180
c.wait();
21782181
}
@@ -5329,6 +5332,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
53295332
return ctx->device->pipeline_rwkv_wkv6_f32;
53305333
}
53315334
return nullptr;
5335+
case GGML_OP_OPT_STEP_ADAMW:
5336+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5337+
return ctx->device->pipeline_opt_step_adamw_f32;
5338+
}
5339+
return nullptr;
53325340
case GGML_OP_LEAKY_RELU:
53335341
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
53345342
return ctx->device->pipeline_leaky_relu_f32;
@@ -5936,6 +5944,111 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
59365944
);
59375945
}
59385946

5947+
static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
5948+
const ggml_tensor * x = dst->src[0];
5949+
const ggml_tensor * g = dst->src[1];
5950+
const ggml_tensor * gm = dst->src[2];
5951+
const ggml_tensor * gv = dst->src[3];
5952+
const ggml_tensor * p = dst->src[4];
5953+
5954+
GGML_ASSERT(x->type == GGML_TYPE_F32);
5955+
GGML_ASSERT(g->type == GGML_TYPE_F32);
5956+
GGML_ASSERT(gm->type == GGML_TYPE_F32);
5957+
GGML_ASSERT(gv->type == GGML_TYPE_F32);
5958+
GGML_ASSERT(p->type == GGML_TYPE_F32);
5959+
GGML_ASSERT(dst->buffer != nullptr);
5960+
GGML_ASSERT(ggml_is_contiguous(x));
5961+
GGML_ASSERT(ggml_is_contiguous(g));
5962+
GGML_ASSERT(ggml_is_contiguous(gm));
5963+
GGML_ASSERT(ggml_is_contiguous(gv));
5964+
GGML_ASSERT(ggml_is_contiguous(p));
5965+
GGML_ASSERT(ggml_are_same_shape(x, g));
5966+
GGML_ASSERT(ggml_are_same_shape(x, gm));
5967+
GGML_ASSERT(ggml_are_same_shape(x, gv));
5968+
GGML_ASSERT(ggml_nelements(p) == 7);
5969+
5970+
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW);
5971+
GGML_ASSERT(pipeline != nullptr);
5972+
5973+
if (dryrun) {
5974+
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
5975+
return;
5976+
}
5977+
5978+
ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer->context;
5979+
ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer->context;
5980+
ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer->context;
5981+
ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context;
5982+
ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context;
5983+
5984+
ggml_vk_sync_buffers(subctx);
5985+
5986+
vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr;
5987+
size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0;
5988+
bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false;
5989+
5990+
if (ctx->device->uma) {
5991+
ggml_vk_host_get(ctx->device, x->data, d_X, x_offset);
5992+
ggml_vk_host_get(ctx->device, g->data, d_G, g_offset);
5993+
ggml_vk_host_get(ctx->device, gm->data, d_GM, gm_offset);
5994+
ggml_vk_host_get(ctx->device, gv->data, d_GV, gv_offset);
5995+
ggml_vk_host_get(ctx->device, p->data, d_P, p_offset);
5996+
5997+
X_uma = d_X != nullptr;
5998+
G_uma = d_G != nullptr;
5999+
GM_uma = d_GM != nullptr;
6000+
GV_uma = d_GV != nullptr;
6001+
P_uma = d_P != nullptr;
6002+
}
6003+
6004+
if (!X_uma) {
6005+
d_X = x_buf_ctx->dev_buffer;
6006+
x_offset = vk_tensor_offset(x) + x->view_offs;
6007+
}
6008+
if (!G_uma) {
6009+
d_G = g_buf_ctx->dev_buffer;
6010+
g_offset = vk_tensor_offset(g) + g->view_offs;
6011+
}
6012+
if (!GM_uma) {
6013+
d_GM = gm_buf_ctx->dev_buffer;
6014+
gm_offset = vk_tensor_offset(gm) + gm->view_offs;
6015+
}
6016+
if (!GV_uma) {
6017+
d_GV = gv_buf_ctx->dev_buffer;
6018+
gv_offset = vk_tensor_offset(gv) + gv->view_offs;
6019+
}
6020+
if (!P_uma) {
6021+
d_P = p_buf_ctx->dev_buffer;
6022+
p_offset = vk_tensor_offset(p) + p->view_offs;
6023+
}
6024+
6025+
const uint64_t x_size = ggml_nbytes(x);
6026+
const uint64_t g_size = ggml_nbytes(g);
6027+
const uint64_t gm_size = ggml_nbytes(gm);
6028+
const uint64_t gv_size = ggml_nbytes(gv);
6029+
const uint64_t p_size = ggml_nbytes(p);
6030+
6031+
std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(x), 1, 1 };
6032+
6033+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
6034+
vk_subbuffer{ d_X, x_offset, x_size },
6035+
vk_subbuffer{ d_G, g_offset, g_size },
6036+
vk_subbuffer{ d_GM, gm_offset, gm_size },
6037+
vk_subbuffer{ d_GV, gv_offset, gv_size },
6038+
vk_subbuffer{ d_P, p_offset, p_size },
6039+
}, sizeof(vk_op_push_constants), &pc, elements);
6040+
}
6041+
6042+
static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
6043+
const size_t n = ggml_nelements(dst->src[0]);
6044+
6045+
ggml_vk_op_f32_opt_step_adamw(
6046+
ctx, subctx, dst,
6047+
{ (uint32_t)n, 0, 0.0f, 0.0f },
6048+
dryrun
6049+
);
6050+
}
6051+
59396052
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
59406053
int * op_params = (int *)dst->op_params;
59416054

@@ -7100,6 +7213,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71007213
case GGML_OP_RWKV_WKV6:
71017214
case GGML_OP_LEAKY_RELU:
71027215
case GGML_OP_FLASH_ATTN_EXT:
7216+
case GGML_OP_OPT_STEP_ADAMW:
71037217
break;
71047218
default:
71057219
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@@ -7322,6 +7436,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
73227436
case GGML_OP_RWKV_WKV6:
73237437
ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun);
73247438

7439+
break;
7440+
7441+
case GGML_OP_OPT_STEP_ADAMW:
7442+
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
7443+
73257444
break;
73267445
default:
73277446
return false;
@@ -7409,6 +7528,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
74097528
case GGML_OP_RWKV_WKV6:
74107529
case GGML_OP_LEAKY_RELU:
74117530
case GGML_OP_REPEAT:
7531+
case GGML_OP_OPT_STEP_ADAMW:
74127532
buf = tensor->buffer;
74137533

74147534
break;
@@ -8346,6 +8466,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
83468466
case GGML_OP_POOL_2D:
83478467
case GGML_OP_RWKV_WKV6:
83488468
case GGML_OP_LEAKY_RELU:
8469+
case GGML_OP_OPT_STEP_ADAMW:
83498470
return true;
83508471
default:
83518472
return false;
@@ -8951,6 +9072,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
89519072
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
89529073
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
89539074
tensor->src[4], tensor->src[5]);
9075+
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
9076+
tensor_clone = ggml_opt_step_adamw(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2],
9077+
tensor->src[3], tensor->src[4]);
89549078
}
89559079
else {
89569080
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) buffer X {A_TYPE x[];};
11+
layout (binding = 1) readonly buffer G {A_TYPE grad[];};
12+
layout (binding = 2) buffer GM {A_TYPE gradm[];};
13+
layout (binding = 3) buffer GV {A_TYPE gradv[];};
14+
layout (binding = 4) readonly buffer P {float params[7];};
15+
16+
void main() {
17+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
18+
19+
if (i >= p.KX) {
20+
return;
21+
}
22+
23+
const float alpha = params[0];
24+
const float beta1 = params[1];
25+
const float beta2 = params[2];
26+
const float eps = params[3];
27+
const float wd = params[4];
28+
const float beta1h = params[5];
29+
const float beta2h = params[6];
30+
31+
const float gi = grad[i];
32+
const float gmi = gradm[i]*beta1 + gi*(1.0f - beta1);
33+
const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2);
34+
35+
gradm[i] = gmi;
36+
gradv[i] = gvi;
37+
38+
const float mh = gmi*beta1h;
39+
const float vh = sqrt(gvi*beta2h) + eps;
40+
41+
x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
42+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,8 @@ void process_shaders() {
500500

501501
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
502502

503+
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
504+
503505
for (auto &c : compiles) {
504506
c.wait();
505507
}

0 commit comments

Comments
 (0)