Skip to content

Commit 095f8d1

Browse files
vulkan: implement GGML_OP_COUNT_EQUAL
1 parent 148f586 commit 095f8d1

File tree

4 files changed

+61
-2
lines changed

4 files changed

+61
-2
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ struct vk_device_struct {
254254
vk_pipeline pipeline_argsort_f32;
255255
vk_pipeline pipeline_sum_rows_f32;
256256
vk_pipeline pipeline_argmax_f32;
257+
vk_pipeline pipeline_count_equal_i32;
257258
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
258259
vk_pipeline pipeline_timestep_embedding_f32;
259260
vk_pipeline pipeline_pool2d_f32;
@@ -2157,6 +2158,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
21572158

21582159
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
21592160

2161+
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
2162+
21602163
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
21612164
if (device->float_controls_rte_fp16) {
21622165
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
@@ -5298,6 +5301,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
52985301
return ctx->device->pipeline_argmax_f32;
52995302
}
53005303
return nullptr;
5304+
case GGML_OP_COUNT_EQUAL:
5305+
if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) {
5306+
return ctx->device->pipeline_count_equal_i32;
5307+
}
5308+
return nullptr;
53015309
case GGML_OP_IM2COL:
53025310
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
53035311
return ctx->device->pipeline_im2col_f32;
@@ -6187,6 +6195,11 @@ static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, co
61876195
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
61886196
}
61896197

6198+
static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6199+
ggml_backend_tensor_memset(dst, 0, 0, ggml_nbytes(dst));
6200+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6201+
}
6202+
61906203
static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
61916204
const int32_t s0 = dst->op_params[0];
61926205
const int32_t s1 = dst->op_params[1];
@@ -7080,6 +7093,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
70807093
case GGML_OP_SUM:
70817094
case GGML_OP_SUM_ROWS:
70827095
case GGML_OP_ARGMAX:
7096+
case GGML_OP_COUNT_EQUAL:
70837097
case GGML_OP_IM2COL:
70847098
case GGML_OP_TIMESTEP_EMBEDDING:
70857099
case GGML_OP_POOL_2D:
@@ -7134,6 +7148,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71347148
case GGML_OP_SUM:
71357149
case GGML_OP_SUM_ROWS:
71367150
case GGML_OP_ARGMAX:
7151+
case GGML_OP_COUNT_EQUAL:
71377152
case GGML_OP_IM2COL:
71387153
case GGML_OP_TIMESTEP_EMBEDDING:
71397154
case GGML_OP_POOL_2D:
@@ -7269,6 +7284,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
72697284
case GGML_OP_ARGMAX:
72707285
ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun);
72717286

7287+
break;
7288+
case GGML_OP_COUNT_EQUAL:
7289+
ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node, dryrun);
7290+
72727291
break;
72737292
case GGML_OP_IM2COL:
72747293
ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7383,6 +7402,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
73837402
case GGML_OP_SUM:
73847403
case GGML_OP_SUM_ROWS:
73857404
case GGML_OP_ARGMAX:
7405+
case GGML_OP_COUNT_EQUAL:
73867406
case GGML_OP_IM2COL:
73877407
case GGML_OP_TIMESTEP_EMBEDDING:
73887408
case GGML_OP_POOL_2D:
@@ -8320,6 +8340,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
83208340
case GGML_OP_SUM:
83218341
case GGML_OP_SUM_ROWS:
83228342
case GGML_OP_ARGMAX:
8343+
case GGML_OP_COUNT_EQUAL:
83238344
case GGML_OP_IM2COL:
83248345
case GGML_OP_TIMESTEP_EMBEDDING:
83258346
case GGML_OP_POOL_2D:
@@ -8898,6 +8919,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
88988919
tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone);
88998920
} else if (tensor->op == GGML_OP_ARGMAX) {
89008921
tensor_clone = ggml_argmax(ggml_ctx, src0_clone);
8922+
} else if (tensor->op == GGML_OP_COUNT_EQUAL) {
8923+
tensor_clone = ggml_count_equal(ggml_ctx, src0_clone, src1_clone);
89018924
} else if (tensor->op == GGML_OP_IM2COL) {
89028925
const int32_t s0 = tensor->op_params[0];
89038926
const int32_t s1 = tensor->op_params[1];
@@ -9017,6 +9040,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
90179040
} else if (tensor->type == GGML_TYPE_I32) {
90189041
correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
90199042
result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
9043+
} else if (tensor->type == GGML_TYPE_I64) {
9044+
correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
9045+
result = *(int64_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
90209046
} else {
90219047
std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl;
90229048
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#version 450
2+
3+
#extension GL_EXT_control_flow_attributes : enable
4+
5+
#include "types.comp"
6+
#include "generic_head.comp"
7+
8+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
12+
layout (binding = 2) buffer D {D_TYPE data_d[];};
13+
14+
const uint CHUNK_SIZE = 512;
15+
16+
void main() {
17+
const uint base = gl_WorkGroupID.x * CHUNK_SIZE;
18+
const uint col = gl_LocalInvocationID.x;
19+
20+
uint count = 0;
21+
[[unroll]]
22+
for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) {
23+
const uint idx = base + i + col;
24+
if (idx >= p.KX) {
25+
break;
26+
}
27+
count += uint(data_a[idx] == data_b[idx]);
28+
}
29+
30+
atomicAdd(data_d[0], D_TYPE(count));
31+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ void process_shaders() {
488488

489489
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
490490
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
491+
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
491492

492493
string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
493494
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));

tests/test-backend-ops.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,7 +1254,7 @@ struct test_count_equal : public test_case {
12541254
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
12551255
ggml_set_name(b, "b");
12561256

1257-
ggml_tensor * b_argmax = ggml_argmax(ctx, a);
1257+
ggml_tensor * b_argmax = ggml_argmax(ctx, b);
12581258
ggml_set_name(b_argmax, "b_argmax");
12591259

12601260
ggml_tensor * out = ggml_count_equal(ctx, a_argmax, b_argmax);
@@ -3861,7 +3861,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
38613861
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
38623862
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
38633863

3864-
test_cases.emplace_back(new test_count_equal());
3864+
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1}));
3865+
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
38653866

38663867
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
38673868
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));

0 commit comments

Comments
 (0)