Skip to content

Commit 34d1aed

Browse files
0cc4mqnixsynapse
authored andcommitted
Vulkan: Add GLU ops and shaders
1 parent d593429 commit 34d1aed

File tree

5 files changed

+215
-2
lines changed

5 files changed

+215
-2
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,10 @@ struct vk_device_struct {
436436
vk_pipeline pipeline_tanh[2];
437437
vk_pipeline pipeline_sigmoid[2];
438438

439+
vk_pipeline pipeline_geglu[2];
440+
vk_pipeline pipeline_reglu[2];
441+
vk_pipeline pipeline_swiglu[2];
442+
439443
vk_pipeline pipeline_leaky_relu_f32;
440444
vk_pipeline pipeline_silu_back_f32;
441445
vk_pipeline pipeline_diag_mask_inf_f32;
@@ -2751,6 +2755,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
27512755
CREATE_UNARY(sigmoid)
27522756
#undef CREATE_UNARY
27532757

2758+
#define CREATE_GLU(name) \
2759+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); \
2760+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2761+
2762+
CREATE_GLU(geglu)
2763+
CREATE_GLU(reglu)
2764+
CREATE_GLU(swiglu)
2765+
#undef CREATE_GLU
2766+
27542767
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
27552768
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
27562769

@@ -6455,6 +6468,24 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
64556468
break;
64566469
}
64576470
return nullptr;
6471+
case GGML_OP_GLU:
6472+
if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
6473+
(dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
6474+
(src0->type != dst->type)) {
6475+
return nullptr;
6476+
}
6477+
6478+
switch (ggml_get_glu_op(dst)) {
6479+
case GGML_GLU_OP_GEGLU:
6480+
return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16];
6481+
case GGML_GLU_OP_REGLU:
6482+
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
6483+
case GGML_GLU_OP_SWIGLU:
6484+
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
6485+
default:
6486+
break;
6487+
}
6488+
return nullptr;
64586489
case GGML_OP_DIAG_MASK_INF:
64596490
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
64606491
return ctx->device->pipeline_diag_mask_inf_f32;
@@ -6831,6 +6862,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
68316862
case GGML_OP_SOFT_MAX_BACK:
68326863
case GGML_OP_SUM_ROWS:
68336864
case GGML_OP_ARGMAX:
6865+
case GGML_OP_GLU:
68346866
{
68356867
const uint32_t nr = ggml_nrows(src0);
68366868
if (nr > 262144) {
@@ -7547,6 +7579,14 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
75477579
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
75487580
}
75497581

7582+
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7583+
GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
7584+
7585+
const uint32_t swapped = (uint32_t)dst->op_params[1];
7586+
7587+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], swapped, 0.0f, 0.0f }, dryrun);
7588+
}
7589+
75507590
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
75517591
int32_t * op_params = (int32_t *)dst->op_params;
75527592
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
@@ -8758,6 +8798,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
87588798
return false;
87598799
}
87608800
break;
8801+
case GGML_OP_GLU:
8802+
switch (ggml_get_glu_op(node)) {
8803+
case GGML_GLU_OP_GEGLU:
8804+
case GGML_GLU_OP_REGLU:
8805+
case GGML_GLU_OP_SWIGLU:
8806+
break;
8807+
default:
8808+
return false;
8809+
}
8810+
break;
87618811
case GGML_OP_REPEAT:
87628812
case GGML_OP_REPEAT_BACK:
87638813
case GGML_OP_GET_ROWS:
@@ -8850,6 +8900,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
88508900
case GGML_OP_RMS_NORM_BACK:
88518901
case GGML_OP_L2_NORM:
88528902
case GGML_OP_UNARY:
8903+
case GGML_OP_GLU:
88538904
case GGML_OP_DIAG_MASK_INF:
88548905
case GGML_OP_SOFT_MAX:
88558906
case GGML_OP_SOFT_MAX_BACK:
@@ -8987,6 +9038,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
89879038
return false;
89889039
}
89899040
break;
9041+
case GGML_OP_GLU:
9042+
switch (ggml_get_glu_op(node)) {
9043+
case GGML_GLU_OP_GEGLU:
9044+
case GGML_GLU_OP_REGLU:
9045+
case GGML_GLU_OP_SWIGLU:
9046+
ggml_vk_glu(ctx, compute_ctx, src0, node, dryrun);
9047+
break;
9048+
default:
9049+
return false;
9050+
}
9051+
break;
89909052
case GGML_OP_DIAG_MASK_INF:
89919053
ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
89929054

@@ -9112,8 +9174,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
91129174
if (!ok) {
91139175
if (node->op == GGML_OP_UNARY) {
91149176
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
9115-
}
9116-
else {
9177+
} else if (node->op == GGML_OP_GLU) {
9178+
std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
9179+
} else {
91179180
std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
91189181
}
91199182
}
@@ -9192,6 +9255,17 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
91929255
return false;
91939256
}
91949257
break;
9258+
case GGML_OP_GLU:
9259+
switch (ggml_get_glu_op(tensor)) {
9260+
case GGML_GLU_OP_GEGLU:
9261+
case GGML_GLU_OP_REGLU:
9262+
case GGML_GLU_OP_SWIGLU:
9263+
buf = tensor->buffer;
9264+
break;
9265+
default:
9266+
return false;
9267+
}
9268+
break;
91959269
case GGML_OP_MUL_MAT:
91969270
case GGML_OP_MUL_MAT_ID:
91979271
case GGML_OP_FLASH_ATTN_EXT:
@@ -9976,6 +10050,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
997610050
return false;
997710051
}
997810052
break;
10053+
case GGML_OP_GLU:
10054+
switch (ggml_get_glu_op(op)) {
10055+
case GGML_GLU_OP_GEGLU:
10056+
case GGML_GLU_OP_REGLU:
10057+
case GGML_GLU_OP_SWIGLU:
10058+
return ggml_is_contiguous(op->src[0]) &&
10059+
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
10060+
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
10061+
(op->src[0]->type == op->type);
10062+
default:
10063+
return false;
10064+
}
10065+
break;
997910066
case GGML_OP_MUL_MAT:
998010067
case GGML_OP_MUL_MAT_ID:
998110068
{
@@ -10706,6 +10793,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1070610793
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
1070710794
GGML_ABORT("fatal error");
1070810795
}
10796+
} else if (tensor->op == GGML_OP_GLU) {
10797+
tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
1070910798
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
1071010799
if (src1 == nullptr) {
1071110800
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12+
13+
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
14+
15+
void main() {
16+
const float GELU_COEF_A = 0.044715f;
17+
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
18+
19+
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
20+
const uint col = gl_LocalInvocationID.x;
21+
22+
const uint offset = p.KX / 2;
23+
24+
const bool swapped = p.KY > 0;
25+
26+
if (!swapped) {
27+
for (uint i = col; i < offset; i += BLOCK_SIZE) {
28+
const uint idx = row * p.KX + i;
29+
30+
const float xi = float(data_a[idx]);
31+
const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi);
32+
data_d[row * offset + i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)) * float(data_a[idx + offset]));
33+
}
34+
} else {
35+
for (uint i = col; i < offset; i += BLOCK_SIZE) {
36+
const uint idx = row * p.KX + i;
37+
38+
const float xi = float(data_a[idx + offset]);
39+
const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi);
40+
data_d[row * offset + i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)) * float(data_a[idx]));
41+
}
42+
}
43+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12+
13+
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
14+
15+
void main() {
16+
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
17+
const uint col = gl_LocalInvocationID.x;
18+
19+
const uint offset = p.KX / 2;
20+
21+
const bool swapped = p.KY > 0;
22+
23+
if (!swapped) {
24+
for (uint i = col; i < offset; i += BLOCK_SIZE) {
25+
const uint idx = row * p.KX + i;
26+
27+
data_d[row * offset + i] = D_TYPE(max(float(data_a[idx]), 0.0f) * float(data_a[idx + offset]));
28+
}
29+
} else {
30+
for (uint i = col; i < offset; i += BLOCK_SIZE) {
31+
const uint idx = row * p.KX + i;
32+
33+
data_d[row * offset + i] = D_TYPE(max(float(data_a[idx + offset]), 0.0f) * float(data_a[idx]));
34+
}
35+
}
36+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12+
13+
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
14+
15+
void main() {
16+
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
17+
const uint col = gl_LocalInvocationID.x;
18+
19+
const uint offset = p.KX / 2;
20+
21+
const bool swapped = p.KY > 0;
22+
23+
if (!swapped) {
24+
for (uint i = col; i < offset; i += BLOCK_SIZE) {
25+
const uint idx = row * p.KX + i;
26+
27+
const float xi = float(data_a[idx]);
28+
data_d[row * offset + i] = D_TYPE(xi / (1.0f + exp(-xi)) * float(data_a[idx + offset]));
29+
}
30+
} else {
31+
for (uint i = col; i < offset; i += BLOCK_SIZE) {
32+
const uint idx = row * p.KX + i;
33+
34+
const float xi = float(data_a[idx + offset]);
35+
data_d[row * offset + i] = D_TYPE(xi / (1.0f + exp(-xi)) * float(data_a[idx]));
36+
}
37+
}
38+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,13 @@ void process_shaders() {
585585
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
586586
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
587587

588+
string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
589+
string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
590+
string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
591+
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
592+
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
593+
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
594+
588595
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
589596
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
590597

0 commit comments

Comments
 (0)