@@ -249,6 +249,7 @@ struct vk_device_struct {
249
249
vk_pipeline pipeline_relu_f32;
250
250
vk_pipeline pipeline_leaky_relu_f32;
251
251
vk_pipeline pipeline_tanh_f32;
252
+ vk_pipeline pipeline_sigmoid_f32;
252
253
vk_pipeline pipeline_diag_mask_inf_f32;
253
254
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
254
255
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
@@ -2189,6 +2190,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2189
2190
ggml_vk_create_pipeline (device, device->pipeline_relu_f32 , " relu_f32" , relu_f32_len, relu_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
2190
2191
ggml_vk_create_pipeline (device, device->pipeline_leaky_relu_f32 , " leaky_relu_f32" , leaky_relu_f32_len, leaky_relu_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
2191
2192
ggml_vk_create_pipeline (device, device->pipeline_tanh_f32 , " tanh_f32" , tanh_f32_len, tanh_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
2193
+ ggml_vk_create_pipeline (device, device->pipeline_sigmoid_f32 , " sigmoid_f32" , sigmoid_f32_len, sigmoid_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
2192
2194
2193
2195
ggml_vk_create_pipeline (device, device->pipeline_diag_mask_inf_f32 , " diag_mask_inf_f32" , diag_mask_inf_f32_len, diag_mask_inf_f32_data, " main" , 2 , sizeof (vk_op_diag_mask_push_constants), {1 , 512 , 1 }, {}, 1 , true );
2194
2196
@@ -5342,6 +5344,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5342
5344
return ctx->device ->pipeline_tanh_f32 ;
5343
5345
}
5344
5346
break ;
5347
+ case GGML_UNARY_OP_SIGMOID:
5348
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5349
+ return ctx->device ->pipeline_sigmoid_f32 ;
5350
+ }
5351
+ break ;
5345
5352
default :
5346
5353
break ;
5347
5354
}
@@ -7335,6 +7342,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7335
7342
case GGML_UNARY_OP_GELU_QUICK:
7336
7343
case GGML_UNARY_OP_RELU:
7337
7344
case GGML_UNARY_OP_TANH:
7345
+ case GGML_UNARY_OP_SIGMOID:
7338
7346
break ;
7339
7347
default :
7340
7348
return false ;
@@ -7551,6 +7559,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7551
7559
case GGML_UNARY_OP_GELU_QUICK:
7552
7560
case GGML_UNARY_OP_RELU:
7553
7561
case GGML_UNARY_OP_TANH:
7562
+ case GGML_UNARY_OP_SIGMOID:
7554
7563
ggml_vk_unary (ctx, compute_ctx, src0, node, dryrun);
7555
7564
break ;
7556
7565
default :
@@ -7738,6 +7747,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7738
7747
case GGML_UNARY_OP_GELU_QUICK:
7739
7748
case GGML_UNARY_OP_RELU:
7740
7749
case GGML_UNARY_OP_TANH:
7750
+ case GGML_UNARY_OP_SIGMOID:
7741
7751
buf = tensor->buffer ;
7742
7752
break ;
7743
7753
default :
@@ -8439,6 +8449,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8439
8449
case GGML_UNARY_OP_SILU:
8440
8450
case GGML_UNARY_OP_RELU:
8441
8451
case GGML_UNARY_OP_TANH:
8452
+ case GGML_UNARY_OP_SIGMOID:
8442
8453
return ggml_is_contiguous (op->src [0 ]);
8443
8454
default :
8444
8455
return false ;
@@ -9105,6 +9116,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9105
9116
case GGML_UNARY_OP_TANH:
9106
9117
tensor_clone = ggml_tanh (ggml_ctx, src_clone[0 ]);
9107
9118
break ;
9119
+ case GGML_UNARY_OP_SIGMOID:
9120
+ tensor_clone = ggml_sigmoid (ggml_ctx, src_clone[0 ]);
9121
+ break ;
9108
9122
default :
9109
9123
std::cerr << " Missing vk_check_results OP: " << ggml_op_name (tensor->op ) << std::endl;
9110
9124
GGML_ABORT (" fatal error" );
0 commit comments