@@ -232,7 +232,7 @@ struct vk_device_struct {
232
232
vk_pipeline pipeline_cos_f32;
233
233
vk_pipeline pipeline_clamp_f32;
234
234
vk_pipeline pipeline_pad_f32;
235
- vk_pipeline pipeline_repeat_f32;
235
+ vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32 ;
236
236
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
237
237
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
238
238
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
@@ -2127,6 +2127,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2127
2127
ggml_vk_create_pipeline (device, device->pipeline_pad_f32 , " pad_f32" , pad_f32_len, pad_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
2128
2128
2129
2129
ggml_vk_create_pipeline (device, device->pipeline_repeat_f32 , " repeat_f32" , repeat_f32_len, repeat_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
2130
+ ggml_vk_create_pipeline (device, device->pipeline_repeat_back_f32 , " repeat_back_f32" , repeat_back_f32_len, repeat_back_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
2130
2131
2131
2132
ggml_vk_create_pipeline (device, device->pipeline_gelu_f32 , " gelu_f32" , gelu_f32_len, gelu_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
2132
2133
ggml_vk_create_pipeline (device, device->pipeline_gelu_quick_f32 , " gelu_quick_f32" , gelu_quick_f32_len, gelu_quick_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
@@ -5201,6 +5202,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5201
5202
return ctx->device ->pipeline_repeat_f32 ;
5202
5203
}
5203
5204
return nullptr ;
5205
+ case GGML_OP_REPEAT_BACK:
5206
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5207
+ return ctx->device ->pipeline_repeat_back_f32 ;
5208
+ }
5209
+ return nullptr ;
5204
5210
case GGML_OP_CPY:
5205
5211
case GGML_OP_CONT:
5206
5212
case GGML_OP_DUP:
@@ -5365,6 +5371,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
5365
5371
case GGML_OP_CLAMP:
5366
5372
case GGML_OP_PAD:
5367
5373
case GGML_OP_REPEAT:
5374
+ case GGML_OP_REPEAT_BACK:
5368
5375
return true ;
5369
5376
default :
5370
5377
return false ;
@@ -5649,6 +5656,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5649
5656
case GGML_OP_CLAMP:
5650
5657
case GGML_OP_PAD:
5651
5658
case GGML_OP_REPEAT:
5659
+ case GGML_OP_REPEAT_BACK:
5652
5660
case GGML_OP_CPY:
5653
5661
case GGML_OP_CONCAT:
5654
5662
case GGML_OP_UPSCALE:
@@ -6182,6 +6190,20 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co
6182
6190
}, dryrun);
6183
6191
}
6184
6192
6193
+ static void ggml_vk_repeat_back (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
6194
+ const uint32_t src0_type_size = ggml_type_size (src0->type );
6195
+ const uint32_t dst_type_size = ggml_type_size (dst->type );
6196
+
6197
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_REPEAT_BACK, {
6198
+ (uint32_t )ggml_nelements (dst),
6199
+ (uint32_t )src0->ne [0 ], (uint32_t )src0->ne [1 ], (uint32_t )src0->ne [2 ], (uint32_t )src0->ne [3 ], (uint32_t )src0->nb [0 ] / src0_type_size, (uint32_t )src0->nb [1 ] / src0_type_size, (uint32_t )src0->nb [2 ] / src0_type_size, (uint32_t )src0->nb [3 ] / src0_type_size,
6200
+ (uint32_t ) dst->ne [0 ], (uint32_t ) dst->ne [1 ], (uint32_t ) dst->ne [2 ], (uint32_t ) dst->ne [3 ], (uint32_t ) dst->nb [0 ] / dst_type_size, (uint32_t ) dst->nb [1 ] / dst_type_size, (uint32_t ) dst->nb [2 ] / dst_type_size, (uint32_t ) dst->nb [3 ] / dst_type_size,
6201
+ 0 ,
6202
+ 0 .0f , 0 .0f ,
6203
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
6204
+ }, dryrun);
6205
+ }
6206
+
6185
6207
static void ggml_vk_cpy (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
6186
6208
const uint32_t src0_type_size = ggml_type_size (src0->type );
6187
6209
const uint32_t dst_type_size = ggml_type_size (dst->type );
@@ -7177,6 +7199,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7177
7199
}
7178
7200
break ;
7179
7201
case GGML_OP_REPEAT:
7202
+ case GGML_OP_REPEAT_BACK:
7180
7203
case GGML_OP_GET_ROWS:
7181
7204
case GGML_OP_ADD:
7182
7205
case GGML_OP_ACC:
@@ -7234,6 +7257,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7234
7257
} else {
7235
7258
switch (node->op ) {
7236
7259
case GGML_OP_REPEAT:
7260
+ case GGML_OP_REPEAT_BACK:
7237
7261
case GGML_OP_ACC:
7238
7262
case GGML_OP_GET_ROWS:
7239
7263
case GGML_OP_ADD:
@@ -7283,6 +7307,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7283
7307
case GGML_OP_REPEAT:
7284
7308
ggml_vk_repeat (ctx, compute_ctx, src0, node, dryrun);
7285
7309
7310
+ break ;
7311
+ case GGML_OP_REPEAT_BACK:
7312
+ ggml_vk_repeat_back (ctx, compute_ctx, src0, node, dryrun);
7313
+
7286
7314
break ;
7287
7315
case GGML_OP_ACC:
7288
7316
ggml_vk_acc (ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7528,6 +7556,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7528
7556
case GGML_OP_RWKV_WKV6:
7529
7557
case GGML_OP_LEAKY_RELU:
7530
7558
case GGML_OP_REPEAT:
7559
+ case GGML_OP_REPEAT_BACK:
7531
7560
case GGML_OP_OPT_STEP_ADAMW:
7532
7561
buf = tensor->buffer ;
7533
7562
@@ -8420,6 +8449,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8420
8449
} break ;
8421
8450
case GGML_OP_REPEAT:
8422
8451
return ggml_type_size (op->type ) == sizeof (float ) && ggml_type_size (op->src [0 ]->type ) == sizeof (float );
8452
+ case GGML_OP_REPEAT_BACK:
8453
+ return op->type == GGML_TYPE_F32 && op->src [0 ]->type == GGML_TYPE_F32;
8423
8454
case GGML_OP_ROPE:
8424
8455
{
8425
8456
const int mode = ((const int32_t *) op->op_params )[2 ];
@@ -8830,6 +8861,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8830
8861
tensor_clone = ggml_pad (ggml_ctx, src_clone[0 ], tensor->ne [0 ] - src_clone[0 ]->ne [0 ], tensor->ne [1 ] - src_clone[0 ]->ne [1 ], tensor->ne [2 ] - src_clone[0 ]->ne [2 ], tensor->ne [3 ] - src_clone[0 ]->ne [3 ]);
8831
8862
} else if (tensor->op == GGML_OP_REPEAT) {
8832
8863
tensor_clone = ggml_repeat (ggml_ctx, src_clone[0 ], tensor);
8864
+ } else if (tensor->op == GGML_OP_REPEAT_BACK) {
8865
+ tensor_clone = ggml_repeat_back (ggml_ctx, src_clone[0 ], tensor);
8833
8866
} else if (tensor->op == GGML_OP_ADD) {
8834
8867
tensor_clone = ggml_add (ggml_ctx, src_clone[0 ], src_clone[1 ]);
8835
8868
} else if (tensor->op == GGML_OP_ACC) {
0 commit comments