@@ -180,6 +180,7 @@ struct vk_device_struct {
180
180
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
181
181
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
182
182
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
183
+ vk_pipeline pipeline_acc_f32;
183
184
vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16;
184
185
vk_pipeline pipeline_mul_f32;
185
186
vk_pipeline pipeline_div_f32;
@@ -1687,6 +1688,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
1687
1688
ggml_vk_create_pipeline (device, device->pipeline_add_f32 , " add_f32" , add_f32_len, add_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
1688
1689
ggml_vk_create_pipeline (device, device->pipeline_add_f16_f32_f16 , " add_f16_f32_f16" , add_f16_f32_f16_len, add_f16_f32_f16_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
1689
1690
1691
+ ggml_vk_create_pipeline (device, device->pipeline_acc_f32 , " acc_f32" , acc_f32_len, acc_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
1692
+
1690
1693
ggml_vk_create_pipeline (device, device->pipeline_mul_f32 , " mul_f32" , mul_f32_len, mul_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
1691
1694
ggml_vk_create_pipeline (device, device->pipeline_div_f32 , " div_f32" , div_f32_len, div_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
1692
1695
@@ -3971,6 +3974,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
3971
3974
return ctx->device ->pipeline_get_rows_f32 [src0->type ];
3972
3975
}
3973
3976
return nullptr ;
3977
+ case GGML_OP_ACC:
3978
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3979
+ return ctx->device ->pipeline_acc_f32 ;
3980
+ }
3981
+ return nullptr ;
3974
3982
case GGML_OP_ADD:
3975
3983
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3976
3984
return ctx->device ->pipeline_add_f32 ;
@@ -4463,6 +4471,28 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
4463
4471
}, dryrun);
4464
4472
}
4465
4473
4474
+ static void ggml_vk_acc (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
4475
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra ;
4476
+ const uint32_t src0_type_size = ggml_type_size (src0->type );
4477
+ const uint32_t src1_type_size = ggml_type_size (src1->type );
4478
+ const uint32_t dst_type_size = ggml_type_size (dst->type );
4479
+ const uint32_t d_offset = ((extra->offset + dst->view_offs ) % ctx->device ->properties .limits .minStorageBufferOffsetAlignment ) / dst_type_size;
4480
+
4481
+ int nb1 = dst->op_params [0 ] / 4 ; // 4 bytes of float32
4482
+ int nb2 = dst->op_params [1 ] / 4 ; // 4 bytes of float32
4483
+ // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
4484
+ int offset = dst->op_params [3 ] / 4 ; // offset in bytes
4485
+
4486
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr , dst, GGML_OP_ACC, {
4487
+ (uint32_t )ggml_nelements (src0),
4488
+ (uint32_t )src0->ne [0 ], (uint32_t )src0->ne [1 ], (uint32_t )src0->ne [2 ],(uint32_t )src0->ne [3 ], (uint32_t )src0->nb [0 ] / src0_type_size, (uint32_t )nb1, (uint32_t )nb2, (uint32_t )src0->nb [3 ] / src0_type_size,
4489
+ (uint32_t )src1->ne [0 ], (uint32_t )src1->ne [1 ], (uint32_t )src1->ne [2 ],(uint32_t )src1->ne [3 ], (uint32_t )src1->nb [0 ] / src1_type_size, (uint32_t )src1->nb [1 ] / src1_type_size, (uint32_t )src1->nb [2 ] / src1_type_size, (uint32_t )src1->nb [3 ] / src1_type_size,
4490
+ (uint32_t ) dst->ne [0 ], (uint32_t ) dst->ne [1 ], (uint32_t ) dst->ne [2 ],(uint32_t ) dst->ne [3 ], (uint32_t ) dst->nb [0 ] / dst_type_size, (uint32_t )nb1, (uint32_t )nb2, (uint32_t ) dst->nb [3 ] / dst_type_size,
4491
+ d_offset,
4492
+ 0 .0f , 0 .0f , offset,
4493
+ }, dryrun);
4494
+ }
4495
+
4466
4496
static void ggml_vk_add (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
4467
4497
const uint32_t src0_type_size = ggml_type_size (src0->type );
4468
4498
const uint32_t src1_type_size = ggml_type_size (src1->type );
@@ -5621,6 +5651,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
5621
5651
case GGML_OP_REPEAT:
5622
5652
case GGML_OP_GET_ROWS:
5623
5653
case GGML_OP_ADD:
5654
+ case GGML_OP_ACC:
5624
5655
case GGML_OP_MUL:
5625
5656
case GGML_OP_DIV:
5626
5657
case GGML_OP_CONCAT:
@@ -5668,6 +5699,10 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
5668
5699
case GGML_OP_REPEAT:
5669
5700
ggml_vk_repeat (ctx, compute_ctx, src0, node, dryrun);
5670
5701
5702
+ break ;
5703
+ case GGML_OP_ACC:
5704
+ ggml_vk_acc (ctx, compute_ctx, src0, src1, node, dryrun);
5705
+
5671
5706
break ;
5672
5707
case GGML_OP_GET_ROWS:
5673
5708
ggml_vk_get_rows (ctx, compute_ctx, src0, src1, node, dryrun);
@@ -5808,6 +5843,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
5808
5843
5809
5844
switch (tensor->op ) {
5810
5845
case GGML_OP_ADD:
5846
+ case GGML_OP_ACC:
5811
5847
case GGML_OP_GET_ROWS:
5812
5848
case GGML_OP_MUL:
5813
5849
case GGML_OP_DIV:
@@ -6539,6 +6575,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
6539
6575
case GGML_OP_GROUP_NORM:
6540
6576
case GGML_OP_RMS_NORM:
6541
6577
case GGML_OP_ADD:
6578
+ case GGML_OP_ACC:
6542
6579
case GGML_OP_MUL:
6543
6580
case GGML_OP_DIV:
6544
6581
case GGML_OP_CONCAT:
@@ -6995,6 +7032,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
6995
7032
tensor_clone = ggml_repeat (ggml_ctx, src0_clone, src1_clone);
6996
7033
} else if (tensor->op == GGML_OP_ADD) {
6997
7034
tensor_clone = ggml_add (ggml_ctx, src0_clone, src1_clone);
7035
+ } else if (tensor->op == GGML_OP_ACC) {
7036
+ tensor_clone = ggml_acc (ggml_ctx, src0_clone, src1_clone, tensor->op_params [0 ], tensor->op_params [1 ], tensor->op_params [2 ], tensor->op_params [3 ]);
6998
7037
} else if (tensor->op == GGML_OP_NORM) {
6999
7038
tensor_clone = ggml_norm (ggml_ctx, src0_clone, *(float *)tensor->op_params );
7000
7039
} else if (tensor->op == GGML_OP_GROUP_NORM) {
0 commit comments