@@ -431,6 +431,10 @@ struct vk_device_struct {
431
431
vk_pipeline pipeline_tanh[2];
432
432
vk_pipeline pipeline_sigmoid[2];
433
433
434
+ vk_pipeline pipeline_geglu[2];
435
+ vk_pipeline pipeline_reglu[2];
436
+ vk_pipeline pipeline_swiglu[2];
437
+
434
438
vk_pipeline pipeline_leaky_relu_f32;
435
439
vk_pipeline pipeline_silu_back_f32;
436
440
vk_pipeline pipeline_diag_mask_inf_f32;
@@ -2728,6 +2732,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
2728
2732
CREATE_UNARY(sigmoid)
2729
2733
#undef CREATE_UNARY
2730
2734
2735
+ #define CREATE_GLU(name) \
2736
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); \
2737
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2738
+
2739
+ CREATE_GLU(geglu)
2740
+ CREATE_GLU(reglu)
2741
+ CREATE_GLU(swiglu)
2742
+ #undef CREATE_GLU
2743
+
2731
2744
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2732
2745
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2733
2746
@@ -6415,6 +6428,24 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6415
6428
break;
6416
6429
}
6417
6430
return nullptr;
6431
+ case GGML_OP_GLU:
6432
+ if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
6433
+ (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
6434
+ (src0->type != dst->type)) {
6435
+ return nullptr;
6436
+ }
6437
+
6438
+ switch (ggml_get_glu_op(dst)) {
6439
+ case GGML_GLU_OP_GEGLU:
6440
+ return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16];
6441
+ case GGML_GLU_OP_REGLU:
6442
+ return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
6443
+ case GGML_GLU_OP_SWIGLU:
6444
+ return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
6445
+ default:
6446
+ break;
6447
+ }
6448
+ return nullptr;
6418
6449
case GGML_OP_DIAG_MASK_INF:
6419
6450
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6420
6451
return ctx->device->pipeline_diag_mask_inf_f32;
@@ -6791,6 +6822,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6791
6822
case GGML_OP_SOFT_MAX_BACK:
6792
6823
case GGML_OP_SUM_ROWS:
6793
6824
case GGML_OP_ARGMAX:
6825
+ case GGML_OP_GLU:
6794
6826
{
6795
6827
const uint32_t nr = ggml_nrows(src0);
6796
6828
if (nr > 262144) {
@@ -7507,6 +7539,14 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
7507
7539
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
7508
7540
}
7509
7541
7542
+ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7543
+ GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
7544
+
7545
+ const uint32_t swapped = (uint32_t)dst->op_params[1];
7546
+
7547
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], swapped, 0.0f, 0.0f }, dryrun);
7548
+ }
7549
+
7510
7550
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7511
7551
int32_t * op_params = (int32_t *)dst->op_params;
7512
7552
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
@@ -8718,6 +8758,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8718
8758
return false;
8719
8759
}
8720
8760
break;
8761
+ case GGML_OP_GLU:
8762
+ switch (ggml_get_glu_op(node)) {
8763
+ case GGML_GLU_OP_GEGLU:
8764
+ case GGML_GLU_OP_REGLU:
8765
+ case GGML_GLU_OP_SWIGLU:
8766
+ break;
8767
+ default:
8768
+ return false;
8769
+ }
8770
+ break;
8721
8771
case GGML_OP_REPEAT:
8722
8772
case GGML_OP_REPEAT_BACK:
8723
8773
case GGML_OP_GET_ROWS:
@@ -8810,6 +8860,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8810
8860
case GGML_OP_RMS_NORM_BACK:
8811
8861
case GGML_OP_L2_NORM:
8812
8862
case GGML_OP_UNARY:
8863
+ case GGML_OP_GLU:
8813
8864
case GGML_OP_DIAG_MASK_INF:
8814
8865
case GGML_OP_SOFT_MAX:
8815
8866
case GGML_OP_SOFT_MAX_BACK:
@@ -8947,6 +8998,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8947
8998
return false;
8948
8999
}
8949
9000
break;
9001
+ case GGML_OP_GLU:
9002
+ switch (ggml_get_glu_op(node)) {
9003
+ case GGML_GLU_OP_GEGLU:
9004
+ case GGML_GLU_OP_REGLU:
9005
+ case GGML_GLU_OP_SWIGLU:
9006
+ ggml_vk_glu(ctx, compute_ctx, src0, node, dryrun);
9007
+ break;
9008
+ default:
9009
+ return false;
9010
+ }
9011
+ break;
8950
9012
case GGML_OP_DIAG_MASK_INF:
8951
9013
ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
8952
9014
@@ -9072,8 +9134,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
9072
9134
if (!ok) {
9073
9135
if (node->op == GGML_OP_UNARY) {
9074
9136
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
9075
- }
9076
- else {
9137
+ } else if (node->op == GGML_OP_GLU) {
9138
+ std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
9139
+ } else {
9077
9140
std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
9078
9141
}
9079
9142
}
@@ -9152,6 +9215,17 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9152
9215
return false;
9153
9216
}
9154
9217
break;
9218
+ case GGML_OP_GLU:
9219
+ switch (ggml_get_glu_op(tensor)) {
9220
+ case GGML_GLU_OP_GEGLU:
9221
+ case GGML_GLU_OP_REGLU:
9222
+ case GGML_GLU_OP_SWIGLU:
9223
+ buf = tensor->buffer;
9224
+ break;
9225
+ default:
9226
+ return false;
9227
+ }
9228
+ break;
9155
9229
case GGML_OP_MUL_MAT:
9156
9230
case GGML_OP_MUL_MAT_ID:
9157
9231
case GGML_OP_FLASH_ATTN_EXT:
@@ -9923,6 +9997,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9923
9997
return false;
9924
9998
}
9925
9999
break;
10000
+ case GGML_OP_GLU:
10001
+ switch (ggml_get_glu_op(op)) {
10002
+ case GGML_GLU_OP_GEGLU:
10003
+ case GGML_GLU_OP_REGLU:
10004
+ case GGML_GLU_OP_SWIGLU:
10005
+ return ggml_is_contiguous(op->src[0]) &&
10006
+ (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
10007
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
10008
+ (op->src[0]->type == op->type);
10009
+ default:
10010
+ return false;
10011
+ }
10012
+ break;
9926
10013
case GGML_OP_MUL_MAT:
9927
10014
case GGML_OP_MUL_MAT_ID:
9928
10015
{
@@ -10637,6 +10724,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10637
10724
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
10638
10725
GGML_ABORT("fatal error");
10639
10726
}
10727
+ } else if (tensor->op == GGML_OP_GLU) {
10728
+ tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
10640
10729
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
10641
10730
if (src1 == nullptr) {
10642
10731
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
0 commit comments