Skip to content

Commit abf4c2e

Browse files
vulkan: support GGML_OP_SUM
1 parent 5c1d8a9 commit abf4c2e

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5276,6 +5276,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
52765276
return ctx->device->pipeline_argsort_f32;
52775277
}
52785278
return nullptr;
5279+
case GGML_OP_SUM:
52795280
case GGML_OP_SUM_ROWS:
52805281
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
52815282
return ctx->device->pipeline_sum_rows_f32;
@@ -5554,6 +5555,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
55545555
elements = { nr, 1, 1 };
55555556
}
55565557
} break;
5558+
case GGML_OP_SUM:
5559+
// We use GGML_OP_SUM_ROWS with 1 row.
5560+
elements = { 1, 1, 1 };
5561+
break;
55575562
case GGML_OP_GROUP_NORM:
55585563
{
55595564
const uint32_t num_groups = dst->op_params[0];
@@ -6136,6 +6141,10 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
61366141
}, dryrun);
61376142
}
61386143

6144+
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6145+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6146+
}
6147+
61396148
static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
61406149
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
61416150
}
@@ -7029,6 +7038,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
70297038
case GGML_OP_MUL_MAT:
70307039
case GGML_OP_MUL_MAT_ID:
70317040
case GGML_OP_ARGSORT:
7041+
case GGML_OP_SUM:
70327042
case GGML_OP_SUM_ROWS:
70337043
case GGML_OP_IM2COL:
70347044
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -7080,6 +7090,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
70807090
case GGML_OP_SOFT_MAX:
70817091
case GGML_OP_ROPE:
70827092
case GGML_OP_ARGSORT:
7093+
case GGML_OP_SUM:
70837094
case GGML_OP_SUM_ROWS:
70847095
case GGML_OP_IM2COL:
70857096
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -7200,6 +7211,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
72007211
case GGML_OP_ARGSORT:
72017212
ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
72027213

7214+
break;
7215+
case GGML_OP_SUM:
7216+
ggml_vk_sum(ctx, compute_ctx, src0, node, dryrun);
7217+
72037218
break;
72047219
case GGML_OP_SUM_ROWS:
72057220
ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun);
@@ -7314,6 +7329,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
73147329
case GGML_OP_TRANSPOSE:
73157330
case GGML_OP_NONE:
73167331
case GGML_OP_ARGSORT:
7332+
case GGML_OP_SUM:
73177333
case GGML_OP_SUM_ROWS:
73187334
case GGML_OP_IM2COL:
73197335
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -8248,6 +8264,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
82488264
case GGML_OP_DIAG_MASK_INF:
82498265
case GGML_OP_SOFT_MAX:
82508266
case GGML_OP_ARGSORT:
8267+
case GGML_OP_SUM:
82518268
case GGML_OP_SUM_ROWS:
82528269
case GGML_OP_IM2COL:
82538270
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -8819,6 +8836,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
88198836
tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone);
88208837
} else if (tensor->op == GGML_OP_ARGSORT) {
88218838
tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params);
8839+
} else if (tensor->op == GGML_OP_SUM) {
8840+
tensor_clone = ggml_sum(ggml_ctx, src0_clone);
88228841
} else if (tensor->op == GGML_OP_SUM_ROWS) {
88238842
tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone);
88248843
} else if (tensor->op == GGML_OP_IM2COL) {

0 commit comments

Comments
 (0)