@@ -233,7 +233,7 @@ struct vk_device_struct {
233
233
vk_pipeline pipeline_cos_f32;
234
234
vk_pipeline pipeline_clamp_f32;
235
235
vk_pipeline pipeline_pad_f32;
236
- vk_pipeline pipeline_repeat_f32;
236
+ vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32 ;
237
237
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
238
238
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
239
239
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
@@ -2175,6 +2175,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2175
2175
ggml_vk_create_pipeline (device, device->pipeline_pad_f32 , " pad_f32" , pad_f32_len, pad_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
2176
2176
2177
2177
ggml_vk_create_pipeline (device, device->pipeline_repeat_f32 , " repeat_f32" , repeat_f32_len, repeat_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
2178
+ ggml_vk_create_pipeline (device, device->pipeline_repeat_back_f32 , " repeat_back_f32" , repeat_back_f32_len, repeat_back_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
2178
2179
2179
2180
ggml_vk_create_pipeline (device, device->pipeline_gelu_f32 , " gelu_f32" , gelu_f32_len, gelu_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
2180
2181
ggml_vk_create_pipeline (device, device->pipeline_gelu_quick_f32 , " gelu_quick_f32" , gelu_quick_f32_len, gelu_quick_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
@@ -5267,6 +5268,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5267
5268
return ctx->device ->pipeline_repeat_f32 ;
5268
5269
}
5269
5270
return nullptr ;
5271
+ case GGML_OP_REPEAT_BACK:
5272
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5273
+ return ctx->device ->pipeline_repeat_back_f32 ;
5274
+ }
5275
+ return nullptr ;
5270
5276
case GGML_OP_CPY:
5271
5277
case GGML_OP_CONT:
5272
5278
case GGML_OP_DUP:
@@ -5447,6 +5453,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
5447
5453
case GGML_OP_CLAMP:
5448
5454
case GGML_OP_PAD:
5449
5455
case GGML_OP_REPEAT:
5456
+ case GGML_OP_REPEAT_BACK:
5450
5457
case GGML_OP_ROPE:
5451
5458
return true ;
5452
5459
default :
@@ -5732,6 +5739,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5732
5739
case GGML_OP_CLAMP:
5733
5740
case GGML_OP_PAD:
5734
5741
case GGML_OP_REPEAT:
5742
+ case GGML_OP_REPEAT_BACK:
5735
5743
case GGML_OP_CPY:
5736
5744
case GGML_OP_CONCAT:
5737
5745
case GGML_OP_UPSCALE:
@@ -6265,6 +6273,20 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co
6265
6273
}, dryrun);
6266
6274
}
6267
6275
6276
+ static void ggml_vk_repeat_back (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
6277
+ const uint32_t src0_type_size = ggml_type_size (src0->type );
6278
+ const uint32_t dst_type_size = ggml_type_size (dst->type );
6279
+
6280
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_REPEAT_BACK, {
6281
+ (uint32_t )ggml_nelements (dst),
6282
+ (uint32_t )src0->ne [0 ], (uint32_t )src0->ne [1 ], (uint32_t )src0->ne [2 ], (uint32_t )src0->ne [3 ], (uint32_t )src0->nb [0 ] / src0_type_size, (uint32_t )src0->nb [1 ] / src0_type_size, (uint32_t )src0->nb [2 ] / src0_type_size, (uint32_t )src0->nb [3 ] / src0_type_size,
6283
+ (uint32_t ) dst->ne [0 ], (uint32_t ) dst->ne [1 ], (uint32_t ) dst->ne [2 ], (uint32_t ) dst->ne [3 ], (uint32_t ) dst->nb [0 ] / dst_type_size, (uint32_t ) dst->nb [1 ] / dst_type_size, (uint32_t ) dst->nb [2 ] / dst_type_size, (uint32_t ) dst->nb [3 ] / dst_type_size,
6284
+ 0 ,
6285
+ 0 .0f , 0 .0f ,
6286
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
6287
+ }, dryrun);
6288
+ }
6289
+
6268
6290
static void ggml_vk_cpy (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
6269
6291
const uint32_t src0_type_size = ggml_type_size (src0->type );
6270
6292
const uint32_t dst_type_size = ggml_type_size (dst->type );
@@ -7268,6 +7290,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7268
7290
}
7269
7291
break ;
7270
7292
case GGML_OP_REPEAT:
7293
+ case GGML_OP_REPEAT_BACK:
7271
7294
case GGML_OP_GET_ROWS:
7272
7295
case GGML_OP_ADD:
7273
7296
case GGML_OP_ACC:
@@ -7325,6 +7348,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7325
7348
} else {
7326
7349
switch (node->op ) {
7327
7350
case GGML_OP_REPEAT:
7351
+ case GGML_OP_REPEAT_BACK:
7328
7352
case GGML_OP_ACC:
7329
7353
case GGML_OP_GET_ROWS:
7330
7354
case GGML_OP_ADD:
@@ -7374,6 +7398,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7374
7398
case GGML_OP_REPEAT:
7375
7399
ggml_vk_repeat (ctx, compute_ctx, src0, node, dryrun);
7376
7400
7401
+ break ;
7402
+ case GGML_OP_REPEAT_BACK:
7403
+ ggml_vk_repeat_back (ctx, compute_ctx, src0, node, dryrun);
7404
+
7377
7405
break ;
7378
7406
case GGML_OP_ACC:
7379
7407
ggml_vk_acc (ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7619,6 +7647,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7619
7647
case GGML_OP_RWKV_WKV6:
7620
7648
case GGML_OP_LEAKY_RELU:
7621
7649
case GGML_OP_REPEAT:
7650
+ case GGML_OP_REPEAT_BACK:
7622
7651
case GGML_OP_OPT_STEP_ADAMW:
7623
7652
buf = tensor->buffer ;
7624
7653
@@ -8517,6 +8546,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8517
8546
} break ;
8518
8547
case GGML_OP_REPEAT:
8519
8548
return ggml_type_size (op->type ) == sizeof (float ) && ggml_type_size (op->src [0 ]->type ) == sizeof (float );
8549
+ case GGML_OP_REPEAT_BACK:
8550
+ return op->type == GGML_TYPE_F32 && op->src [0 ]->type == GGML_TYPE_F32;
8520
8551
case GGML_OP_ROPE:
8521
8552
case GGML_OP_NONE:
8522
8553
case GGML_OP_RESHAPE:
@@ -8922,6 +8953,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8922
8953
tensor_clone = ggml_pad (ggml_ctx, src_clone[0 ], tensor->ne [0 ] - src_clone[0 ]->ne [0 ], tensor->ne [1 ] - src_clone[0 ]->ne [1 ], tensor->ne [2 ] - src_clone[0 ]->ne [2 ], tensor->ne [3 ] - src_clone[0 ]->ne [3 ]);
8923
8954
} else if (tensor->op == GGML_OP_REPEAT) {
8924
8955
tensor_clone = ggml_repeat (ggml_ctx, src_clone[0 ], tensor);
8956
+ } else if (tensor->op == GGML_OP_REPEAT_BACK) {
8957
+ tensor_clone = ggml_repeat_back (ggml_ctx, src_clone[0 ], tensor);
8925
8958
} else if (tensor->op == GGML_OP_ADD) {
8926
8959
tensor_clone = ggml_add (ggml_ctx, src_clone[0 ], src_clone[1 ]);
8927
8960
} else if (tensor->op == GGML_OP_ACC) {
0 commit comments