@@ -149,6 +149,66 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf);
149
149
150
150
static constexpr uint32_t mul_mat_vec_max_cols = 8 ;
151
151
152
+ enum vk_device_architecture {
153
+ OTHER,
154
+ AMD_GCN,
155
+ AMD_RDNA1,
156
+ AMD_RDNA2,
157
+ AMD_RDNA3,
158
+ };
159
+
160
+ static vk_device_architecture get_device_architecture (const vk::PhysicalDevice& device) {
161
+ vk::PhysicalDeviceProperties props = device.getProperties ();
162
+
163
+ if (props.vendorID == VK_VENDOR_ID_AMD) {
164
+ const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties ();
165
+
166
+ bool amd_shader_core_properties = false ;
167
+ bool integer_dot_product = false ;
168
+ bool subgroup_size_control = false ;
169
+
170
+ for (const auto & properties : ext_props) {
171
+ if (strcmp (" VK_AMD_shader_core_properties" , properties.extensionName ) == 0 ) {
172
+ amd_shader_core_properties = true ;
173
+ } else if (strcmp (" VK_KHR_shader_integer_dot_product" , properties.extensionName ) == 0 ) {
174
+ integer_dot_product = true ;
175
+ } else if (strcmp (" VK_EXT_subgroup_size_control" , properties.extensionName ) == 0 ) {
176
+ subgroup_size_control = true ;
177
+ }
178
+ }
179
+
180
+ if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) {
181
+ return vk_device_architecture::OTHER;
182
+ }
183
+
184
+ vk::PhysicalDeviceProperties2 props2;
185
+ vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd;
186
+ vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;
187
+ vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
188
+
189
+ props2.pNext = &shader_core_props_amd;
190
+ shader_core_props_amd.pNext = &integer_dot_props;
191
+ integer_dot_props.pNext = &subgroup_size_control_props;
192
+
193
+ device.getProperties2 (&props2);
194
+
195
+ if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64 ) {
196
+ return vk_device_architecture::AMD_GCN;
197
+ }
198
+ if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32 ) {
199
+ // RDNA
200
+ if (shader_core_props_amd.wavefrontsPerSimd == 20 ) {
201
+ return vk_device_architecture::AMD_RDNA1;
202
+ }
203
+ if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated ) {
204
+ return vk_device_architecture::AMD_RDNA3;
205
+ }
206
+ return vk_device_architecture::AMD_RDNA2;
207
+ }
208
+ }
209
+ return vk_device_architecture::OTHER;
210
+ }
211
+
152
212
struct vk_device_struct {
153
213
std::mutex mutex;
154
214
@@ -161,6 +221,7 @@ struct vk_device_struct {
161
221
bool pipeline_robustness;
162
222
vk::Device device;
163
223
uint32_t vendor_id;
224
+ vk_device_architecture architecture;
164
225
vk_queue compute_queue;
165
226
vk_queue transfer_queue;
166
227
bool single_queue;
@@ -2219,7 +2280,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2219
2280
device->need_compiles = false ;
2220
2281
}
2221
2282
2222
- static bool ggml_vk_khr_cooperative_matrix_support (const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props);
2283
+ static bool ggml_vk_khr_cooperative_matrix_support (const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch );
2223
2284
2224
2285
static vk_device ggml_vk_get_device (size_t idx) {
2225
2286
VK_LOG_DEBUG (" ggml_vk_get_device(" << idx << " )" );
@@ -2248,6 +2309,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
2248
2309
device->physical_device = physical_devices[dev_num];
2249
2310
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device .enumerateDeviceExtensionProperties ();
2250
2311
2312
+ device->architecture = get_device_architecture (device->physical_device );
2313
+
2251
2314
bool fp16_storage = false ;
2252
2315
bool fp16_compute = false ;
2253
2316
bool maintenance4_support = false ;
@@ -2257,7 +2320,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
2257
2320
bool coopmat2_support = false ;
2258
2321
device->coopmat_support = false ;
2259
2322
2260
- // Check if maintenance4 is supported
2261
2323
for (const auto & properties : ext_props) {
2262
2324
if (strcmp (" VK_KHR_maintenance4" , properties.extensionName ) == 0 ) {
2263
2325
maintenance4_support = true ;
@@ -2370,7 +2432,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2370
2432
2371
2433
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
2372
2434
2373
- if (!ggml_vk_khr_cooperative_matrix_support (device->properties , driver_props)) {
2435
+ if (!ggml_vk_khr_cooperative_matrix_support (device->properties , driver_props, device-> architecture )) {
2374
2436
device->coopmat_support = false ;
2375
2437
}
2376
2438
@@ -2776,7 +2838,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2776
2838
}
2777
2839
}
2778
2840
2779
- if (!ggml_vk_khr_cooperative_matrix_support (props2.properties , driver_props)) {
2841
+ const vk_device_architecture device_architecture = get_device_architecture (physical_device);
2842
+
2843
+ if (!ggml_vk_khr_cooperative_matrix_support (props2.properties , driver_props, device_architecture)) {
2780
2844
coopmat_support = false ;
2781
2845
}
2782
2846
@@ -8435,18 +8499,15 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
8435
8499
UNUSED (instance_extensions);
8436
8500
}
8437
8501
8438
- static bool ggml_vk_khr_cooperative_matrix_support (const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) {
8502
+ static bool ggml_vk_khr_cooperative_matrix_support (const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch ) {
8439
8503
switch (props.vendorID ) {
8440
8504
case VK_VENDOR_ID_INTEL:
8441
8505
// Intel drivers don't support coopmat properly yet
8442
8506
return false ;
8443
8507
case VK_VENDOR_ID_AMD:
8444
8508
if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
8445
8509
// Workaround for AMD proprietary driver reporting support on all GPUs
8446
- const std::string name = props.deviceName ;
8447
- return name.rfind (" AMD Radeon RX 7" , 0 ) == 0 || name.rfind (" AMD Radeon(TM) RX 7" , 0 ) == 0 || // RDNA 3 consumer GPUs
8448
- name.rfind (" AMD Radeon PRO W7" , 0 ) == 0 || name.rfind (" AMD Radeon(TM) PRO W7" , 0 ) == 0 || // RDNA 3 workstation GPUs
8449
- name.rfind (" AMD Radeon 7" , 0 ) == 0 || name.rfind (" AMD Radeon(TM) 7" , 0 ) == 0 ; // RDNA 3 APUs
8510
+ return arch == vk_device_architecture::AMD_RDNA3;
8450
8511
}
8451
8512
return true ;
8452
8513
default :
0 commit comments