@@ -1371,7 +1371,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
1371
1371
// Needs to be kept up to date on shader changes
1372
1372
const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1 ;
1373
1373
const uint32_t type_size = device->fp16 ? sizeof (ggml_fp16_t ) : sizeof (float );
1374
- const uint32_t warps = warptile[0 ] / device-> subgroup_size ;
1374
+ const uint32_t warps = warptile[0 ] / warptile[ 10 ] ;
1375
1375
1376
1376
const uint32_t load_bufs = (warptile[1 ] + warptile[2 ]) * (warptile[3 ] + bank_conflict_offset) * type_size;
1377
1377
const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof (uint32_t ) : 0 ;
@@ -1385,8 +1385,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
1385
1385
1386
1386
std::cerr << " ggml_vulkan: Compiling shaders" ;
1387
1387
1388
- // some shaders require the subgroup size to be 16 or larger
1388
+ // some shaders have a minimum subgroup size
1389
1389
const uint32_t subgroup_size_16 = std::max (device->subgroup_size , 16u );
1390
+ const uint32_t subgroup_size_32 = std::max (device->subgroup_size , 32u );
1390
1391
1391
1392
// mulmat
1392
1393
std::vector<uint32_t > l_warptile, m_warptile, s_warptile,
@@ -1453,7 +1454,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1453
1454
1454
1455
l_warptile_mmq = { 128 , 128 , 128 , 32 , device->subgroup_size * 2 , 64 , 2 , tm_l, tn_l, tk_l, device->subgroup_size };
1455
1456
m_warptile_mmq = { 128 , 64 , 64 , 32 , device->subgroup_size , 32 , 2 , tm_m, tn_m, tk_m, device->subgroup_size };
1456
- s_warptile_mmq = { subgroup_size_16 , 32 , 32 , 32 , 32 , 32 , 2 , tm_s, tn_s, tk_s, device->subgroup_size };
1457
+ s_warptile_mmq = { subgroup_size_32 , 32 , 32 , 32 , 32 , 32 , 2 , tm_s, tn_s, tk_s, device->subgroup_size };
1457
1458
1458
1459
l_mmq_wg_denoms = l_wg_denoms = {128 , 128 , 1 };
1459
1460
m_mmq_wg_denoms = m_wg_denoms = { 64 , 64 , 1 };
@@ -1872,7 +1873,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1872
1873
ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f32_f32 [GGML_TYPE_Q4_K], " mul_mat_vec_q4_k_f32_f32" , mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {1 , 1 , 1 }, {subgroup_size_16}, 1 , true );
1873
1874
ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f32_f32 [GGML_TYPE_Q5_K], " mul_mat_vec_q5_k_f32_f32" , mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {1 , 1 , 1 }, {subgroup_size_16}, 1 , true );
1874
1875
ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f32_f32 [GGML_TYPE_Q6_K], " mul_mat_vec_q6_k_f32_f32" , mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {1 , 1 , 1 }, {subgroup_size_16}, 1 , true );
1875
- ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f32_f32 [GGML_TYPE_IQ4_NL], " mul_mat_vec_iq4_nl_f32_f32" , mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {2 *rm, 1 , 1 }, {device-> subgroup_size , 2 *rm}, 1 , true );
1876
+ ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f32_f32 [GGML_TYPE_IQ4_NL], " mul_mat_vec_iq4_nl_f32_f32" , mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {2 *rm, 1 , 1 }, {subgroup_size_16 , 2 *rm}, 1 , true );
1876
1877
1877
1878
ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f16_f32 [GGML_TYPE_F32 ], " mul_mat_vec_f32_f16_f32" , mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {2 , 1 , 1 }, {device->subgroup_size , 2 }, 1 );
1878
1879
ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f16_f32 [GGML_TYPE_F16 ], " mul_mat_vec_f16_f16_f32" , mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {2 , 1 , 1 }, {device->subgroup_size , 2 }, 1 );
@@ -1886,7 +1887,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1886
1887
ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f16_f32 [GGML_TYPE_Q4_K], " mul_mat_vec_q4_k_f16_f32" , mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {1 , 1 , 1 }, {subgroup_size_16}, 1 , true );
1887
1888
ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f16_f32 [GGML_TYPE_Q5_K], " mul_mat_vec_q5_k_f16_f32" , mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {1 , 1 , 1 }, {subgroup_size_16}, 1 , true );
1888
1889
ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f16_f32 [GGML_TYPE_Q6_K], " mul_mat_vec_q6_k_f16_f32" , mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {1 , 1 , 1 }, {subgroup_size_16}, 1 , true );
1889
- ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f16_f32 [GGML_TYPE_IQ4_NL], " mul_mat_vec_iq4_nl_f16_f32" , mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {2 *rm, 1 , 1 }, {device-> subgroup_size , 2 *rm}, 1 , true );
1890
+ ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_f16_f32 [GGML_TYPE_IQ4_NL], " mul_mat_vec_iq4_nl_f16_f32" , mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, " main" , 3 , sizeof (vk_mat_vec_push_constants), {2 *rm, 1 , 1 }, {subgroup_size_16 , 2 *rm}, 1 , true );
1890
1891
1891
1892
ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_id_f32 [GGML_TYPE_F32 ], " mul_mat_vec_id_f32_f32" , mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, " main" , 4 , sizeof (vk_mat_vec_id_push_constants), {2 , 1 , 1 }, {device->subgroup_size , 2 }, 1 );
1892
1893
ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_id_f32 [GGML_TYPE_F16 ], " mul_mat_vec_id_f16_f32" , mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, " main" , 4 , sizeof (vk_mat_vec_id_push_constants), {2 , 1 , 1 }, {device->subgroup_size , 2 }, 1 );
@@ -1900,7 +1901,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1900
1901
ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_id_f32 [GGML_TYPE_Q4_K], " mul_mat_vec_id_q4_k_f32" , mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, " main" , 4 , sizeof (vk_mat_vec_id_push_constants), {1 , 1 , 1 }, {subgroup_size_16}, 1 , true );
1901
1902
ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_id_f32 [GGML_TYPE_Q5_K], " mul_mat_vec_id_q5_k_f32" , mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, " main" , 4 , sizeof (vk_mat_vec_id_push_constants), {1 , 1 , 1 }, {subgroup_size_16}, 1 , true );
1902
1903
ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_id_f32 [GGML_TYPE_Q6_K], " mul_mat_vec_id_q6_k_f32" , mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, " main" , 4 , sizeof (vk_mat_vec_id_push_constants), {1 , 1 , 1 }, {subgroup_size_16}, 1 , true );
1903
- ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_id_f32 [GGML_TYPE_IQ4_NL], " mul_mat_vec_id_iq4_nl_f32" , mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, " main" , 4 , sizeof (vk_mat_vec_id_push_constants), {2 *rm, 1 , 1 }, {device-> subgroup_size , 2 *rm}, 1 , true );
1904
+ ggml_vk_create_pipeline (device, device->pipeline_dequant_mul_mat_vec_id_f32 [GGML_TYPE_IQ4_NL], " mul_mat_vec_id_iq4_nl_f32" , mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, " main" , 4 , sizeof (vk_mat_vec_id_push_constants), {2 *rm, 1 , 1 }, {subgroup_size_16 , 2 *rm}, 1 , true );
1904
1905
1905
1906
// dequant shaders
1906
1907
ggml_vk_create_pipeline (device, device->pipeline_dequant [GGML_TYPE_F32 ], " f32_to_f16" , dequant_f32_len, dequant_f32_data, " main" , 2 , 5 * sizeof (uint32_t ), {256 * 16 , 1 , 1 }, {}, 1 );
0 commit comments