Skip to content

Commit 595c1a7

Browse files
committed
Vulkan: Add VK_EXT_subgroup_size_control support to ensure full subgroups for coopmats
1 parent ecc93d0 commit 595c1a7

File tree

1 file changed

+79
-23
lines changed

1 file changed

+79
-23
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 79 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,19 @@ struct vk_device_struct {
168168
uint32_t subgroup_size;
169169
uint32_t shader_core_count;
170170
bool uma;
171-
bool coopmat2;
171+
172+
bool subgroup_size_control;
173+
uint32_t subgroup_min_size;
174+
uint32_t subgroup_max_size;
175+
bool subgroup_require_full_support;
172176

173177
bool coopmat_support;
174178
bool coopmat_acc_f32_support;
175179
bool coopmat_acc_f16_support;
176180
uint32_t coopmat_m;
177181
uint32_t coopmat_n;
178182
uint32_t coopmat_k;
183+
bool coopmat2;
179184

180185
size_t idx;
181186

@@ -753,8 +758,12 @@ static uint32_t compile_count = 0;
753758
static std::mutex compile_count_mutex;
754759
static std::condition_variable compile_count_cond;
755760

756-
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align, bool disable_robustness) {
757-
VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
761+
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint,
762+
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
763+
uint32_t align, bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
764+
VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size <<
765+
", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align <<
766+
", " << disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")");
758767
GGML_ASSERT(parameter_count > 0);
759768
GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
760769

@@ -813,14 +822,28 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
813822
specialization_constants.data()
814823
);
815824

825+
vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{};
826+
827+
if (device->subgroup_require_full_support && require_full_subgroups) {
828+
pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT;
829+
}
830+
816831
vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(
817-
vk::PipelineShaderStageCreateFlags(),
832+
pipeline_shader_stage_create_flags,
818833
vk::ShaderStageFlagBits::eCompute,
819834
pipeline->shader_module,
820835
entrypoint.c_str(),
821836
&specialization_info);
837+
838+
vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info;
839+
pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size;
840+
if (device->subgroup_size_control && required_subgroup_size > 0) {
841+
GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size);
842+
pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info);
843+
}
844+
822845
vk::ComputePipelineCreateInfo compute_pipeline_create_info(
823-
vk::PipelineCreateFlags(),
846+
vk::PipelineCreateFlags{},
824847
pipeline_shader_create_info,
825848
pipeline->layout);
826849

@@ -1500,7 +1523,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
15001523
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
15011524

15021525
std::vector<std::future<void>> compiles;
1503-
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t align, bool disable_robustness = false) {
1526+
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
1527+
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
1528+
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
15041529
{
15051530
// wait until fewer than N compiles are in progress
15061531
uint32_t N = std::max(1u, std::thread::hardware_concurrency());
@@ -1510,7 +1535,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
15101535
}
15111536
compile_count++;
15121537
}
1513-
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness));
1538+
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint,
1539+
parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness, require_full_subgroups, required_subgroup_size));
15141540
};
15151541

15161542
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
@@ -1616,17 +1642,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
16161642
// Create 6 variants, {s,m,l}x{unaligned,aligned}
16171643
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
16181644
if (device->mul_mat ## ID ## _l) \
1619-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1645+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
16201646
if (device->mul_mat ## ID ## _m) \
1621-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1647+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
16221648
if (device->mul_mat ## ID ## _s) \
1623-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1649+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
16241650
if (device->mul_mat ## ID ## _l) \
1625-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1651+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
16261652
if (device->mul_mat ## ID ## _m) \
1627-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1653+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
16281654
if (device->mul_mat ## ID ## _s) \
1629-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1655+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
16301656

16311657
// Create 2 variants, {f16,f32} accumulator
16321658
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@@ -1993,6 +2019,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
19932019
amd_shader_core_properties2 = true;
19942020
} else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
19952021
pipeline_robustness = true;
2022+
} else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
2023+
device->subgroup_size_control = true;
19962024
} else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
19972025
!getenv("GGML_VK_DISABLE_COOPMAT")) {
19982026
device->coopmat_support = true;
@@ -2012,6 +2040,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
20122040
vk::PhysicalDeviceDriverProperties driver_props;
20132041
vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
20142042
vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2043+
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
2044+
20152045
props2.pNext = &props3;
20162046
props3.pNext = &subgroup_props;
20172047
subgroup_props.pNext = &driver_props;
@@ -2030,6 +2060,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
20302060
last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
20312061
last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
20322062
}
2063+
if (device->subgroup_size_control) {
2064+
last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props;
2065+
last_struct = (VkBaseOutStructure *)&subgroup_size_control_props;
2066+
}
20332067

20342068
#if defined(VK_NV_cooperative_matrix2)
20352069
vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;
@@ -2067,11 +2101,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
20672101

20682102
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
20692103

2070-
if (device->vendor_id == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2071-
// Intel drivers don't support coopmat properly yet
2072-
// Only RADV supports coopmat properly on AMD
2073-
device->coopmat_support = false;
2074-
}
2104+
// if (device->vendor_id == VK_VENDOR_ID_INTEL || (device->vendor_id == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2105+
// // Intel drivers don't support coopmat properly yet
2106+
// // Only RADV supports coopmat properly on AMD
2107+
// device->coopmat_support = false;
2108+
// }
20752109

20762110
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
20772111

@@ -2123,6 +2157,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
21232157
device_extensions.push_back("VK_EXT_pipeline_robustness");
21242158
}
21252159

2160+
VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features;
2161+
subgroup_size_control_features.pNext = nullptr;
2162+
subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT;
2163+
subgroup_size_control_features.computeFullSubgroups = false;
2164+
subgroup_size_control_features.subgroupSizeControl = false;
2165+
2166+
if (device->subgroup_size_control) {
2167+
last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features;
2168+
last_struct = (VkBaseOutStructure *)&subgroup_size_control_features;
2169+
}
2170+
21262171
VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
21272172
coopmat_features.pNext = nullptr;
21282173
coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
@@ -2150,6 +2195,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
21502195

21512196
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
21522197

2198+
device->subgroup_size_control = device->subgroup_size_control &&
2199+
(!(subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) ||
2200+
!subgroup_size_control_features.subgroupSizeControl);
2201+
2202+
if (device->subgroup_size_control) {
2203+
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
2204+
device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
2205+
device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
2206+
device_extensions.push_back("VK_EXT_subgroup_size_control");
2207+
}
2208+
21532209
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
21542210

21552211
if (coopmat2_support) {
@@ -2430,11 +2486,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
24302486
}
24312487
}
24322488

2433-
if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2434-
// Intel drivers don't support coopmat properly yet
2435-
// Only RADV supports coopmat properly on AMD
2436-
coopmat_support = false;
2437-
}
2489+
// if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2490+
// // Intel drivers don't support coopmat properly yet
2491+
// // Only RADV supports coopmat properly on AMD
2492+
// coopmat_support = false;
2493+
// }
24382494

24392495
const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
24402496
bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;

0 commit comments

Comments
 (0)