Skip to content

Commit 35f6369

Browse files
committed
Force subgroup 32 on RX 5700 and subgroup 64 for im2col
1 parent 62733f2 commit 35f6369

File tree

1 file changed

+35
-4
lines changed

1 file changed

+35
-4
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1543,11 +1543,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
15431543
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
15441544
}
15451545

1546+
vk::PhysicalDeviceProperties2 props2;
1547+
device->physical_device.getProperties2(&props2);
1548+
std::string device_name = props2.properties.deviceName.data();
15461549
std::vector<std::future<void>> compiles;
15471550
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
15481551
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
15491552
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
15501553

1554+
if (required_subgroup_size == 0) {
1555+
required_subgroup_size = (device_name.find("RX 5700") != std::string::npos) ? 32 : required_subgroup_size;
1556+
}
1557+
15511558
if (!pipeline) {
15521559
pipeline = std::make_shared<vk_pipeline_struct>();
15531560
pipeline->name = name;
@@ -1573,6 +1580,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
15731580
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
15741581
};
15751582

1583+
// New lambda for pipelines with subgroup size 64.
1584+
auto const &ggml_vk_create_pipeline_64 = [&](vk_device& device, vk_pipeline& pipeline,
1585+
const std::string &name, size_t spv_size, const void* spv_data,
1586+
const std::string &entrypoint, uint32_t parameter_count,
1587+
uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms,
1588+
const std::vector<uint32_t>& specialization_constants, uint32_t align,
1589+
bool disable_robustness = false, bool require_full_subgroups = false)
1590+
{
1591+
ggml_vk_create_pipeline(device, pipeline, name, spv_size, spv_data, entrypoint,
1592+
parameter_count, push_constant_size, wg_denoms,
1593+
specialization_constants, align, disable_robustness,
1594+
require_full_subgroups, 64);
1595+
};
1596+
15761597
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
15771598
if (device->coopmat2) {
15781599

@@ -2151,11 +2172,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
21512172

21522173
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
21532174

2154-
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2155-
if (device->float_controls_rte_fp16) {
2156-
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2175+
// Workaround needed to speedup im2col on RX 5700
2176+
if (device_name.find("RX 5700") != std::string::npos) {
2177+
ggml_vk_create_pipeline_64(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2178+
if (device->float_controls_rte_fp16) {
2179+
ggml_vk_create_pipeline_64(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2180+
} else {
2181+
ggml_vk_create_pipeline_64(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2182+
}
21572183
} else {
2158-
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2184+
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2185+
if (device->float_controls_rte_fp16) {
2186+
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2187+
} else {
2188+
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2189+
}
21592190
}
21602191

21612192
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);

0 commit comments

Comments
 (0)