@@ -1476,26 +1476,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
1476
1476
// spec constants and tile sizes for quant matmul (non-Qi_K)
1477
1477
l_warptile_mmq = { 256 , 128 , 256 , 64 };
1478
1478
m_warptile_mmq = { 256 , 128 , 128 , 64 };
1479
- s_warptile_mmq = { 256 , 128 , 128 , 64 };
1479
+ s_warptile_mmq = { 256 , 32 , 64 , 128 };
1480
1480
l_mmq_wg_denoms = { 128 , 256 , 1 };
1481
1481
m_mmq_wg_denoms = { 128 , 128 , 1 };
1482
- s_mmq_wg_denoms = { 128 , 128 , 1 };
1482
+ s_mmq_wg_denoms = { 32 , 64 , 1 };
1483
1483
1484
1484
// spec constants and tile sizes for quant matmul (Qi_K)
1485
- l_warptile_mmq_k = { 256 , 128 , 512 , 16 };
1486
- m_warptile_mmq_k = { 256 , 128 , 256 , 16 };
1487
- s_warptile_mmq_k = { 256 , 32 , 128 , 64 };
1488
- l_mmq_wg_denoms_k = { 128 , 512 , 1 };
1489
- m_mmq_wg_denoms_k = { 128 , 256 , 1 };
1490
- s_mmq_wg_denoms_k = { 32 , 128 , 1 };
1485
+ l_warptile_mmq_k = { 256 , 64 , 128 , 64 };
1486
+ m_warptile_mmq_k = { 256 , 32 , 64 , 64 };
1487
+ s_warptile_mmq_k = { 256 , 32 , 32 , 128 };
1488
+ l_mmq_wg_denoms_k = { 64 , 128 , 1 };
1489
+ m_mmq_wg_denoms_k = { 32 , 64 , 1 };
1490
+ s_mmq_wg_denoms_k = { 32 , 32 , 1 };
1491
1491
1492
1492
// spec constants and tile sizes for quant matmul_id
1493
- l_warptile_mmqid = { 256 , 128 , 128 , 16 };
1493
+ l_warptile_mmqid = { 256 , 128 , 64 , 16 };
1494
1494
m_warptile_mmqid = { 256 , 128 , 64 , 16 };
1495
- s_warptile_mmqid = { 256 , 64 , 64 , 16 };
1496
- l_mmqid_wg_denoms = { 128 , 128 , 1 };
1495
+ s_warptile_mmqid = { 256 , 128 , 64 , 16 };
1496
+ l_mmqid_wg_denoms = { 128 , 64 , 1 };
1497
1497
m_mmqid_wg_denoms = { 128 , 64 , 1 };
1498
- s_mmqid_wg_denoms = { 64 , 64 , 1 };
1498
+ s_mmqid_wg_denoms = { 128 , 64 , 1 };
1499
1499
1500
1500
l_align = 128 ;
1501
1501
m_align = 64 ;
@@ -3850,10 +3850,14 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
3850
3850
VK_LOG_DEBUG (" ggml_vk_guess_matmul_pipeline(" << m << " , " << n << " , " << aligned << " , " << ggml_type_name (src0_type) << " )" );
3851
3851
3852
3852
if (ctx->device ->coopmat2 ) {
3853
- if ((ctx->device ->mul_mat_l [src0_type] && (m % mmp->l ->wg_denoms [0 ]) == 0 && (n % mmp->l ->wg_denoms [1 ]) == 0 ) || (!ctx->device ->mul_mat_m [src0_type] && !ctx->device ->mul_mat_s [src0_type])) {
3853
+ // Use large shader when the N dimension is greater than the medium shader's tile size
3854
+ uint32_t crossover_large = mmp->m ->wg_denoms [1 ];
3855
+ if ((ctx->device ->mul_mat_l [src0_type] && (n > crossover_large)) || (!ctx->device ->mul_mat_m [src0_type] && !ctx->device ->mul_mat_s [src0_type])) {
3854
3856
return aligned ? mmp->a_l : mmp->l ;
3855
3857
}
3856
- if ((ctx->device ->mul_mat_m [src0_type] && (m % mmp->m ->wg_denoms [0 ]) == 0 && (n % mmp->m ->wg_denoms [1 ]) == 0 ) || !ctx->device ->mul_mat_s [src0_type]) {
3858
+ // Use medium shader when the N dimension is greater than the small shader's tile size
3859
+ uint32_t crossover_medium = mmp->s ->wg_denoms [1 ];
3860
+ if ((ctx->device ->mul_mat_m [src0_type] && (n > crossover_medium)) || !ctx->device ->mul_mat_s [src0_type]) {
3857
3861
return aligned ? mmp->a_m : mmp->m ;
3858
3862
}
3859
3863
return aligned ? mmp->a_s : mmp->s ;
@@ -3898,13 +3902,17 @@ static void ggml_vk_matmul(
3898
3902
}
3899
3903
3900
3904
static vk_pipeline ggml_vk_guess_matmul_id_pipeline (ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
3901
- VK_LOG_DEBUG (" ggml_vk_guess_matmul_pipeline (" << m << " , " << n << " , " << aligned << " , " << ggml_type_name (src0_type) << " )" );
3905
+ VK_LOG_DEBUG (" ggml_vk_guess_matmul_id_pipeline (" << m << " , " << n << " , " << aligned << " , " << ggml_type_name (src0_type) << " )" );
3902
3906
3903
3907
if (ctx->device ->coopmat2 ) {
3904
- if ((ctx->device ->mul_mat_id_l [src0_type] && (m % mmp->l ->wg_denoms [0 ]) == 0 && (n % mmp->l ->wg_denoms [1 ]) == 0 ) || (!ctx->device ->mul_mat_id_m [src0_type] && !ctx->device ->mul_mat_id_s [src0_type])) {
3908
+ // Use large shader when the N dimension is greater than the medium shader's tile size
3909
+ uint32_t crossover_large = mmp->m ->wg_denoms [1 ];
3910
+ if ((ctx->device ->mul_mat_id_l [src0_type] && (n > crossover_large)) || (!ctx->device ->mul_mat_id_m [src0_type] && !ctx->device ->mul_mat_id_s [src0_type])) {
3905
3911
return aligned ? mmp->a_l : mmp->l ;
3906
3912
}
3907
- if ((ctx->device ->mul_mat_id_m [src0_type] && (m % mmp->m ->wg_denoms [0 ]) == 0 && (n % mmp->m ->wg_denoms [1 ]) == 0 ) || !ctx->device ->mul_mat_id_s [src0_type]) {
3913
+ // Use medium shader when the N dimension is greater than the small shader's tile size
3914
+ uint32_t crossover_medium = mmp->s ->wg_denoms [1 ];
3915
+ if ((ctx->device ->mul_mat_id_m [src0_type] && (n > crossover_medium)) || !ctx->device ->mul_mat_id_s [src0_type]) {
3908
3916
return aligned ? mmp->a_m : mmp->m ;
3909
3917
}
3910
3918
return aligned ? mmp->a_s : mmp->s ;
0 commit comments