Skip to content

Commit e759a9e

Browse files
foldlJudd
authored andcommitted
add OP sigmoid (ggml-org#12056)
Co-authored-by: Judd <[email protected]>
1 parent 28d058a commit e759a9e

File tree

3 files changed

+35
-0
lines changed

3 files changed

+35
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ struct vk_device_struct {
249249
vk_pipeline pipeline_relu_f32;
250250
vk_pipeline pipeline_leaky_relu_f32;
251251
vk_pipeline pipeline_tanh_f32;
252+
vk_pipeline pipeline_sigmoid_f32;
252253
vk_pipeline pipeline_diag_mask_inf_f32;
253254
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
254255
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
@@ -2189,6 +2190,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
21892190
ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
21902191
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
21912192
ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2193+
ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
21922194

21932195
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
21942196

@@ -5342,6 +5344,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
53425344
return ctx->device->pipeline_tanh_f32;
53435345
}
53445346
break;
5347+
case GGML_UNARY_OP_SIGMOID:
5348+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5349+
return ctx->device->pipeline_sigmoid_f32;
5350+
}
5351+
break;
53455352
default:
53465353
break;
53475354
}
@@ -7335,6 +7342,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
73357342
case GGML_UNARY_OP_GELU_QUICK:
73367343
case GGML_UNARY_OP_RELU:
73377344
case GGML_UNARY_OP_TANH:
7345+
case GGML_UNARY_OP_SIGMOID:
73387346
break;
73397347
default:
73407348
return false;
@@ -7551,6 +7559,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
75517559
case GGML_UNARY_OP_GELU_QUICK:
75527560
case GGML_UNARY_OP_RELU:
75537561
case GGML_UNARY_OP_TANH:
7562+
case GGML_UNARY_OP_SIGMOID:
75547563
ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
75557564
break;
75567565
default:
@@ -7738,6 +7747,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
77387747
case GGML_UNARY_OP_GELU_QUICK:
77397748
case GGML_UNARY_OP_RELU:
77407749
case GGML_UNARY_OP_TANH:
7750+
case GGML_UNARY_OP_SIGMOID:
77417751
buf = tensor->buffer;
77427752
break;
77437753
default:
@@ -8439,6 +8449,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
84398449
case GGML_UNARY_OP_SILU:
84408450
case GGML_UNARY_OP_RELU:
84418451
case GGML_UNARY_OP_TANH:
8452+
case GGML_UNARY_OP_SIGMOID:
84428453
return ggml_is_contiguous(op->src[0]);
84438454
default:
84448455
return false;
@@ -9105,6 +9116,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
91059116
case GGML_UNARY_OP_TANH:
91069117
tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
91079118
break;
9119+
case GGML_UNARY_OP_SIGMOID:
9120+
tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);
9121+
break;
91089122
default:
91099123
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
91109124
GGML_ABORT("fatal error");
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12+
13+
void main() {
14+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
15+
16+
if (i >= p.KX) {
17+
return;
18+
}
19+
data_d[i] = D_TYPE(1. / (1 + exp(-1. *data_a[i])));
20+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ void process_shaders() {
482482
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
483483
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
484484
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
485+
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
485486

486487
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
487488

0 commit comments

Comments
 (0)