Skip to content

Commit bc34976

Browse files
vulkan: implement GGML_OP_REPEAT_BACK
1 parent e6a2c06 commit bc34976

File tree

3 files changed

+72
-1
lines changed

3 files changed

+72
-1
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ struct vk_device_struct {
232232
vk_pipeline pipeline_cos_f32;
233233
vk_pipeline pipeline_clamp_f32;
234234
vk_pipeline pipeline_pad_f32;
235-
vk_pipeline pipeline_repeat_f32;
235+
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
236236
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
237237
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
238238
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
@@ -2127,6 +2127,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
21272127
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
21282128

21292129
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2130+
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
21302131

21312132
ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
21322133
ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@@ -5201,6 +5202,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
52015202
return ctx->device->pipeline_repeat_f32;
52025203
}
52035204
return nullptr;
5205+
case GGML_OP_REPEAT_BACK:
5206+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5207+
return ctx->device->pipeline_repeat_back_f32;
5208+
}
5209+
return nullptr;
52045210
case GGML_OP_CPY:
52055211
case GGML_OP_CONT:
52065212
case GGML_OP_DUP:
@@ -5365,6 +5371,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
53655371
case GGML_OP_CLAMP:
53665372
case GGML_OP_PAD:
53675373
case GGML_OP_REPEAT:
5374+
case GGML_OP_REPEAT_BACK:
53685375
return true;
53695376
default:
53705377
return false;
@@ -5649,6 +5656,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
56495656
case GGML_OP_CLAMP:
56505657
case GGML_OP_PAD:
56515658
case GGML_OP_REPEAT:
5659+
case GGML_OP_REPEAT_BACK:
56525660
case GGML_OP_CPY:
56535661
case GGML_OP_CONCAT:
56545662
case GGML_OP_UPSCALE:
@@ -6182,6 +6190,20 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co
61826190
}, dryrun);
61836191
}
61846192

6193+
static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6194+
const uint32_t src0_type_size = ggml_type_size(src0->type);
6195+
const uint32_t dst_type_size = ggml_type_size(dst->type);
6196+
6197+
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, {
6198+
(uint32_t)ggml_nelements(dst),
6199+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
6200+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
6201+
0,
6202+
0.0f, 0.0f,
6203+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
6204+
}, dryrun);
6205+
}
6206+
61856207
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
61866208
const uint32_t src0_type_size = ggml_type_size(src0->type);
61876209
const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -7177,6 +7199,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71777199
}
71787200
break;
71797201
case GGML_OP_REPEAT:
7202+
case GGML_OP_REPEAT_BACK:
71807203
case GGML_OP_GET_ROWS:
71817204
case GGML_OP_ADD:
71827205
case GGML_OP_ACC:
@@ -7234,6 +7257,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
72347257
} else {
72357258
switch (node->op) {
72367259
case GGML_OP_REPEAT:
7260+
case GGML_OP_REPEAT_BACK:
72377261
case GGML_OP_ACC:
72387262
case GGML_OP_GET_ROWS:
72397263
case GGML_OP_ADD:
@@ -7283,6 +7307,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
72837307
case GGML_OP_REPEAT:
72847308
ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun);
72857309

7310+
break;
7311+
case GGML_OP_REPEAT_BACK:
7312+
ggml_vk_repeat_back(ctx, compute_ctx, src0, node, dryrun);
7313+
72867314
break;
72877315
case GGML_OP_ACC:
72887316
ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7528,6 +7556,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
75287556
case GGML_OP_RWKV_WKV6:
75297557
case GGML_OP_LEAKY_RELU:
75307558
case GGML_OP_REPEAT:
7559+
case GGML_OP_REPEAT_BACK:
75317560
case GGML_OP_OPT_STEP_ADAMW:
75327561
buf = tensor->buffer;
75337562

@@ -8420,6 +8449,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
84208449
} break;
84218450
case GGML_OP_REPEAT:
84228451
return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
8452+
case GGML_OP_REPEAT_BACK:
8453+
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
84238454
case GGML_OP_ROPE:
84248455
{
84258456
const int mode = ((const int32_t *) op->op_params)[2];
@@ -8830,6 +8861,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
88308861
tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
88318862
} else if (tensor->op == GGML_OP_REPEAT) {
88328863
tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor);
8864+
} else if (tensor->op == GGML_OP_REPEAT_BACK) {
8865+
tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor);
88338866
} else if (tensor->op == GGML_OP_ADD) {
88348867
tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
88358868
} else if (tensor->op == GGML_OP_ACC) {
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#version 450
2+
3+
#include "types.comp"
4+
#include "generic_unary_head.comp"
5+
6+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7+
8+
void main() {
9+
const uint idx = get_idx();
10+
11+
if (idx >= p.ne) {
12+
return;
13+
}
14+
15+
// Destination multi-index (inlined dst_idx)
16+
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
17+
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
18+
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
19+
const uint i12_offset = i12*p.ne11*p.ne10;
20+
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
21+
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
22+
const uint d_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
23+
24+
// Accumulate from sources
25+
A_TYPE acc = A_TYPE(0);
26+
for (uint i3 = i13; i3 < p.ne03; i3 += p.ne13) {
27+
for (uint i2 = i12; i2 < p.ne02; i2 += p.ne12) {
28+
for (uint i1 = i11; i1 < p.ne01; i1 += p.ne11) {
29+
for (uint i0 = i10; i0 < p.ne00; i0 += p.ne10) {
30+
acc += data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00];
31+
}
32+
}
33+
}
34+
}
35+
36+
data_d[get_doffset() + d_idx] = D_TYPE(acc);
37+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,7 @@ void process_shaders() {
445445
string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
446446

447447
string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
448+
string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
448449

449450
string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
450451

0 commit comments

Comments
 (0)