Skip to content

Commit b51b69e

Browse files
committed
Refactor validation and enumeration platform checks into functions to clean up ggml_vk_instance_init()
1 parent c143105 commit b51b69e

File tree

1 file changed

+61
-36
lines changed

1 file changed

+61
-36
lines changed

ggml-vulkan.cpp

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
10791079
}
10801080
}
10811081

1082+
static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
1083+
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
1084+
10821085
void ggml_vk_instance_init() {
10831086
if (vk_instance_initialized) {
10841087
return;
@@ -1090,54 +1093,40 @@ void ggml_vk_instance_init() {
10901093
vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION };
10911094

10921095
const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
1093-
#ifdef __APPLE__
1094-
bool portability_enumeration_ext = false;
1095-
// Check for portability enumeration extension for MoltenVK support
1096-
for (const auto& properties : instance_extensions) {
1097-
if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
1098-
portability_enumeration_ext = true;
1099-
break;
1100-
}
1096+
const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions);
1097+
const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
1098+
1099+
std::vector<const char*> layers;
1100+
1101+
if (validation_ext) {
1102+
layers.push_back("VK_LAYER_KHRONOS_validation");
11011103
}
1102-
if (!portability_enumeration_ext) {
1103-
std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
1104+
std::vector<const char*> extensions;
1105+
if (validation_ext) {
1106+
extensions.push_back("VK_EXT_validation_features");
11041107
}
1105-
#endif
1106-
1107-
std::vector<const char*> layers = {
1108-
#ifdef GGML_VULKAN_VALIDATE
1109-
"VK_LAYER_KHRONOS_validation",
1110-
#endif
1111-
};
1112-
std::vector<const char*> extensions = {
1113-
#ifdef GGML_VULKAN_VALIDATE
1114-
"VK_EXT_validation_features",
1115-
#endif
1116-
};
1117-
#ifdef __APPLE__
11181108
if (portability_enumeration_ext) {
11191109
extensions.push_back("VK_KHR_portability_enumeration");
11201110
}
1121-
#endif
11221111
vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
1123-
#ifdef __APPLE__
11241112
if (portability_enumeration_ext) {
11251113
instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;
11261114
}
1127-
#endif
11281115

1116+
std::vector<vk::ValidationFeatureEnableEXT> features_enable;
1117+
vk::ValidationFeaturesEXT validation_features;
11291118

1130-
#ifdef GGML_VULKAN_VALIDATE
1131-
const std::vector<vk::ValidationFeatureEnableEXT> features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices };
1132-
vk::ValidationFeaturesEXT validation_features = {
1133-
features_enable,
1134-
{},
1135-
};
1136-
validation_features.setPNext(nullptr);
1137-
instance_create_info.setPNext(&validation_features);
1119+
if (validation_ext) {
1120+
features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices };
1121+
validation_features = {
1122+
features_enable,
1123+
{},
1124+
};
1125+
validation_features.setPNext(nullptr);
1126+
instance_create_info.setPNext(&validation_features);
11381127

1139-
std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl;
1140-
#endif
1128+
std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl;
1129+
}
11411130
vk_instance.instance = vk::createInstance(instance_create_info);
11421131

11431132
memset(vk_instance.initialized, 0, sizeof(bool) * GGML_VK_MAX_DEVICES);
@@ -5227,6 +5216,42 @@ GGML_CALL int ggml_backend_vk_reg_devices() {
52275216
return vk_instance.device_indices.size();
52285217
}
52295218

5219+
// Extension availability
5220+
static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) {
5221+
#ifdef GGML_VULKAN_VALIDATE
5222+
bool portability_enumeration_ext = false;
5223+
// Check for portability enumeration extension for MoltenVK support
5224+
for (const auto& properties : instance_extensions) {
5225+
if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
5226+
return true;
5227+
}
5228+
}
5229+
if (!portability_enumeration_ext) {
5230+
std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
5231+
}
5232+
#endif
5233+
return false;
5234+
5235+
UNUSED(instance_extensions);
5236+
}
5237+
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) {
5238+
#ifdef __APPLE__
5239+
bool portability_enumeration_ext = false;
5240+
// Check for portability enumeration extension for MoltenVK support
5241+
for (const auto& properties : instance_extensions) {
5242+
if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
5243+
return true;
5244+
}
5245+
}
5246+
if (!portability_enumeration_ext) {
5247+
std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
5248+
}
5249+
#endif
5250+
return false;
5251+
5252+
UNUSED(instance_extensions);
5253+
}
5254+
52305255
// checks
52315256

52325257
#ifdef GGML_VULKAN_CHECK_RESULTS

0 commit comments

Comments
 (0)