@@ -221,6 +221,7 @@ struct vk_device_struct {
221
221
vk_pipeline pipeline_acc_f32;
222
222
vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
223
223
vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
224
+ vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat;
224
225
vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
225
226
vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
226
227
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
@@ -2100,6 +2101,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2100
2101
2101
2102
ggml_vk_create_pipeline (device, device->pipeline_acc_f32 , " acc_f32" , acc_f32_len, acc_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
2102
2103
2104
+ ggml_vk_create_pipeline (device, device->pipeline_sub_f32 , " sub_f32" , sub_f32_len, sub_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {0 }, 1 );
2105
+ ggml_vk_create_pipeline (device, device->pipeline_sub_f32_norepeat , " sub_f32_norepeat" , sub_f32_len, sub_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {1 }, 1 );
2103
2106
ggml_vk_create_pipeline (device, device->pipeline_mul_f32 , " mul_f32" , mul_f32_len, mul_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {0 }, 1 );
2104
2107
ggml_vk_create_pipeline (device, device->pipeline_mul_f32_norepeat , " mul_f32_norepeat" , mul_f32_len, mul_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {1 }, 1 );
2105
2108
ggml_vk_create_pipeline (device, device->pipeline_div_f32 , " div_f32" , div_f32_len, div_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {0 }, 1 );
@@ -5126,6 +5129,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5126
5129
return ggml_are_same_shape (src0, src1) ? ctx->device ->pipeline_add_f16_f32_f16_norepeat : ctx->device ->pipeline_add_f16_f32_f16 ;
5127
5130
}
5128
5131
return nullptr ;
5132
+ case GGML_OP_SUB:
5133
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5134
+ return ggml_are_same_shape (src0, src1) ? ctx->device ->pipeline_sub_f32_norepeat : ctx->device ->pipeline_sub_f32 ;
5135
+ }
5136
+ return nullptr ;
5129
5137
case GGML_OP_MUL:
5130
5138
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5131
5139
return ggml_are_same_shape (src0, src1) ? ctx->device ->pipeline_mul_f32_norepeat : ctx->device ->pipeline_mul_f32 ;
@@ -5330,6 +5338,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
5330
5338
case GGML_OP_CPY:
5331
5339
case GGML_OP_GET_ROWS:
5332
5340
case GGML_OP_ADD:
5341
+ case GGML_OP_SUB:
5333
5342
case GGML_OP_MUL:
5334
5343
case GGML_OP_DIV:
5335
5344
case GGML_OP_CONCAT:
@@ -5614,6 +5623,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5614
5623
elements = { N * OC * OH * OW, 1 , 1 };
5615
5624
} break ;
5616
5625
case GGML_OP_ADD:
5626
+ case GGML_OP_SUB:
5617
5627
case GGML_OP_DIV:
5618
5628
case GGML_OP_MUL:
5619
5629
case GGML_OP_SCALE:
@@ -5745,6 +5755,21 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
5745
5755
}, dryrun);
5746
5756
}
5747
5757
5758
+ static void ggml_vk_sub (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
5759
+ const uint32_t src0_type_size = ggml_type_size (src0->type );
5760
+ const uint32_t src1_type_size = ggml_type_size (src1->type );
5761
+ const uint32_t dst_type_size = ggml_type_size (dst->type );
5762
+
5763
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr , dst, GGML_OP_SUB, {
5764
+ (uint32_t )ggml_nelements (src0),
5765
+ (uint32_t )src0->ne [0 ], (uint32_t )src0->ne [1 ], (uint32_t )src0->ne [2 ],(uint32_t )src0->ne [3 ], (uint32_t )src0->nb [0 ] / src0_type_size, (uint32_t )src0->nb [1 ] / src0_type_size, (uint32_t )src0->nb [2 ] / src0_type_size, (uint32_t )src0->nb [3 ] / src0_type_size,
5766
+ (uint32_t )src1->ne [0 ], (uint32_t )src1->ne [1 ], (uint32_t )src1->ne [2 ],(uint32_t )src1->ne [3 ], (uint32_t )src1->nb [0 ] / src1_type_size, (uint32_t )src1->nb [1 ] / src1_type_size, (uint32_t )src1->nb [2 ] / src1_type_size, (uint32_t )src1->nb [3 ] / src1_type_size,
5767
+ (uint32_t ) dst->ne [0 ], (uint32_t ) dst->ne [1 ], (uint32_t ) dst->ne [2 ],(uint32_t ) dst->ne [3 ], (uint32_t ) dst->nb [0 ] / dst_type_size, (uint32_t ) dst->nb [1 ] / dst_type_size, (uint32_t ) dst->nb [2 ] / dst_type_size, (uint32_t ) dst->nb [3 ] / dst_type_size,
5768
+ 0 ,
5769
+ 0 .0f , 0 .0f , 0 ,
5770
+ }, dryrun);
5771
+ }
5772
+
5748
5773
static void ggml_vk_mul (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
5749
5774
const uint32_t src0_type_size = ggml_type_size (src0->type );
5750
5775
const uint32_t src1_type_size = ggml_type_size (src1->type );
@@ -7029,6 +7054,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7029
7054
case GGML_OP_GET_ROWS:
7030
7055
case GGML_OP_ADD:
7031
7056
case GGML_OP_ACC:
7057
+ case GGML_OP_SUB:
7032
7058
case GGML_OP_MUL:
7033
7059
case GGML_OP_DIV:
7034
7060
case GGML_OP_CONCAT:
@@ -7083,6 +7109,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7083
7109
case GGML_OP_ACC:
7084
7110
case GGML_OP_GET_ROWS:
7085
7111
case GGML_OP_ADD:
7112
+ case GGML_OP_SUB:
7086
7113
case GGML_OP_MUL:
7087
7114
case GGML_OP_DIV:
7088
7115
case GGML_OP_CONCAT:
@@ -7139,6 +7166,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7139
7166
case GGML_OP_ADD:
7140
7167
ggml_vk_add (ctx, compute_ctx, src0, src1, node, dryrun);
7141
7168
7169
+ break ;
7170
+ case GGML_OP_SUB:
7171
+ ggml_vk_sub (ctx, compute_ctx, src0, src1, node, dryrun);
7172
+
7142
7173
break ;
7143
7174
case GGML_OP_MUL:
7144
7175
ggml_vk_mul (ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7323,6 +7354,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7323
7354
case GGML_OP_ADD:
7324
7355
case GGML_OP_ACC:
7325
7356
case GGML_OP_GET_ROWS:
7357
+ case GGML_OP_SUB:
7326
7358
case GGML_OP_MUL:
7327
7359
case GGML_OP_DIV:
7328
7360
case GGML_OP_CONCAT:
@@ -8271,6 +8303,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8271
8303
return ggml_is_contiguous (op->src [0 ]);
8272
8304
case GGML_OP_ADD:
8273
8305
case GGML_OP_ACC:
8306
+ case GGML_OP_SUB:
8274
8307
case GGML_OP_MUL:
8275
8308
case GGML_OP_DIV:
8276
8309
case GGML_OP_CONCAT:
@@ -8762,6 +8795,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8762
8795
tensor_clone = ggml_mul_mat (ggml_ctx, src0_clone, src1_clone);
8763
8796
} else if (tensor->op == GGML_OP_MUL_MAT_ID) {
8764
8797
tensor_clone = ggml_mul_mat_id (ggml_ctx, src0_clone, src1_clone, src2_clone);
8798
+ } else if (tensor->op == GGML_OP_SUB) {
8799
+ tensor_clone = ggml_sub (ggml_ctx, src0_clone, src1_clone);
8765
8800
} else if (tensor->op == GGML_OP_MUL) {
8766
8801
tensor_clone = ggml_mul (ggml_ctx, src0_clone, src1_clone);
8767
8802
} else if (tensor->op == GGML_OP_DIV) {
0 commit comments