@@ -183,9 +183,10 @@ struct vk_device_struct {
183
183
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
184
184
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
185
185
vk_pipeline pipeline_acc_f32;
186
- vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16;
187
- vk_pipeline pipeline_mul_f32;
188
- vk_pipeline pipeline_div_f32;
186
+ vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
187
+ vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
188
+ vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
189
+ vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
189
190
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
190
191
vk_pipeline pipeline_upscale_f32;
191
192
vk_pipeline pipeline_scale_f32;
@@ -1759,13 +1760,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
1759
1760
ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_f16 , " cpy_f32_f16" , cpy_f32_f16_len, cpy_f32_f16_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
1760
1761
ggml_vk_create_pipeline (device, device->pipeline_cpy_f16_f16 , " cpy_f16_f16" , cpy_f16_f16_len, cpy_f16_f16_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
1761
1762
1762
- ggml_vk_create_pipeline (device, device->pipeline_add_f32 , " add_f32" , add_f32_len, add_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
1763
- ggml_vk_create_pipeline (device, device->pipeline_add_f16_f32_f16 , " add_f16_f32_f16" , add_f16_f32_f16_len, add_f16_f32_f16_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
1763
+ ggml_vk_create_pipeline (device, device->pipeline_add_f32 , " add_f32" , add_f32_len, add_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {0 }, 1 );
1764
+ ggml_vk_create_pipeline (device, device->pipeline_add_f32_norepeat , " add_f32_norepeat" , add_f32_len, add_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {1 }, 1 );
1765
+ ggml_vk_create_pipeline (device, device->pipeline_add_f16_f32_f16 , " add_f16_f32_f16" , add_f16_f32_f16_len, add_f16_f32_f16_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {0 }, 1 );
1766
+ ggml_vk_create_pipeline (device, device->pipeline_add_f16_f32_f16_norepeat , " add_f16_f32_f16_norepeat" , add_f16_f32_f16_len, add_f16_f32_f16_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {1 }, 1 );
1764
1767
1765
1768
ggml_vk_create_pipeline (device, device->pipeline_acc_f32 , " acc_f32" , acc_f32_len, acc_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
1766
1769
1767
- ggml_vk_create_pipeline (device, device->pipeline_mul_f32 , " mul_f32" , mul_f32_len, mul_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
1768
- ggml_vk_create_pipeline (device, device->pipeline_div_f32 , " div_f32" , div_f32_len, div_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
1770
+ ggml_vk_create_pipeline (device, device->pipeline_mul_f32 , " mul_f32" , mul_f32_len, mul_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {0 }, 1 );
1771
+ ggml_vk_create_pipeline (device, device->pipeline_mul_f32_norepeat , " mul_f32_norepeat" , mul_f32_len, mul_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {1 }, 1 );
1772
+ ggml_vk_create_pipeline (device, device->pipeline_div_f32 , " div_f32" , div_f32_len, div_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {0 }, 1 );
1773
+ ggml_vk_create_pipeline (device, device->pipeline_div_f32_norepeat , " div_f32_norepeat" , div_f32_len, div_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {1 }, 1 );
1769
1774
1770
1775
ggml_vk_create_pipeline (device, device->pipeline_concat_f32 , " concat_f32" , concat_f32_len, concat_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
1771
1776
ggml_vk_create_pipeline (device, device->pipeline_concat_f16 , " concat_f16" , concat_f16_len, concat_f16_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
@@ -4078,20 +4083,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
4078
4083
return nullptr ;
4079
4084
case GGML_OP_ADD:
4080
4085
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
4081
- return ctx->device ->pipeline_add_f32 ;
4086
+ return ggml_are_same_shape (src0, src1) ? ctx-> device -> pipeline_add_f32_norepeat : ctx->device ->pipeline_add_f32 ;
4082
4087
}
4083
4088
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
4084
- return ctx->device ->pipeline_add_f16_f32_f16 ;
4089
+ return ggml_are_same_shape (src0, src1) ? ctx-> device -> pipeline_add_f16_f32_f16_norepeat : ctx->device ->pipeline_add_f16_f32_f16 ;
4085
4090
}
4086
4091
return nullptr ;
4087
4092
case GGML_OP_MUL:
4088
4093
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
4089
- return ctx->device ->pipeline_mul_f32 ;
4094
+ return ggml_are_same_shape (src0, src1) ? ctx-> device -> pipeline_mul_f32_norepeat : ctx->device ->pipeline_mul_f32 ;
4090
4095
}
4091
4096
return nullptr ;
4092
4097
case GGML_OP_DIV:
4093
4098
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
4094
- return ctx->device ->pipeline_div_f32 ;
4099
+ return ggml_are_same_shape (src0, src1) ? ctx-> device -> pipeline_div_f32_norepeat : ctx->device ->pipeline_div_f32 ;
4095
4100
}
4096
4101
return nullptr ;
4097
4102
case GGML_OP_CONCAT:
0 commit comments