@@ -252,6 +252,7 @@ struct vk_device_struct {
252
252
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
253
253
vk_pipeline pipeline_argsort_f32;
254
254
vk_pipeline pipeline_sum_rows_f32;
255
+ vk_pipeline pipeline_argmax_f32;
255
256
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
256
257
vk_pipeline pipeline_timestep_embedding_f32;
257
258
vk_pipeline pipeline_pool2d_f32;
@@ -2149,6 +2150,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2149
2150
2150
2151
ggml_vk_create_pipeline (device, device->pipeline_argsort_f32 , " argsort_f32" , argsort_f32_len, argsort_f32_data, " main" , 2 , sizeof (vk_op_argsort_push_constants), {1024 , 1 , 1 }, {}, 1 );
2151
2152
2153
+ ggml_vk_create_pipeline (device, device->pipeline_argmax_f32 , " argmax_f32" , argmax_f32_len, argmax_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, { device->subgroup_size }, 1 );
2154
+
2152
2155
ggml_vk_create_pipeline (device, device->pipeline_sum_rows_f32 , " sum_rows_f32" , sum_rows_f32_len, sum_rows_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, { device->subgroup_size }, 1 );
2153
2156
2154
2157
ggml_vk_create_pipeline (device, device->pipeline_im2col_f32 , " im2col_f32" , im2col_f32_len, im2col_f32_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
@@ -5282,6 +5285,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5282
5285
return ctx->device ->pipeline_sum_rows_f32 ;
5283
5286
}
5284
5287
return nullptr ;
5288
+ case GGML_OP_ARGMAX:
5289
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
5290
+ return ctx->device ->pipeline_argmax_f32 ;
5291
+ }
5292
+ return nullptr ;
5285
5293
case GGML_OP_IM2COL:
5286
5294
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5287
5295
return ctx->device ->pipeline_im2col_f32 ;
@@ -5545,6 +5553,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5545
5553
case GGML_OP_RMS_NORM:
5546
5554
case GGML_OP_SOFT_MAX:
5547
5555
case GGML_OP_SUM_ROWS:
5556
+ case GGML_OP_ARGMAX:
5548
5557
{
5549
5558
const uint32_t nr = ggml_nrows (src0);
5550
5559
if (nr > 262144 ) {
@@ -6149,6 +6158,10 @@ static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
6149
6158
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_SUM_ROWS, { (uint32_t )src0->ne [0 ], 0 , 0 .0f , 0 .0f }, dryrun);
6150
6159
}
6151
6160
6161
+ static void ggml_vk_argmax (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
6162
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_ARGMAX, { (uint32_t )src0->ne [0 ], 0 , 0 .0f , 0 .0f }, dryrun);
6163
+ }
6164
+
6152
6165
static void ggml_vk_im2col (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
6153
6166
const int32_t s0 = dst->op_params [0 ];
6154
6167
const int32_t s1 = dst->op_params [1 ];
@@ -7040,6 +7053,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7040
7053
case GGML_OP_ARGSORT:
7041
7054
case GGML_OP_SUM:
7042
7055
case GGML_OP_SUM_ROWS:
7056
+ case GGML_OP_ARGMAX:
7043
7057
case GGML_OP_IM2COL:
7044
7058
case GGML_OP_TIMESTEP_EMBEDDING:
7045
7059
case GGML_OP_POOL_2D:
@@ -7092,6 +7106,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7092
7106
case GGML_OP_ARGSORT:
7093
7107
case GGML_OP_SUM:
7094
7108
case GGML_OP_SUM_ROWS:
7109
+ case GGML_OP_ARGMAX:
7095
7110
case GGML_OP_IM2COL:
7096
7111
case GGML_OP_TIMESTEP_EMBEDDING:
7097
7112
case GGML_OP_POOL_2D:
@@ -7219,6 +7234,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7219
7234
case GGML_OP_SUM_ROWS:
7220
7235
ggml_vk_sum_rows (ctx, compute_ctx, src0, node, dryrun);
7221
7236
7237
+ break ;
7238
+ case GGML_OP_ARGMAX:
7239
+ ggml_vk_argmax (ctx, compute_ctx, src0, node, dryrun);
7240
+
7222
7241
break ;
7223
7242
case GGML_OP_IM2COL:
7224
7243
ggml_vk_im2col (ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7331,6 +7350,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7331
7350
case GGML_OP_ARGSORT:
7332
7351
case GGML_OP_SUM:
7333
7352
case GGML_OP_SUM_ROWS:
7353
+ case GGML_OP_ARGMAX:
7334
7354
case GGML_OP_IM2COL:
7335
7355
case GGML_OP_TIMESTEP_EMBEDDING:
7336
7356
case GGML_OP_POOL_2D:
@@ -8266,6 +8286,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8266
8286
case GGML_OP_ARGSORT:
8267
8287
case GGML_OP_SUM:
8268
8288
case GGML_OP_SUM_ROWS:
8289
+ case GGML_OP_ARGMAX:
8269
8290
case GGML_OP_IM2COL:
8270
8291
case GGML_OP_TIMESTEP_EMBEDDING:
8271
8292
case GGML_OP_POOL_2D:
@@ -8840,6 +8861,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8840
8861
tensor_clone = ggml_sum (ggml_ctx, src0_clone);
8841
8862
} else if (tensor->op == GGML_OP_SUM_ROWS) {
8842
8863
tensor_clone = ggml_sum_rows (ggml_ctx, src0_clone);
8864
+ } else if (tensor->op == GGML_OP_ARGMAX) {
8865
+ tensor_clone = ggml_argmax (ggml_ctx, src0_clone);
8843
8866
} else if (tensor->op == GGML_OP_IM2COL) {
8844
8867
const int32_t s0 = tensor->op_params [0 ];
8845
8868
const int32_t s1 = tensor->op_params [1 ];
0 commit comments