Skip to content

Commit 160687b

Browse files
authored
vulkan: Fix newly added tests for permuted mul_mat and 1D im2col (#10226)
1 parent 6423c65 commit 160687b

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

ggml/src/ggml-vulkan.cpp

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3147,7 +3147,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
31473147
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
31483148
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
31493149

3150-
if (mmp == nullptr) {
3150+
if (qx_needs_dequant) {
31513151
// Fall back to dequant + f16 mulmat
31523152
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
31533153
}
@@ -3630,9 +3630,19 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
36303630

36313631
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
36323632
VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
3633-
if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1) {
3633+
if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
3634+
// detect 0213 permutation, and batch size of 1
3635+
src0->nb[0] <= src0->nb[2] &&
3636+
src0->nb[2] <= src0->nb[1] &&
3637+
src0->nb[1] <= src0->nb[3] &&
3638+
src1->nb[0] <= src1->nb[2] &&
3639+
src1->nb[2] <= src1->nb[1] &&
3640+
src1->nb[1] <= src1->nb[3] &&
3641+
src0->ne[3] == 1 &&
3642+
src1->ne[3] == 1) {
36343643
ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
3635-
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1) {
3644+
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
3645+
!ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
36363646
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
36373647
} else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
36383648
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
@@ -3708,7 +3718,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
37083718
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
37093719
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
37103720

3711-
if (mmp == nullptr) {
3721+
if (qx_needs_dequant) {
37123722
GGML_ABORT("fatal error");
37133723
}
37143724

@@ -4470,7 +4480,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
44704480
const uint32_t OH = is_2D ? dst->ne[2] : 1;
44714481
const uint32_t OW = dst->ne[1];
44724482

4473-
const uint32_t batch = src1->ne[3];
4483+
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
44744484

44754485
elements = { OW * KW * KH, OH, batch * IC };
44764486
} break;
@@ -4915,7 +4925,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
49154925
const uint32_t OW = dst->ne[1];
49164926

49174927
const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
4918-
const uint32_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
4928+
const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
49194929

49204930
const uint32_t pelements = OW * KW * KH;
49214931

@@ -6804,6 +6814,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
68046814
if (a->ne[3] != b->ne[3]) {
68056815
return false;
68066816
}
6817+
if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) ||
6818+
!(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
6819+
return false;
6820+
}
6821+
68076822
return true;
68086823
} break;
68096824
case GGML_OP_GET_ROWS:

0 commit comments

Comments
 (0)