Skip to content

Commit 1577cfd

Browse files
committed
vulkan: Adjust coopmat2 tile sizes and selection heuristic
1 parent ba76543 commit 1577cfd

File tree

1 file changed

+25
-17
lines changed

1 file changed

+25
-17
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,26 +1476,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
14761476
// spec constants and tile sizes for quant matmul (non-Qi_K)
14771477
l_warptile_mmq = { 256, 128, 256, 64 };
14781478
m_warptile_mmq = { 256, 128, 128, 64 };
1479-
s_warptile_mmq = { 256, 128, 128, 64 };
1479+
s_warptile_mmq = { 256, 32, 64, 128 };
14801480
l_mmq_wg_denoms = { 128, 256, 1 };
14811481
m_mmq_wg_denoms = { 128, 128, 1 };
1482-
s_mmq_wg_denoms = { 128, 128, 1 };
1482+
s_mmq_wg_denoms = { 32, 64, 1 };
14831483

14841484
// spec constants and tile sizes for quant matmul (Qi_K)
1485-
l_warptile_mmq_k = { 256, 128, 512, 16 };
1486-
m_warptile_mmq_k = { 256, 128, 256, 16 };
1487-
s_warptile_mmq_k = { 256, 32, 128, 64 };
1488-
l_mmq_wg_denoms_k = { 128, 512, 1 };
1489-
m_mmq_wg_denoms_k = { 128, 256, 1 };
1490-
s_mmq_wg_denoms_k = { 32, 128, 1 };
1485+
l_warptile_mmq_k = { 256, 64, 128, 64 };
1486+
m_warptile_mmq_k = { 256, 32, 64, 64 };
1487+
s_warptile_mmq_k = { 256, 32, 32, 128 };
1488+
l_mmq_wg_denoms_k = { 64, 128, 1 };
1489+
m_mmq_wg_denoms_k = { 32, 64, 1 };
1490+
s_mmq_wg_denoms_k = { 32, 32, 1 };
14911491

14921492
// spec constants and tile sizes for quant matmul_id
1493-
l_warptile_mmqid = { 256, 128, 128, 16 };
1493+
l_warptile_mmqid = { 256, 128, 64, 16 };
14941494
m_warptile_mmqid = { 256, 128, 64, 16 };
1495-
s_warptile_mmqid = { 256, 64, 64, 16 };
1496-
l_mmqid_wg_denoms = { 128, 128, 1 };
1495+
s_warptile_mmqid = { 256, 128, 64, 16 };
1496+
l_mmqid_wg_denoms = { 128, 64, 1 };
14971497
m_mmqid_wg_denoms = { 128, 64, 1 };
1498-
s_mmqid_wg_denoms = { 64, 64, 1 };
1498+
s_mmqid_wg_denoms = { 128, 64, 1 };
14991499

15001500
l_align = 128;
15011501
m_align = 64;
@@ -3850,10 +3850,14 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
38503850
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
38513851

38523852
if (ctx->device->coopmat2) {
3853-
if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
3853+
// Use large shader when the N dimension is greater than the medium shader's tile size
3854+
uint32_t crossover_large = mmp->m->wg_denoms[1];
3855+
if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
38543856
return aligned ? mmp->a_l : mmp->l;
38553857
}
3856-
if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) {
3858+
// Use medium shader when the N dimension is greater than the small shader's tile size
3859+
uint32_t crossover_medium = mmp->s->wg_denoms[1];
3860+
if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) {
38573861
return aligned ? mmp->a_m : mmp->m;
38583862
}
38593863
return aligned ? mmp->a_s : mmp->s;
@@ -3898,13 +3902,17 @@ static void ggml_vk_matmul(
38983902
}
38993903

39003904
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
3901-
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
3905+
VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
39023906

39033907
if (ctx->device->coopmat2) {
3904-
if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
3908+
// Use large shader when the N dimension is greater than the medium shader's tile size
3909+
uint32_t crossover_large = mmp->m->wg_denoms[1];
3910+
if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
39053911
return aligned ? mmp->a_l : mmp->l;
39063912
}
3907-
if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) {
3913+
// Use medium shader when the N dimension is greater than the small shader's tile size
3914+
uint32_t crossover_medium = mmp->s->wg_denoms[1];
3915+
if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) {
39083916
return aligned ? mmp->a_m : mmp->m;
39093917
}
39103918
return aligned ? mmp->a_s : mmp->s;

0 commit comments

Comments
 (0)