@@ -3147,7 +3147,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
3147
3147
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
3148
3148
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
3149
3149
3150
- if (mmp == nullptr ) {
3150
+ if (qx_needs_dequant ) {
3151
3151
// Fall back to dequant + f16 mulmat
3152
3152
mmp = ggml_vk_get_mul_mat_mat_pipeline (ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
3153
3153
}
@@ -3630,9 +3630,19 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
3630
3630
3631
3631
static void ggml_vk_mul_mat (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
3632
3632
VK_LOG_DEBUG (" ggml_vk_mul_mat(" << src0 << " , " << src1 << " , " << dst << " )" );
3633
- if (src0->type == GGML_TYPE_F16 && ggml_is_permuted (src0) && ggml_is_permuted (src1) && dst->ne [1 ] == 1 ) {
3633
+ if (src0->type == GGML_TYPE_F16 && ggml_is_permuted (src0) && ggml_is_permuted (src1) && dst->ne [1 ] == 1 &&
3634
+ // detect 0213 permutation, and batch size of 1
3635
+ src0->nb [0 ] <= src0->nb [2 ] &&
3636
+ src0->nb [2 ] <= src0->nb [1 ] &&
3637
+ src0->nb [1 ] <= src0->nb [3 ] &&
3638
+ src1->nb [0 ] <= src1->nb [2 ] &&
3639
+ src1->nb [2 ] <= src1->nb [1 ] &&
3640
+ src1->nb [1 ] <= src1->nb [3 ] &&
3641
+ src0->ne [3 ] == 1 &&
3642
+ src1->ne [3 ] == 1 ) {
3634
3643
ggml_vk_mul_mat_vec_p021_f16_f32 (ctx, subctx, src0, src1, dst, dryrun);
3635
- } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous (src0) && !ggml_is_transposed (src1) && dst->ne [1 ] == 1 ) {
3644
+ } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous (src0) && !ggml_is_transposed (src1) && dst->ne [1 ] == 1 &&
3645
+ !ggml_is_permuted (src0) && !ggml_is_permuted (src1)) {
3636
3646
ggml_vk_mul_mat_vec_nc_f16_f32 (ctx, subctx, src0, src1, dst, dryrun);
3637
3647
} else if (dst->ne [1 ] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized (src0->type ))) {
3638
3648
ggml_vk_mul_mat_vec_q_f16 (ctx, subctx, src0, src1, dst, dryrun);
@@ -3708,7 +3718,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
3708
3718
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
3709
3719
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
3710
3720
3711
- if (mmp == nullptr ) {
3721
+ if (qx_needs_dequant ) {
3712
3722
GGML_ABORT (" fatal error" );
3713
3723
}
3714
3724
@@ -4470,7 +4480,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
4470
4480
const uint32_t OH = is_2D ? dst->ne [2 ] : 1 ;
4471
4481
const uint32_t OW = dst->ne [1 ];
4472
4482
4473
- const uint32_t batch = src1->ne [3 ];
4483
+ const uint32_t batch = src1->ne [is_2D ? 3 : 2 ];
4474
4484
4475
4485
elements = { OW * KW * KH, OH, batch * IC };
4476
4486
} break ;
@@ -4915,7 +4925,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
4915
4925
const uint32_t OW = dst->ne [1 ];
4916
4926
4917
4927
const uint32_t offset_delta = src1->nb [is_2D ? 2 : 1 ] / 4 ; // nb is byte offset, src is type float32
4918
- const uint32_t batch_offset = src1->nb [3 ] / 4 ; // nb is byte offset, src is type float32
4928
+ const uint32_t batch_offset = src1->nb [is_2D ? 3 : 2 ] / 4 ; // nb is byte offset, src is type float32
4919
4929
4920
4930
const uint32_t pelements = OW * KW * KH;
4921
4931
@@ -6804,6 +6814,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
6804
6814
if (a->ne [3 ] != b->ne [3 ]) {
6805
6815
return false ;
6806
6816
}
6817
+ if (!(ggml_vk_dim01_contiguous (op->src [0 ]) || op->src [0 ]->type == GGML_TYPE_F32 || op->src [0 ]->type == GGML_TYPE_F16) ||
6818
+ !(ggml_vk_dim01_contiguous (op->src [1 ]) || op->src [1 ]->type == GGML_TYPE_F32 || op->src [1 ]->type == GGML_TYPE_F16)) {
6819
+ return false ;
6820
+ }
6821
+
6807
6822
return true ;
6808
6823
} break ;
6809
6824
case GGML_OP_GET_ROWS:
0 commit comments