Skip to content

vulkan: Adjust coopmat2 tile sizes and selection heuristic #12258

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1476,26 +1476,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
// spec constants and tile sizes for quant matmul (non-Qi_K)
l_warptile_mmq = { 256, 128, 256, 64 };
m_warptile_mmq = { 256, 128, 128, 64 };
s_warptile_mmq = { 256, 128, 128, 64 };
s_warptile_mmq = { 256, 32, 64, 128 };
l_mmq_wg_denoms = { 128, 256, 1 };
m_mmq_wg_denoms = { 128, 128, 1 };
s_mmq_wg_denoms = { 128, 128, 1 };
s_mmq_wg_denoms = { 32, 64, 1 };

// spec constants and tile sizes for quant matmul (Qi_K)
l_warptile_mmq_k = { 256, 128, 512, 16 };
m_warptile_mmq_k = { 256, 128, 256, 16 };
s_warptile_mmq_k = { 256, 32, 128, 64 };
l_mmq_wg_denoms_k = { 128, 512, 1 };
m_mmq_wg_denoms_k = { 128, 256, 1 };
s_mmq_wg_denoms_k = { 32, 128, 1 };
l_warptile_mmq_k = { 256, 64, 128, 64 };
m_warptile_mmq_k = { 256, 32, 64, 64 };
s_warptile_mmq_k = { 256, 32, 32, 128 };
l_mmq_wg_denoms_k = { 64, 128, 1 };
m_mmq_wg_denoms_k = { 32, 64, 1 };
s_mmq_wg_denoms_k = { 32, 32, 1 };

// spec constants and tile sizes for quant matmul_id
l_warptile_mmqid = { 256, 128, 128, 16 };
l_warptile_mmqid = { 256, 128, 64, 16 };
m_warptile_mmqid = { 256, 128, 64, 16 };
s_warptile_mmqid = { 256, 64, 64, 16 };
l_mmqid_wg_denoms = { 128, 128, 1 };
s_warptile_mmqid = { 256, 128, 64, 16 };
l_mmqid_wg_denoms = { 128, 64, 1 };
m_mmqid_wg_denoms = { 128, 64, 1 };
s_mmqid_wg_denoms = { 64, 64, 1 };
s_mmqid_wg_denoms = { 128, 64, 1 };

l_align = 128;
m_align = 64;
Expand Down Expand Up @@ -3850,10 +3850,14 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");

if (ctx->device->coopmat2) {
if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
// Use large shader when the N dimension is greater than the medium shader's tile size
uint32_t crossover_large = mmp->m->wg_denoms[1];
if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
return aligned ? mmp->a_l : mmp->l;
}
if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) {
// Use medium shader when the N dimension is greater than the small shader's tile size
uint32_t crossover_medium = mmp->s->wg_denoms[1];
if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_s : mmp->s;
Expand Down Expand Up @@ -3898,13 +3902,17 @@ static void ggml_vk_matmul(
}

static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");

if (ctx->device->coopmat2) {
if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
// Use large shader when the N dimension is greater than the medium shader's tile size
uint32_t crossover_large = mmp->m->wg_denoms[1];
if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
return aligned ? mmp->a_l : mmp->l;
}
if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) {
// Use medium shader when the N dimension is greater than the small shader's tile size
uint32_t crossover_medium = mmp->s->wg_denoms[1];
if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_s : mmp->s;
Expand Down
Loading