Skip to content

Commit 9205de6

Browse files
committed
Helper function to set subgroup size
1 parent 293edef commit 9205de6

File tree

1 file changed

+35
-17
lines changed

1 file changed

+35
-17
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,6 +1423,36 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
14231423
return supported;
14241424
}
14251425

1426+
// Define a configuration map per GPU.
1427+
// Outer key: GPU identifier (e.g. "RX 5700").
1428+
// Inner map: key is pipeline name; value is the subgroup size.
1429+
static std::unordered_map<std::string, std::unordered_map<std::string, uint32_t>> gpu_pipeline_config = {
1430+
{"RX 5700", {
1431+
{"im2col_f32", 64},
1432+
{"im2col_f32_f16", 64}
1433+
}}
1434+
};
1435+
1436+
// Helper function defined at namespace scope.
1437+
static uint32_t get_subgroup_size(const std::string &pipeline_name, const std::string &device_name) {
1438+
std::string foundKey;
1439+
for (const auto &entry : gpu_pipeline_config) {
1440+
if (device_name.find(entry.first) != std::string::npos) {
1441+
foundKey = entry.first;
1442+
break;
1443+
}
1444+
}
1445+
if (!foundKey.empty()) {
1446+
auto &pipelineMap = gpu_pipeline_config[foundKey];
1447+
auto pipIt = pipelineMap.find(pipeline_name);
1448+
if (pipIt != pipelineMap.end() && pipIt->second != 0) {
1449+
return pipIt->second;
1450+
}
1451+
}
1452+
// If not defined, return 0.
1453+
return 0;
1454+
}
1455+
14261456
static void ggml_vk_load_shaders(vk_device& device) {
14271457
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
14281458

@@ -1546,11 +1576,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
15461576
vk::PhysicalDeviceProperties2 props2;
15471577
device->physical_device.getProperties2(&props2);
15481578
std::string device_name = props2.properties.deviceName.data();
1579+
15491580
std::vector<std::future<void>> compiles;
15501581
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
15511582
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
15521583
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
15531584

1585+
required_subgroup_size = get_subgroup_size(name, device_name);
15541586
if (required_subgroup_size == 0) {
15551587
required_subgroup_size = (device_name.find("RX 5700") != std::string::npos) ? 32 : required_subgroup_size;
15561588
}
@@ -1580,20 +1612,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
15801612
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
15811613
};
15821614

1583-
// New lambda for pipelines with subgroup size 64.
1584-
auto const &ggml_vk_create_pipeline_64 = [&](vk_device& device, vk_pipeline& pipeline,
1585-
const std::string &name, size_t spv_size, const void* spv_data,
1586-
const std::string &entrypoint, uint32_t parameter_count,
1587-
uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms,
1588-
const std::vector<uint32_t>& specialization_constants, uint32_t align,
1589-
bool disable_robustness = false, bool require_full_subgroups = false)
1590-
{
1591-
ggml_vk_create_pipeline(device, pipeline, name, spv_size, spv_data, entrypoint,
1592-
parameter_count, push_constant_size, wg_denoms,
1593-
specialization_constants, align, disable_robustness,
1594-
require_full_subgroups, 64);
1595-
};
1596-
15971615
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
15981616
if (device->coopmat2) {
15991617

@@ -2174,11 +2192,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
21742192

21752193
// Workaround needed to speedup im2col on RX 5700
21762194
if (device_name.find("RX 5700") != std::string::npos) {
2177-
ggml_vk_create_pipeline_64(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2195+
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
21782196
if (device->float_controls_rte_fp16) {
2179-
ggml_vk_create_pipeline_64(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2197+
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
21802198
} else {
2181-
ggml_vk_create_pipeline_64(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2199+
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
21822200
}
21832201
} else {
21842202
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);

0 commit comments

Comments
 (0)