@@ -213,6 +213,7 @@ struct vk_device_struct {
213
213
vk_pipeline pipeline_sum_rows_f32;
214
214
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
215
215
vk_pipeline pipeline_timestep_embedding_f32;
216
+ vk_pipeline pipeline_pool2d_f32;
216
217
217
218
std::unordered_map<std::string, vk_pipeline_ref> pipelines;
218
219
std::unordered_map<std::string, uint64_t > pipeline_descriptor_set_requirements;
@@ -403,6 +404,17 @@ struct vk_op_timestep_embedding_push_constants {
403
404
uint32_t max_period;
404
405
};
405
406
407
+ struct vk_op_pool2d_push_constants {
408
+ uint32_t IW; uint32_t IH;
409
+ uint32_t OW; uint32_t OH;
410
+ uint32_t OC;
411
+ uint32_t pelements;
412
+ uint32_t op;
413
+ int32_t k0; int32_t k1;
414
+ int32_t s0; int32_t s1;
415
+ int32_t p0; int32_t p1;
416
+ };
417
+
406
418
// Allow pre-recording command buffers
407
419
struct vk_staging_memcpy {
408
420
vk_staging_memcpy (void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -1803,6 +1815,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
1803
1815
1804
1816
ggml_vk_create_pipeline (device, device->pipeline_timestep_embedding_f32 , " timestep_embedding_f32" , timestep_embedding_f32_len, timestep_embedding_f32_data, " main" , 2 , sizeof (vk_op_timestep_embedding_push_constants), {256 , 1 , 1 }, {}, 1 );
1805
1817
1818
+ ggml_vk_create_pipeline (device, device->pipeline_pool2d_f32 , " pool2d_f32" , pool2d_f32_len, pool2d_f32_data, " main" , 2 , sizeof (vk_op_pool2d_push_constants), {512 , 1 , 1 }, {}, 1 );
1819
+
1806
1820
for (auto &c : compiles) {
1807
1821
c.wait ();
1808
1822
}
@@ -4234,6 +4248,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
4234
4248
return ctx->device ->pipeline_timestep_embedding_f32 ;
4235
4249
}
4236
4250
return nullptr ;
4251
+ case GGML_OP_POOL_2D:
4252
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
4253
+ return ctx->device ->pipeline_pool2d_f32 ;
4254
+ }
4255
+ return nullptr ;
4237
4256
case GGML_OP_LEAKY_RELU:
4238
4257
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
4239
4258
return ctx->device ->pipeline_leaky_relu_f32 ;
@@ -4464,6 +4483,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
4464
4483
uint32_t half_ceil = (dim + 1 ) / 2 ;
4465
4484
elements = { half_ceil, (uint32_t )src0->ne [0 ], 1 };
4466
4485
} break ;
4486
+ case GGML_OP_POOL_2D:
4487
+ {
4488
+ const uint32_t N = dst->ne [3 ];
4489
+ const uint32_t OC = dst->ne [2 ];
4490
+ const uint32_t OH = dst->ne [1 ];
4491
+ const uint32_t OW = dst->ne [0 ];
4492
+ elements = { N * OC * OH * OW, 1 , 1 };
4493
+ } break ;
4467
4494
case GGML_OP_ADD:
4468
4495
case GGML_OP_DIV:
4469
4496
case GGML_OP_MUL:
@@ -4914,6 +4941,34 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context
4914
4941
}, dryrun);
4915
4942
}
4916
4943
4944
+ static void ggml_vk_pool_2d (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
4945
+ uint32_t op = static_cast <uint32_t >(dst->op_params [0 ]);
4946
+ const int32_t k1 = dst->op_params [1 ];
4947
+ const int32_t k0 = dst->op_params [2 ];
4948
+ const int32_t s1 = dst->op_params [3 ];
4949
+ const int32_t s0 = dst->op_params [4 ];
4950
+ const int32_t p1 = dst->op_params [5 ];
4951
+ const int32_t p0 = dst->op_params [6 ];
4952
+
4953
+ const uint32_t IH = src0->ne [1 ];
4954
+ const uint32_t IW = src0->ne [0 ];
4955
+
4956
+ const uint32_t N = dst->ne [3 ];
4957
+
4958
+ const uint32_t OC = dst->ne [2 ];
4959
+ const uint32_t OH = dst->ne [1 ];
4960
+ const uint32_t OW = dst->ne [0 ];
4961
+
4962
+ const uint32_t parallel_elements = N * OC * OH * OW;
4963
+
4964
+ ggml_vk_op_f32<vk_op_pool2d_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_POOL_2D, {
4965
+ IW, IH, OW, OH, OC,
4966
+ parallel_elements,
4967
+ op,
4968
+ k0, k1, s0, s1, p0, p1,
4969
+ }, dryrun);
4970
+ }
4971
+
4917
4972
static void ggml_vk_leaky_relu (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
4918
4973
const float * op_params = (const float *)dst->op_params ;
4919
4974
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_LEAKY_RELU, { (uint32_t )ggml_nelements (src0), 0 , op_params[0 ], 0 .0f }, dryrun);
@@ -5792,6 +5847,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
5792
5847
case GGML_OP_SUM_ROWS:
5793
5848
case GGML_OP_IM2COL:
5794
5849
case GGML_OP_TIMESTEP_EMBEDDING:
5850
+ case GGML_OP_POOL_2D:
5795
5851
case GGML_OP_LEAKY_RELU:
5796
5852
break ;
5797
5853
default :
@@ -5927,6 +5983,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
5927
5983
case GGML_OP_TIMESTEP_EMBEDDING:
5928
5984
ggml_vk_timestep_embedding (ctx, compute_ctx, src0, node, dryrun);
5929
5985
5986
+ break ;
5987
+ case GGML_OP_POOL_2D:
5988
+ ggml_vk_pool_2d (ctx, compute_ctx, src0, node, dryrun);
5989
+
5930
5990
break ;
5931
5991
case GGML_OP_LEAKY_RELU:
5932
5992
ggml_vk_leaky_relu (ctx, compute_ctx, src0, node, dryrun);
@@ -6018,6 +6078,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
6018
6078
case GGML_OP_SUM_ROWS:
6019
6079
case GGML_OP_IM2COL:
6020
6080
case GGML_OP_TIMESTEP_EMBEDDING:
6081
+ case GGML_OP_POOL_2D:
6021
6082
case GGML_OP_LEAKY_RELU:
6022
6083
case GGML_OP_REPEAT:
6023
6084
buf = tensor->buffer ;
@@ -6821,6 +6882,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
6821
6882
case GGML_OP_SUM_ROWS:
6822
6883
case GGML_OP_IM2COL:
6823
6884
case GGML_OP_TIMESTEP_EMBEDDING:
6885
+ case GGML_OP_POOL_2D:
6824
6886
case GGML_OP_LEAKY_RELU:
6825
6887
return true ;
6826
6888
default :
@@ -7334,6 +7396,16 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
7334
7396
const int32_t dim = tensor->op_params [0 ];
7335
7397
const int32_t max_period = tensor->op_params [1 ];
7336
7398
tensor_clone = ggml_timestep_embedding (ggml_ctx, src0_clone, dim, max_period);
7399
+ } else if (tensor->op == GGML_OP_POOL_2D) {
7400
+ enum ggml_op_pool op = static_cast <ggml_op_pool>(dst->op_params [0 ]);
7401
+ const int32_t k0 = tensor->op_params [1 ];
7402
+ const int32_t k1 = tensor->op_params [2 ];
7403
+ const int32_t s0 = tensor->op_params [3 ];
7404
+ const int32_t s1 = tensor->op_params [4 ];
7405
+ const int32_t p0 = tensor->op_params [5 ];
7406
+ const int32_t p1 = tensor->op_params [6 ];
7407
+
7408
+ tensor_clone = ggml_pool_2d (ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1);
7337
7409
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
7338
7410
const float * op_params = (const float *)tensor->op_params ;
7339
7411
tensor_clone = ggml_leaky_relu (ggml_ctx, src0_clone, op_params[0 ], false );
0 commit comments