Skip to content

Commit 4908031

Browse files
vulkan: implement GGML_OP_REPEAT_BACK
1 parent 6f58276 commit 4908031

File tree

3 files changed

+72
-1
lines changed

3 files changed

+72
-1
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ struct vk_device_struct {
233233
vk_pipeline pipeline_cos_f32;
234234
vk_pipeline pipeline_clamp_f32;
235235
vk_pipeline pipeline_pad_f32;
236-
vk_pipeline pipeline_repeat_f32;
236+
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
237237
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
238238
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
239239
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
@@ -2175,6 +2175,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
21752175
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
21762176

21772177
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2178+
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
21782179

21792180
ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
21802181
ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@@ -5267,6 +5268,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
52675268
return ctx->device->pipeline_repeat_f32;
52685269
}
52695270
return nullptr;
5271+
case GGML_OP_REPEAT_BACK:
5272+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5273+
return ctx->device->pipeline_repeat_back_f32;
5274+
}
5275+
return nullptr;
52705276
case GGML_OP_CPY:
52715277
case GGML_OP_CONT:
52725278
case GGML_OP_DUP:
@@ -5447,6 +5453,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
54475453
case GGML_OP_CLAMP:
54485454
case GGML_OP_PAD:
54495455
case GGML_OP_REPEAT:
5456+
case GGML_OP_REPEAT_BACK:
54505457
case GGML_OP_ROPE:
54515458
return true;
54525459
default:
@@ -5732,6 +5739,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
57325739
case GGML_OP_CLAMP:
57335740
case GGML_OP_PAD:
57345741
case GGML_OP_REPEAT:
5742+
case GGML_OP_REPEAT_BACK:
57355743
case GGML_OP_CPY:
57365744
case GGML_OP_CONCAT:
57375745
case GGML_OP_UPSCALE:
@@ -6265,6 +6273,20 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co
62656273
}, dryrun);
62666274
}
62676275

6276+
static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6277+
const uint32_t src0_type_size = ggml_type_size(src0->type);
6278+
const uint32_t dst_type_size = ggml_type_size(dst->type);
6279+
6280+
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
6281+
(uint32_t)ggml_nelements(dst),
6282+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
6283+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
6284+
0,
6285+
0.0f, 0.0f,
6286+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
6287+
}, dryrun);
6288+
}
6289+
62686290
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
62696291
const uint32_t src0_type_size = ggml_type_size(src0->type);
62706292
const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -7268,6 +7290,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
72687290
}
72697291
break;
72707292
case GGML_OP_REPEAT:
7293+
case GGML_OP_REPEAT_BACK:
72717294
case GGML_OP_GET_ROWS:
72727295
case GGML_OP_ADD:
72737296
case GGML_OP_ACC:
@@ -7325,6 +7348,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
73257348
} else {
73267349
switch (node->op) {
73277350
case GGML_OP_REPEAT:
7351+
case GGML_OP_REPEAT_BACK:
73287352
case GGML_OP_ACC:
73297353
case GGML_OP_GET_ROWS:
73307354
case GGML_OP_ADD:
@@ -7374,6 +7398,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
73747398
case GGML_OP_REPEAT:
73757399
ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun);
73767400

7401+
break;
7402+
case GGML_OP_REPEAT_BACK:
7403+
ggml_vk_repeat_back(ctx, compute_ctx, src0, node, dryrun);
7404+
73777405
break;
73787406
case GGML_OP_ACC:
73797407
ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7619,6 +7647,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
76197647
case GGML_OP_RWKV_WKV6:
76207648
case GGML_OP_LEAKY_RELU:
76217649
case GGML_OP_REPEAT:
7650+
case GGML_OP_REPEAT_BACK:
76227651
case GGML_OP_OPT_STEP_ADAMW:
76237652
buf = tensor->buffer;
76247653

@@ -8517,6 +8546,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
85178546
} break;
85188547
case GGML_OP_REPEAT:
85198548
return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
8549+
case GGML_OP_REPEAT_BACK:
8550+
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
85208551
case GGML_OP_ROPE:
85218552
case GGML_OP_NONE:
85228553
case GGML_OP_RESHAPE:
@@ -8922,6 +8953,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
89228953
tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
89238954
} else if (tensor->op == GGML_OP_REPEAT) {
89248955
tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor);
8956+
} else if (tensor->op == GGML_OP_REPEAT_BACK) {
8957+
tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor);
89258958
} else if (tensor->op == GGML_OP_ADD) {
89268959
tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
89278960
} else if (tensor->op == GGML_OP_ACC) {
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#version 450
2+
3+
#include "types.comp"
4+
#include "generic_unary_head.comp"
5+
6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
8+
void main() {
9+
const uint idx = get_idx();
10+
11+
if (idx >= p.ne) {
12+
return;
13+
}
14+
15+
// Destination multi-index (inlined dst_idx)
16+
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
17+
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
18+
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
19+
const uint i12_offset = i12*p.ne11*p.ne10;
20+
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
21+
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
22+
const uint d_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
23+
24+
// Accumulate from sources
25+
A_TYPE acc = A_TYPE(0);
26+
for (uint i3 = i13; i3 < p.ne03; i3 += p.ne13) {
27+
for (uint i2 = i12; i2 < p.ne02; i2 += p.ne12) {
28+
for (uint i1 = i11; i1 < p.ne01; i1 += p.ne11) {
29+
for (uint i0 = i10; i0 < p.ne00; i0 += p.ne10) {
30+
acc += data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00];
31+
}
32+
}
33+
}
34+
}
35+
36+
data_d[get_doffset() + d_idx] = D_TYPE(acc);
37+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ void process_shaders() {
454454
string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
455455

456456
string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
457+
string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
457458

458459
string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
459460

0 commit comments

Comments
 (0)