@@ -254,6 +254,7 @@ struct vk_device_struct {
254
254
vk_pipeline pipeline_argsort_f32;
255
255
vk_pipeline pipeline_sum_rows_f32;
256
256
vk_pipeline pipeline_argmax_f32;
257
+ vk_pipeline pipeline_count_equal_i32;
257
258
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
258
259
vk_pipeline pipeline_timestep_embedding_f32;
259
260
vk_pipeline pipeline_pool2d_f32;
@@ -2157,6 +2158,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2157
2158
2158
2159
ggml_vk_create_pipeline (device, device->pipeline_sum_rows_f32 , " sum_rows_f32" , sum_rows_f32_len, sum_rows_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, { device->subgroup_size }, 1 );
2159
2160
2161
+ ggml_vk_create_pipeline (device, device->pipeline_count_equal_i32 , " count_equal_i32" , count_equal_i32_len, count_equal_i32_data, " main" , 3 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 );
2162
+
2160
2163
ggml_vk_create_pipeline (device, device->pipeline_im2col_f32 , " im2col_f32" , im2col_f32_len, im2col_f32_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2161
2164
if (device->float_controls_rte_fp16 ) {
2162
2165
ggml_vk_create_pipeline (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
@@ -5298,6 +5301,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5298
5301
return ctx->device ->pipeline_argmax_f32 ;
5299
5302
}
5300
5303
return nullptr ;
5304
+ case GGML_OP_COUNT_EQUAL:
5305
+ if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) {
5306
+ return ctx->device ->pipeline_count_equal_i32 ;
5307
+ }
5308
+ return nullptr ;
5301
5309
case GGML_OP_IM2COL:
5302
5310
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5303
5311
return ctx->device ->pipeline_im2col_f32 ;
@@ -6187,6 +6195,11 @@ static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, co
6187
6195
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_ARGMAX, { (uint32_t )src0->ne [0 ], 0 , 0 .0f , 0 .0f }, dryrun);
6188
6196
}
6189
6197
6198
+ static void ggml_vk_count_equal (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
6199
+ ggml_backend_tensor_memset (dst, 0 , 0 , ggml_nbytes (dst));
6200
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr , dst, GGML_OP_COUNT_EQUAL, { (uint32_t )ggml_nelements (src0), 0 , 0 .0f , 0 .0f }, dryrun);
6201
+ }
6202
+
6190
6203
static void ggml_vk_im2col (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
6191
6204
const int32_t s0 = dst->op_params [0 ];
6192
6205
const int32_t s1 = dst->op_params [1 ];
@@ -7080,6 +7093,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7080
7093
case GGML_OP_SUM:
7081
7094
case GGML_OP_SUM_ROWS:
7082
7095
case GGML_OP_ARGMAX:
7096
+ case GGML_OP_COUNT_EQUAL:
7083
7097
case GGML_OP_IM2COL:
7084
7098
case GGML_OP_TIMESTEP_EMBEDDING:
7085
7099
case GGML_OP_POOL_2D:
@@ -7134,6 +7148,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7134
7148
case GGML_OP_SUM:
7135
7149
case GGML_OP_SUM_ROWS:
7136
7150
case GGML_OP_ARGMAX:
7151
+ case GGML_OP_COUNT_EQUAL:
7137
7152
case GGML_OP_IM2COL:
7138
7153
case GGML_OP_TIMESTEP_EMBEDDING:
7139
7154
case GGML_OP_POOL_2D:
@@ -7269,6 +7284,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7269
7284
case GGML_OP_ARGMAX:
7270
7285
ggml_vk_argmax (ctx, compute_ctx, src0, node, dryrun);
7271
7286
7287
+ break ;
7288
+ case GGML_OP_COUNT_EQUAL:
7289
+ ggml_vk_count_equal (ctx, compute_ctx, src0, src1, node, dryrun);
7290
+
7272
7291
break ;
7273
7292
case GGML_OP_IM2COL:
7274
7293
ggml_vk_im2col (ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7383,6 +7402,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7383
7402
case GGML_OP_SUM:
7384
7403
case GGML_OP_SUM_ROWS:
7385
7404
case GGML_OP_ARGMAX:
7405
+ case GGML_OP_COUNT_EQUAL:
7386
7406
case GGML_OP_IM2COL:
7387
7407
case GGML_OP_TIMESTEP_EMBEDDING:
7388
7408
case GGML_OP_POOL_2D:
@@ -8320,6 +8340,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8320
8340
case GGML_OP_SUM:
8321
8341
case GGML_OP_SUM_ROWS:
8322
8342
case GGML_OP_ARGMAX:
8343
+ case GGML_OP_COUNT_EQUAL:
8323
8344
case GGML_OP_IM2COL:
8324
8345
case GGML_OP_TIMESTEP_EMBEDDING:
8325
8346
case GGML_OP_POOL_2D:
@@ -8898,6 +8919,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8898
8919
tensor_clone = ggml_sum_rows (ggml_ctx, src0_clone);
8899
8920
} else if (tensor->op == GGML_OP_ARGMAX) {
8900
8921
tensor_clone = ggml_argmax (ggml_ctx, src0_clone);
8922
+ } else if (tensor->op == GGML_OP_COUNT_EQUAL) {
8923
+ tensor_clone = ggml_count_equal (ggml_ctx, src0_clone, src1_clone);
8901
8924
} else if (tensor->op == GGML_OP_IM2COL) {
8902
8925
const int32_t s0 = tensor->op_params [0 ];
8903
8926
const int32_t s1 = tensor->op_params [1 ];
@@ -9017,6 +9040,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
9017
9040
} else if (tensor->type == GGML_TYPE_I32) {
9018
9041
correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3 ] + i2*comp_nb[2 ] + i1*comp_nb[1 ] + i0*comp_nb[0 ]);
9019
9042
result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb [3 ] + i2*tensor->nb [2 ] + i1*tensor->nb [1 ] + i0*tensor->nb [0 ]);
9043
+ } else if (tensor->type == GGML_TYPE_I64) {
9044
+ correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3 ] + i2*comp_nb[2 ] + i1*comp_nb[1 ] + i0*comp_nb[0 ]);
9045
+ result = *(int64_t *) ((char *) tensor_data + i3*tensor->nb [3 ] + i2*tensor->nb [2 ] + i1*tensor->nb [1 ] + i0*tensor->nb [0 ]);
9020
9046
} else {
9021
9047
std::cerr << " Results check not implemented for type " << ggml_type_name (tensor->type ) << std::endl;
9022
9048
}
0 commit comments