@@ -436,6 +436,10 @@ struct vk_device_struct {
436
436
vk_pipeline pipeline_tanh[2];
437
437
vk_pipeline pipeline_sigmoid[2];
438
438
439
+ vk_pipeline pipeline_geglu[2];
440
+ vk_pipeline pipeline_reglu[2];
441
+ vk_pipeline pipeline_swiglu[2];
442
+
439
443
vk_pipeline pipeline_leaky_relu_f32;
440
444
vk_pipeline pipeline_silu_back_f32;
441
445
vk_pipeline pipeline_diag_mask_inf_f32;
@@ -2751,6 +2755,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
2751
2755
CREATE_UNARY(sigmoid)
2752
2756
#undef CREATE_UNARY
2753
2757
2758
+ #define CREATE_GLU(name) \
2759
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); \
2760
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2761
+
2762
+ CREATE_GLU(geglu)
2763
+ CREATE_GLU(reglu)
2764
+ CREATE_GLU(swiglu)
2765
+ #undef CREATE_GLU
2766
+
2754
2767
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2755
2768
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2756
2769
@@ -6455,6 +6468,24 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6455
6468
break;
6456
6469
}
6457
6470
return nullptr;
6471
+ case GGML_OP_GLU:
6472
+ if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
6473
+ (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
6474
+ (src0->type != dst->type)) {
6475
+ return nullptr;
6476
+ }
6477
+
6478
+ switch (ggml_get_glu_op(dst)) {
6479
+ case GGML_GLU_OP_GEGLU:
6480
+ return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16];
6481
+ case GGML_GLU_OP_REGLU:
6482
+ return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
6483
+ case GGML_GLU_OP_SWIGLU:
6484
+ return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
6485
+ default:
6486
+ break;
6487
+ }
6488
+ return nullptr;
6458
6489
case GGML_OP_DIAG_MASK_INF:
6459
6490
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6460
6491
return ctx->device->pipeline_diag_mask_inf_f32;
@@ -6831,6 +6862,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6831
6862
case GGML_OP_SOFT_MAX_BACK:
6832
6863
case GGML_OP_SUM_ROWS:
6833
6864
case GGML_OP_ARGMAX:
6865
+ case GGML_OP_GLU:
6834
6866
{
6835
6867
const uint32_t nr = ggml_nrows(src0);
6836
6868
if (nr > 262144) {
@@ -7547,6 +7579,14 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
7547
7579
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
7548
7580
}
7549
7581
7582
+ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7583
+ GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
7584
+
7585
+ const uint32_t swapped = (uint32_t)dst->op_params[1];
7586
+
7587
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], swapped, 0.0f, 0.0f }, dryrun);
7588
+ }
7589
+
7550
7590
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7551
7591
int32_t * op_params = (int32_t *)dst->op_params;
7552
7592
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
@@ -8758,6 +8798,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8758
8798
return false;
8759
8799
}
8760
8800
break;
8801
+ case GGML_OP_GLU:
8802
+ switch (ggml_get_glu_op(node)) {
8803
+ case GGML_GLU_OP_GEGLU:
8804
+ case GGML_GLU_OP_REGLU:
8805
+ case GGML_GLU_OP_SWIGLU:
8806
+ break;
8807
+ default:
8808
+ return false;
8809
+ }
8810
+ break;
8761
8811
case GGML_OP_REPEAT:
8762
8812
case GGML_OP_REPEAT_BACK:
8763
8813
case GGML_OP_GET_ROWS:
@@ -8850,6 +8900,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8850
8900
case GGML_OP_RMS_NORM_BACK:
8851
8901
case GGML_OP_L2_NORM:
8852
8902
case GGML_OP_UNARY:
8903
+ case GGML_OP_GLU:
8853
8904
case GGML_OP_DIAG_MASK_INF:
8854
8905
case GGML_OP_SOFT_MAX:
8855
8906
case GGML_OP_SOFT_MAX_BACK:
@@ -8987,6 +9038,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8987
9038
return false;
8988
9039
}
8989
9040
break;
9041
+ case GGML_OP_GLU:
9042
+ switch (ggml_get_glu_op(node)) {
9043
+ case GGML_GLU_OP_GEGLU:
9044
+ case GGML_GLU_OP_REGLU:
9045
+ case GGML_GLU_OP_SWIGLU:
9046
+ ggml_vk_glu(ctx, compute_ctx, src0, node, dryrun);
9047
+ break;
9048
+ default:
9049
+ return false;
9050
+ }
9051
+ break;
8990
9052
case GGML_OP_DIAG_MASK_INF:
8991
9053
ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
8992
9054
@@ -9112,8 +9174,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
9112
9174
if (!ok) {
9113
9175
if (node->op == GGML_OP_UNARY) {
9114
9176
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
9115
- }
9116
- else {
9177
+ } else if (node->op == GGML_OP_GLU) {
9178
+ std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
9179
+ } else {
9117
9180
std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
9118
9181
}
9119
9182
}
@@ -9192,6 +9255,17 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
9192
9255
return false;
9193
9256
}
9194
9257
break;
9258
+ case GGML_OP_GLU:
9259
+ switch (ggml_get_glu_op(tensor)) {
9260
+ case GGML_GLU_OP_GEGLU:
9261
+ case GGML_GLU_OP_REGLU:
9262
+ case GGML_GLU_OP_SWIGLU:
9263
+ buf = tensor->buffer;
9264
+ break;
9265
+ default:
9266
+ return false;
9267
+ }
9268
+ break;
9195
9269
case GGML_OP_MUL_MAT:
9196
9270
case GGML_OP_MUL_MAT_ID:
9197
9271
case GGML_OP_FLASH_ATTN_EXT:
@@ -9976,6 +10050,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9976
10050
return false;
9977
10051
}
9978
10052
break;
10053
+ case GGML_OP_GLU:
10054
+ switch (ggml_get_glu_op(op)) {
10055
+ case GGML_GLU_OP_GEGLU:
10056
+ case GGML_GLU_OP_REGLU:
10057
+ case GGML_GLU_OP_SWIGLU:
10058
+ return ggml_is_contiguous(op->src[0]) &&
10059
+ (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
10060
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
10061
+ (op->src[0]->type == op->type);
10062
+ default:
10063
+ return false;
10064
+ }
10065
+ break;
9979
10066
case GGML_OP_MUL_MAT:
9980
10067
case GGML_OP_MUL_MAT_ID:
9981
10068
{
@@ -10706,6 +10793,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10706
10793
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
10707
10794
GGML_ABORT("fatal error");
10708
10795
}
10796
+ } else if (tensor->op == GGML_OP_GLU) {
10797
+ tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
10709
10798
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
10710
10799
if (src1 == nullptr) {
10711
10800
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
0 commit comments