Skip to content

Commit 2dc175f

Browse files
committed
Vulkan: Add VK_EXT_subgroup_size_control support to ensure full subgroups for coopmats
1 parent dafae66 commit 2dc175f

File tree

1 file changed

+79
-23
lines changed

1 file changed

+79
-23
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 79 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,19 @@ struct vk_device_struct {
162162
uint32_t subgroup_size;
163163
uint32_t shader_core_count;
164164
bool uma;
165-
bool coopmat2;
165+
166+
bool subgroup_size_control;
167+
uint32_t subgroup_min_size;
168+
uint32_t subgroup_max_size;
169+
bool subgroup_require_full_support;
166170

167171
bool coopmat_support;
168172
bool coopmat_acc_f32_support;
169173
bool coopmat_acc_f16_support;
170174
uint32_t coopmat_m;
171175
uint32_t coopmat_n;
172176
uint32_t coopmat_k;
177+
bool coopmat2;
173178

174179
size_t idx;
175180

@@ -748,8 +753,12 @@ static uint32_t compile_count = 0;
748753
static std::mutex compile_count_mutex;
749754
static std::condition_variable compile_count_cond;
750755

751-
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align, bool disable_robustness) {
752-
VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
756+
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint,
757+
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
758+
uint32_t align, bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
759+
VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size <<
760+
", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align <<
761+
", " << disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")");
753762
GGML_ASSERT(parameter_count > 0);
754763
GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
755764

@@ -808,14 +817,28 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
808817
specialization_constants.data()
809818
);
810819

820+
vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{};
821+
822+
if (device->subgroup_require_full_support && require_full_subgroups) {
823+
pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT;
824+
}
825+
811826
vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(
812-
vk::PipelineShaderStageCreateFlags(),
827+
pipeline_shader_stage_create_flags,
813828
vk::ShaderStageFlagBits::eCompute,
814829
pipeline->shader_module,
815830
entrypoint.c_str(),
816831
&specialization_info);
832+
833+
vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info;
834+
pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size;
835+
if (device->subgroup_size_control && required_subgroup_size > 0) {
836+
GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size);
837+
pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info);
838+
}
839+
817840
vk::ComputePipelineCreateInfo compute_pipeline_create_info(
818-
vk::PipelineCreateFlags(),
841+
vk::PipelineCreateFlags{},
819842
pipeline_shader_create_info,
820843
pipeline->layout);
821844

@@ -1495,7 +1518,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
14951518
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
14961519

14971520
std::vector<std::future<void>> compiles;
1498-
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t align, bool disable_robustness = false) {
1521+
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
1522+
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
1523+
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
14991524
{
15001525
// wait until fewer than N compiles are in progress
15011526
uint32_t N = std::max(1u, std::thread::hardware_concurrency());
@@ -1505,7 +1530,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
15051530
}
15061531
compile_count++;
15071532
}
1508-
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness));
1533+
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint,
1534+
parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness, require_full_subgroups, required_subgroup_size));
15091535
};
15101536

15111537
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
@@ -1611,17 +1637,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
16111637
// Create 6 variants, {s,m,l}x{unaligned,aligned}
16121638
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
16131639
if (device->mul_mat ## ID ## _l) \
1614-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1640+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
16151641
if (device->mul_mat ## ID ## _m) \
1616-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1642+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
16171643
if (device->mul_mat ## ID ## _s) \
1618-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1644+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
16191645
if (device->mul_mat ## ID ## _l) \
1620-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1646+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
16211647
if (device->mul_mat ## ID ## _m) \
1622-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1648+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
16231649
if (device->mul_mat ## ID ## _s) \
1624-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1650+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
16251651

16261652
// Create 2 variants, {f16,f32} accumulator
16271653
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@@ -1988,6 +2014,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
19882014
amd_shader_core_properties2 = true;
19892015
} else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
19902016
pipeline_robustness = true;
2017+
} else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
2018+
device->subgroup_size_control = true;
19912019
} else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
19922020
!getenv("GGML_VK_DISABLE_COOPMAT")) {
19932021
device->coopmat_support = true;
@@ -2007,6 +2035,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
20072035
vk::PhysicalDeviceDriverProperties driver_props;
20082036
vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
20092037
vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2038+
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
2039+
20102040
props2.pNext = &props3;
20112041
props3.pNext = &subgroup_props;
20122042
subgroup_props.pNext = &driver_props;
@@ -2025,6 +2055,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
20252055
last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
20262056
last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
20272057
}
2058+
if (device->subgroup_size_control) {
2059+
last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props;
2060+
last_struct = (VkBaseOutStructure *)&subgroup_size_control_props;
2061+
}
20282062

20292063
#if defined(VK_NV_cooperative_matrix2)
20302064
vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;
@@ -2062,11 +2096,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
20622096

20632097
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
20642098

2065-
if (device->vendor_id == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2066-
// Intel drivers don't support coopmat properly yet
2067-
// Only RADV supports coopmat properly on AMD
2068-
device->coopmat_support = false;
2069-
}
2099+
// if (device->vendor_id == VK_VENDOR_ID_INTEL || (device->vendor_id == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2100+
// // Intel drivers don't support coopmat properly yet
2101+
// // Only RADV supports coopmat properly on AMD
2102+
// device->coopmat_support = false;
2103+
// }
20702104

20712105
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
20722106

@@ -2118,6 +2152,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
21182152
device_extensions.push_back("VK_EXT_pipeline_robustness");
21192153
}
21202154

2155+
VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features;
2156+
subgroup_size_control_features.pNext = nullptr;
2157+
subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT;
2158+
subgroup_size_control_features.computeFullSubgroups = false;
2159+
subgroup_size_control_features.subgroupSizeControl = false;
2160+
2161+
if (device->subgroup_size_control) {
2162+
last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features;
2163+
last_struct = (VkBaseOutStructure *)&subgroup_size_control_features;
2164+
}
2165+
21212166
VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
21222167
coopmat_features.pNext = nullptr;
21232168
coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
@@ -2145,6 +2190,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
21452190

21462191
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
21472192

2193+
device->subgroup_size_control = device->subgroup_size_control &&
2194+
(!(subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) ||
2195+
!subgroup_size_control_features.subgroupSizeControl);
2196+
2197+
if (device->subgroup_size_control) {
2198+
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
2199+
device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
2200+
device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
2201+
device_extensions.push_back("VK_EXT_subgroup_size_control");
2202+
}
2203+
21482204
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
21492205

21502206
if (coopmat2_support) {
@@ -2427,11 +2483,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
24272483
}
24282484
}
24292485

2430-
if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2431-
// Intel drivers don't support coopmat properly yet
2432-
// Only RADV supports coopmat properly on AMD
2433-
coopmat_support = false;
2434-
}
2486+
// if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
2487+
// // Intel drivers don't support coopmat properly yet
2488+
// // Only RADV supports coopmat properly on AMD
2489+
// coopmat_support = false;
2490+
// }
24352491

24362492
const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
24372493
bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;

0 commit comments

Comments
 (0)