@@ -1543,11 +1543,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
1543
1543
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1544
1544
}
1545
1545
1546
+ vk::PhysicalDeviceProperties2 props2;
1547
+ device->physical_device .getProperties2 (&props2);
1548
+ std::string device_name = props2.properties .deviceName .data ();
1546
1549
std::vector<std::future<void >> compiles;
1547
1550
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void * spv_data, const std::string &entrypoint,
1548
1551
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t , 3 > wg_denoms, const std::vector<uint32_t >& specialization_constants,
1549
1552
uint32_t align, bool disable_robustness = false , bool require_full_subgroups = false , uint32_t required_subgroup_size = 0 ) {
1550
1553
1554
+ if (required_subgroup_size == 0 ) {
1555
+ required_subgroup_size = (device_name.find (" RX 5700" ) != std::string::npos) ? 32 : required_subgroup_size;
1556
+ }
1557
+
1551
1558
if (!pipeline) {
1552
1559
pipeline = std::make_shared<vk_pipeline_struct>();
1553
1560
pipeline->name = name;
@@ -1573,6 +1580,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
1573
1580
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
1574
1581
};
1575
1582
1583
+ // New lambda for pipelines with subgroup size 64.
1584
+ auto const &ggml_vk_create_pipeline_64 = [&](vk_device& device, vk_pipeline& pipeline,
1585
+ const std::string &name, size_t spv_size, const void * spv_data,
1586
+ const std::string &entrypoint, uint32_t parameter_count,
1587
+ uint32_t push_constant_size, std::array<uint32_t , 3 > wg_denoms,
1588
+ const std::vector<uint32_t >& specialization_constants, uint32_t align,
1589
+ bool disable_robustness = false , bool require_full_subgroups = false )
1590
+ {
1591
+ ggml_vk_create_pipeline (device, pipeline, name, spv_size, spv_data, entrypoint,
1592
+ parameter_count, push_constant_size, wg_denoms,
1593
+ specialization_constants, align, disable_robustness,
1594
+ require_full_subgroups, 64 );
1595
+ };
1596
+
1576
1597
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
1577
1598
if (device->coopmat2 ) {
1578
1599
@@ -2151,11 +2172,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
2151
2172
2152
2173
ggml_vk_create_pipeline (device, device->pipeline_sum_rows_f32 , " sum_rows_f32" , sum_rows_f32_len, sum_rows_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, { device->subgroup_size }, 1 );
2153
2174
2154
- ggml_vk_create_pipeline (device, device->pipeline_im2col_f32 , " im2col_f32" , im2col_f32_len, im2col_f32_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2155
- if (device->float_controls_rte_fp16 ) {
2156
- ggml_vk_create_pipeline (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2175
+ // Workaround needed to speedup im2col on RX 5700
2176
+ if (device_name.find (" RX 5700" ) != std::string::npos) {
2177
+ ggml_vk_create_pipeline_64 (device, device->pipeline_im2col_f32 , " im2col_f32" , im2col_f32_len, im2col_f32_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2178
+ if (device->float_controls_rte_fp16 ) {
2179
+ ggml_vk_create_pipeline_64 (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2180
+ } else {
2181
+ ggml_vk_create_pipeline_64 (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_len, im2col_f32_f16_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2182
+ }
2157
2183
} else {
2158
- ggml_vk_create_pipeline (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_len, im2col_f32_f16_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2184
+ ggml_vk_create_pipeline (device, device->pipeline_im2col_f32 , " im2col_f32" , im2col_f32_len, im2col_f32_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2185
+ if (device->float_controls_rte_fp16 ) {
2186
+ ggml_vk_create_pipeline (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2187
+ } else {
2188
+ ggml_vk_create_pipeline (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_len, im2col_f32_f16_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
2189
+ }
2159
2190
}
2160
2191
2161
2192
ggml_vk_create_pipeline (device, device->pipeline_timestep_embedding_f32 , " timestep_embedding_f32" , timestep_embedding_f32_len, timestep_embedding_f32_data, " main" , 2 , sizeof (vk_op_timestep_embedding_push_constants), {256 , 1 , 1 }, {}, 1 );
0 commit comments