Skip to content

Commit 148f586

Browse files
vulkan: implement GGML_OP_SUB
1 parent deb15e3 commit 148f586

File tree

4 files changed

+68
-1
lines changed

4 files changed

+68
-1
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ struct vk_device_struct {
221221
vk_pipeline pipeline_acc_f32;
222222
vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
223223
vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
224+
vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat;
224225
vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
225226
vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
226227
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
@@ -2100,6 +2101,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
21002101

21012102
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
21022103

2104+
ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2105+
ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
21032106
ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
21042107
ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
21052108
ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
@@ -5126,6 +5129,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
51265129
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
51275130
}
51285131
return nullptr;
5132+
case GGML_OP_SUB:
5133+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5134+
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_f32_norepeat : ctx->device->pipeline_sub_f32;
5135+
}
5136+
return nullptr;
51295137
case GGML_OP_MUL:
51305138
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
51315139
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32;
@@ -5330,6 +5338,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
53305338
case GGML_OP_CPY:
53315339
case GGML_OP_GET_ROWS:
53325340
case GGML_OP_ADD:
5341+
case GGML_OP_SUB:
53335342
case GGML_OP_MUL:
53345343
case GGML_OP_DIV:
53355344
case GGML_OP_CONCAT:
@@ -5614,6 +5623,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
56145623
elements = { N * OC * OH * OW, 1, 1};
56155624
} break;
56165625
case GGML_OP_ADD:
5626+
case GGML_OP_SUB:
56175627
case GGML_OP_DIV:
56185628
case GGML_OP_MUL:
56195629
case GGML_OP_SCALE:
@@ -5745,6 +5755,21 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
57455755
}, dryrun);
57465756
}
57475757

5758+
static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
5759+
const uint32_t src0_type_size = ggml_type_size(src0->type);
5760+
const uint32_t src1_type_size = ggml_type_size(src1->type);
5761+
const uint32_t dst_type_size = ggml_type_size(dst->type);
5762+
5763+
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SUB, {
5764+
(uint32_t)ggml_nelements(src0),
5765+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
5766+
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
5767+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
5768+
0,
5769+
0.0f, 0.0f, 0,
5770+
}, dryrun);
5771+
}
5772+
57485773
static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
57495774
const uint32_t src0_type_size = ggml_type_size(src0->type);
57505775
const uint32_t src1_type_size = ggml_type_size(src1->type);
@@ -7029,6 +7054,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
70297054
case GGML_OP_GET_ROWS:
70307055
case GGML_OP_ADD:
70317056
case GGML_OP_ACC:
7057+
case GGML_OP_SUB:
70327058
case GGML_OP_MUL:
70337059
case GGML_OP_DIV:
70347060
case GGML_OP_CONCAT:
@@ -7083,6 +7109,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
70837109
case GGML_OP_ACC:
70847110
case GGML_OP_GET_ROWS:
70857111
case GGML_OP_ADD:
7112+
case GGML_OP_SUB:
70867113
case GGML_OP_MUL:
70877114
case GGML_OP_DIV:
70887115
case GGML_OP_CONCAT:
@@ -7139,6 +7166,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71397166
case GGML_OP_ADD:
71407167
ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
71417168

7169+
break;
7170+
case GGML_OP_SUB:
7171+
ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun);
7172+
71427173
break;
71437174
case GGML_OP_MUL:
71447175
ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7323,6 +7354,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
73237354
case GGML_OP_ADD:
73247355
case GGML_OP_ACC:
73257356
case GGML_OP_GET_ROWS:
7357+
case GGML_OP_SUB:
73267358
case GGML_OP_MUL:
73277359
case GGML_OP_DIV:
73287360
case GGML_OP_CONCAT:
@@ -8271,6 +8303,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
82718303
return ggml_is_contiguous(op->src[0]);
82728304
case GGML_OP_ADD:
82738305
case GGML_OP_ACC:
8306+
case GGML_OP_SUB:
82748307
case GGML_OP_MUL:
82758308
case GGML_OP_DIV:
82768309
case GGML_OP_CONCAT:
@@ -8762,6 +8795,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
87628795
tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
87638796
} else if (tensor->op == GGML_OP_MUL_MAT_ID) {
87648797
tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
8798+
} else if (tensor->op == GGML_OP_SUB) {
8799+
tensor_clone = ggml_sub(ggml_ctx, src0_clone, src1_clone);
87658800
} else if (tensor->op == GGML_OP_MUL) {
87668801
tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone);
87678802
} else if (tensor->op == GGML_OP_DIV) {
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#version 450
2+
3+
#extension GL_EXT_shader_16bit_storage : require
4+
5+
#include "types.comp"
6+
#include "generic_binary_head.comp"
7+
8+
const uint num_threads = 256;
9+
10+
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
11+
12+
void main() {
13+
uint idx = get_idx();
14+
15+
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
16+
const uint num_iter = 2;
17+
18+
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
19+
if (idx >= p.ne) {
20+
continue;
21+
}
22+
uint i00, i01, i02, i03;
23+
get_indices(idx, i00, i01, i02, i03);
24+
25+
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) - FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
26+
27+
idx += num_threads;
28+
}
29+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,8 @@ void process_shaders() {
434434
string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
435435
string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
436436

437+
string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
438+
437439
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
438440

439441
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});

tests/test-backend-ops.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1511,6 +1511,7 @@ struct test_cont : public test_case {
15111511
};
15121512

15131513
// GGML_OP_ADD
1514+
// GGML_OP_SUB
15141515
// GGML_OP_MUL
15151516
// GGML_OP_DIV
15161517
struct test_bin_bcast : public test_case {
@@ -3938,7 +3939,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
39383939
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7}));
39393940

39403941
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
3941-
for (auto op : {ggml_add, ggml_mul, ggml_div}) {
3942+
for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
39423943
test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
39433944
}
39443945
};

0 commit comments

Comments
 (0)