Skip to content

Commit ecc93d0

Browse files
authored
vulkan: compile a test shader in cmake to check for coopmat2 support (ggml-org#10713)
1 parent 62e84d9 commit ecc93d0

File tree

4 files changed

+36
-8
lines changed

4 files changed

+36
-8
lines changed

ggml/src/ggml-vulkan/CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,20 @@ if (Vulkan_FOUND)
88
../../include/ggml-vulkan.h
99
)
1010

11+
# Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported.
12+
# If it's not, there will be an error to stderr.
13+
# If it's supported, set a define to indicate that we should compile those shaders
14+
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp"
15+
OUTPUT_VARIABLE glslc_output
16+
ERROR_VARIABLE glslc_error)
17+
18+
if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*")
19+
message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc")
20+
else()
21+
message(STATUS "GL_NV_cooperative_matrix2 supported by glslc")
22+
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
23+
endif()
24+
1125
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
1226
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
1327

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,7 +1513,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
15131513
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness));
15141514
};
15151515

1516-
#if defined(VK_NV_cooperative_matrix2)
1516+
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
15171517
if (device->coopmat2) {
15181518

15191519
auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
@@ -1611,7 +1611,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
16111611
#undef CREATE_MM
16121612
#undef CREATE_MM2
16131613
} else
1614-
#endif // defined(VK_NV_cooperative_matrix2)
1614+
#endif // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
16151615
if (device->coopmat_support) {
16161616
// Create 6 variants, {s,m,l}x{unaligned,aligned}
16171617
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@@ -2153,7 +2153,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
21532153
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
21542154

21552155
if (coopmat2_support) {
2156-
#if defined(VK_NV_cooperative_matrix2)
2156+
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
21572157
if (coopmat2_features.cooperativeMatrixWorkgroupScope &&
21582158
coopmat2_features.cooperativeMatrixFlexibleDimensions &&
21592159
coopmat2_features.cooperativeMatrixReductions &&
@@ -2414,14 +2414,19 @@ static void ggml_vk_print_gpu_info(size_t idx) {
24142414
bool fp16_storage = false;
24152415
bool fp16_compute = false;
24162416
bool coopmat_support = false;
2417+
bool coopmat2_support = false;
24172418

24182419
for (auto properties : ext_props) {
24192420
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
24202421
fp16_storage = true;
24212422
} else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
24222423
fp16_compute = true;
2423-
} else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) {
2424+
} else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
2425+
!getenv("GGML_VK_DISABLE_COOPMAT")) {
24242426
coopmat_support = true;
2427+
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
2428+
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
2429+
coopmat2_support = true;
24252430
}
24262431
}
24272432

@@ -2472,9 +2477,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
24722477

24732478
coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix;
24742479

2480+
std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
2481+
24752482
std::string device_name = props2.properties.deviceName.data();
2476-
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %d\n",
2477-
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, coopmat_support);
2483+
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %s\n",
2484+
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, matrix_cores.c_str());
24782485

24792486
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
24802487
GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#version 460
2+
3+
#extension GL_NV_cooperative_matrix2 : require
4+
5+
void main()
6+
{
7+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,14 +342,14 @@ void process_shaders() {
342342
matmul_shaders(true, matmul_id, true, false, false);
343343
matmul_shaders(true, matmul_id, true, false, true);
344344

345-
#if defined(VK_NV_cooperative_matrix2)
345+
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
346346
// Coopmat2, fp32acc and fp16acc
347347
matmul_shaders(true, matmul_id, false, true, false);
348348
matmul_shaders(true, matmul_id, false, true, true);
349349
#endif
350350
}
351351

352-
#if defined(VK_NV_cooperative_matrix2)
352+
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
353353
// flash attention
354354
for (const auto& f16acc : {false, true}) {
355355
std::string acctype = f16acc ? "float16_t" : "float";

0 commit comments

Comments
 (0)