Skip to content

Commit 14ea4fa

Browse files
committed
Helper function to set subgroup size
1 parent 293edef commit 14ea4fa

File tree

1 file changed

+36
-28
lines changed

1 file changed

+36
-28
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,6 +1423,36 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
14231423
return supported;
14241424
}
14251425

1426+
// Define a configuration map per GPU.
1427+
// Outer key: GPU identifier (e.g. "RX 5700").
1428+
// Inner map: key is pipeline name; value is the subgroup size.
1429+
static std::unordered_map<std::string, std::unordered_map<std::string, uint32_t>> gpu_pipeline_config = {
1430+
{"RX 5700", {
1431+
{"im2col_f32", 64},
1432+
{"im2col_f32_f16", 64}
1433+
}}
1434+
};
1435+
1436+
// Helper function defined at namespace scope.
1437+
static uint32_t get_subgroup_size(const std::string &pipeline_name, const std::string &device_name) {
1438+
std::string foundKey;
1439+
for (const auto &entry : gpu_pipeline_config) {
1440+
if (device_name.find(entry.first) != std::string::npos) {
1441+
foundKey = entry.first;
1442+
break;
1443+
}
1444+
}
1445+
if (!foundKey.empty()) {
1446+
auto &pipelineMap = gpu_pipeline_config[foundKey];
1447+
auto pipIt = pipelineMap.find(pipeline_name);
1448+
if (pipIt != pipelineMap.end() && pipIt->second != 0) {
1449+
return pipIt->second;
1450+
}
1451+
}
1452+
// If not defined, return 0.
1453+
return 0;
1454+
}
1455+
14261456
static void ggml_vk_load_shaders(vk_device& device) {
14271457
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
14281458

@@ -1546,11 +1576,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
15461576
vk::PhysicalDeviceProperties2 props2;
15471577
device->physical_device.getProperties2(&props2);
15481578
std::string device_name = props2.properties.deviceName.data();
1579+
15491580
std::vector<std::future<void>> compiles;
15501581
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
15511582
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
15521583
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
15531584

1585+
required_subgroup_size = get_subgroup_size(name, device_name);
15541586
if (required_subgroup_size == 0) {
15551587
required_subgroup_size = (device_name.find("RX 5700") != std::string::npos) ? 32 : required_subgroup_size;
15561588
}
@@ -1580,20 +1612,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
15801612
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
15811613
};
15821614

1583-
// New lambda for pipelines with subgroup size 64.
1584-
auto const &ggml_vk_create_pipeline_64 = [&](vk_device& device, vk_pipeline& pipeline,
1585-
const std::string &name, size_t spv_size, const void* spv_data,
1586-
const std::string &entrypoint, uint32_t parameter_count,
1587-
uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms,
1588-
const std::vector<uint32_t>& specialization_constants, uint32_t align,
1589-
bool disable_robustness = false, bool require_full_subgroups = false)
1590-
{
1591-
ggml_vk_create_pipeline(device, pipeline, name, spv_size, spv_data, entrypoint,
1592-
parameter_count, push_constant_size, wg_denoms,
1593-
specialization_constants, align, disable_robustness,
1594-
require_full_subgroups, 64);
1595-
};
1596-
15971615
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
15981616
if (device->coopmat2) {
15991617

@@ -2172,21 +2190,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
21722190

21732191
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
21742192

2175-
// Workaround needed to speedup im2col on RX 5700
2176-
if (device_name.find("RX 5700") != std::string::npos) {
2177-
ggml_vk_create_pipeline_64(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2178-
if (device->float_controls_rte_fp16) {
2179-
ggml_vk_create_pipeline_64(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2180-
} else {
2181-
ggml_vk_create_pipeline_64(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2182-
}
2193+
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2194+
if (device->float_controls_rte_fp16) {
2195+
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
21832196
} else {
2184-
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2185-
if (device->float_controls_rte_fp16) {
2186-
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2187-
} else {
2188-
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2189-
}
2197+
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
21902198
}
21912199

21922200
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);

0 commit comments

Comments
 (0)