@@ -1430,6 +1430,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1430
1430
VK_LOG_DEBUG (" ggml_vk_load_shaders(" << device->name << " )" );
1431
1431
1432
1432
// some shaders have a minimum subgroup size
1433
+ const uint32_t subgroup_size_8 = std::max (device->subgroup_size , 8u );
1433
1434
const uint32_t subgroup_size_16 = std::max (device->subgroup_size , 16u );
1434
1435
const uint32_t subgroup_size_32 = std::max (device->subgroup_size , 32u );
1435
1436
@@ -1492,13 +1493,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
1492
1493
const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1 ;
1493
1494
const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1 ;
1494
1495
1495
- l_warptile = { 128 , 128 , 128 , 16 , device-> subgroup_size * 2 , 64 , 2 , tm_l, tn_l, tk_l, device-> subgroup_size };
1496
- m_warptile = { 128 , 64 , 64 , 16 , device-> subgroup_size , 32 , 2 , tm_m, tn_m, tk_m, device-> subgroup_size };
1497
- s_warptile = { subgroup_size_16, 32 , 32 , 16 , 32 , 32 , 2 , tm_s, tn_s, tk_s, device-> subgroup_size };
1496
+ l_warptile = { 128 , 128 , 128 , 16 , subgroup_size_8 * 2 , 64 , 2 , tm_l, tn_l, tk_l, subgroup_size_8 };
1497
+ m_warptile = { 128 , 64 , 64 , 16 , subgroup_size_8 , 32 , 2 , tm_m, tn_m, tk_m, subgroup_size_8 };
1498
+ s_warptile = { subgroup_size_16, 32 , 32 , 16 , 32 , 32 , 2 , tm_s, tn_s, tk_s, subgroup_size_8 };
1498
1499
1499
- l_warptile_mmq = { 128 , 128 , 128 , 32 , device-> subgroup_size * 2 , 64 , 2 , tm_l, tn_l, tk_l, device-> subgroup_size };
1500
- m_warptile_mmq = { 128 , 64 , 64 , 32 , device-> subgroup_size , 32 , 2 , tm_m, tn_m, tk_m, device-> subgroup_size };
1501
- s_warptile_mmq = { subgroup_size_32, 32 , 32 , 32 , 32 , 32 , 2 , tm_s, tn_s, tk_s, device-> subgroup_size };
1500
+ l_warptile_mmq = { 128 , 128 , 128 , 32 , subgroup_size_8 * 2 , 64 , 2 , tm_l, tn_l, tk_l, subgroup_size_8 };
1501
+ m_warptile_mmq = { 128 , 64 , 64 , 32 , subgroup_size_8 , 32 , 2 , tm_m, tn_m, tk_m, subgroup_size_8 };
1502
+ s_warptile_mmq = { subgroup_size_32, 32 , 32 , 32 , 32 , 32 , 2 , tm_s, tn_s, tk_s, subgroup_size_8 };
1502
1503
1503
1504
l_mmq_wg_denoms = l_wg_denoms = {128 , 128 , 1 };
1504
1505
m_mmq_wg_denoms = m_wg_denoms = { 64 , 64 , 1 };
0 commit comments