@@ -5276,6 +5276,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5276
5276
return ctx->device ->pipeline_argsort_f32 ;
5277
5277
}
5278
5278
return nullptr ;
5279
+ case GGML_OP_SUM:
5279
5280
case GGML_OP_SUM_ROWS:
5280
5281
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5281
5282
return ctx->device ->pipeline_sum_rows_f32 ;
@@ -5554,6 +5555,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5554
5555
elements = { nr, 1 , 1 };
5555
5556
}
5556
5557
} break ;
5558
+ case GGML_OP_SUM:
5559
+ // We use GGML_OP_SUM_ROWS with 1 row.
5560
+ elements = { 1 , 1 , 1 };
5561
+ break ;
5557
5562
case GGML_OP_GROUP_NORM:
5558
5563
{
5559
5564
const uint32_t num_groups = dst->op_params [0 ];
@@ -6136,6 +6141,10 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
6136
6141
}, dryrun);
6137
6142
}
6138
6143
6144
+ static void ggml_vk_sum (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
6145
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_SUM, { (uint32_t )ggml_nelements (src0), 0 , 0 .0f , 0 .0f }, dryrun);
6146
+ }
6147
+
6139
6148
static void ggml_vk_sum_rows (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
6140
6149
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_SUM_ROWS, { (uint32_t )src0->ne [0 ], 0 , 0 .0f , 0 .0f }, dryrun);
6141
6150
}
@@ -7029,6 +7038,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7029
7038
case GGML_OP_MUL_MAT:
7030
7039
case GGML_OP_MUL_MAT_ID:
7031
7040
case GGML_OP_ARGSORT:
7041
+ case GGML_OP_SUM:
7032
7042
case GGML_OP_SUM_ROWS:
7033
7043
case GGML_OP_IM2COL:
7034
7044
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -7080,6 +7090,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7080
7090
case GGML_OP_SOFT_MAX:
7081
7091
case GGML_OP_ROPE:
7082
7092
case GGML_OP_ARGSORT:
7093
+ case GGML_OP_SUM:
7083
7094
case GGML_OP_SUM_ROWS:
7084
7095
case GGML_OP_IM2COL:
7085
7096
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -7200,6 +7211,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7200
7211
case GGML_OP_ARGSORT:
7201
7212
ggml_vk_argsort (ctx, compute_ctx, src0, node, dryrun);
7202
7213
7214
+ break ;
7215
+ case GGML_OP_SUM:
7216
+ ggml_vk_sum (ctx, compute_ctx, src0, node, dryrun);
7217
+
7203
7218
break ;
7204
7219
case GGML_OP_SUM_ROWS:
7205
7220
ggml_vk_sum_rows (ctx, compute_ctx, src0, node, dryrun);
@@ -7314,6 +7329,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7314
7329
case GGML_OP_TRANSPOSE:
7315
7330
case GGML_OP_NONE:
7316
7331
case GGML_OP_ARGSORT:
7332
+ case GGML_OP_SUM:
7317
7333
case GGML_OP_SUM_ROWS:
7318
7334
case GGML_OP_IM2COL:
7319
7335
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -8248,6 +8264,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8248
8264
case GGML_OP_DIAG_MASK_INF:
8249
8265
case GGML_OP_SOFT_MAX:
8250
8266
case GGML_OP_ARGSORT:
8267
+ case GGML_OP_SUM:
8251
8268
case GGML_OP_SUM_ROWS:
8252
8269
case GGML_OP_IM2COL:
8253
8270
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -8819,6 +8836,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8819
8836
tensor_clone = ggml_get_rows (ggml_ctx, src0_clone, src1_clone);
8820
8837
} else if (tensor->op == GGML_OP_ARGSORT) {
8821
8838
tensor_clone = ggml_argsort (ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params );
8839
+ } else if (tensor->op == GGML_OP_SUM) {
8840
+ tensor_clone = ggml_sum (ggml_ctx, src0_clone);
8822
8841
} else if (tensor->op == GGML_OP_SUM_ROWS) {
8823
8842
tensor_clone = ggml_sum_rows (ggml_ctx, src0_clone);
8824
8843
} else if (tensor->op == GGML_OP_IM2COL) {
0 commit comments