6
6
* LICENSE file in the root directory of this source tree.
7
7
*/
8
8
9
+ // @lint-ignore-every CLANGTIDY clang-diagnostic-missing-field-initializers
10
+
9
11
#include < executorch/backends/vulkan/runtime/api/Adapter.h>
10
12
11
13
#include < bitset>
@@ -21,15 +23,33 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle)
21
23
: handle(physical_device_handle),
22
24
properties{},
23
25
memory_properties{},
26
+ shader_16bit_storage{
27
+ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES},
28
+ shader_8bit_storage{
29
+ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES},
30
+ shader_float16_int8_types{
31
+ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES_KHR},
24
32
queue_families{},
25
33
num_compute_queues (0 ),
26
34
has_unified_memory (false ),
27
35
has_timestamps (properties.limits.timestampComputeAndGraphics),
28
- timestamp_period (properties.limits.timestampPeriod) {
36
+ timestamp_period (properties.limits.timestampPeriod),
37
+ extension_features (&shader_16bit_storage) {
29
38
// Extract physical device properties
30
39
vkGetPhysicalDeviceProperties (handle, &properties);
31
40
vkGetPhysicalDeviceMemoryProperties (handle, &memory_properties);
32
41
42
+ VkPhysicalDeviceFeatures2 features2{
43
+ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2};
44
+
45
+ // Create linked list to query availability of extensions
46
+ features2.pNext = &shader_16bit_storage;
47
+ shader_16bit_storage.pNext = &shader_8bit_storage;
48
+ shader_8bit_storage.pNext = &shader_float16_int8_types;
49
+ shader_float16_int8_types.pNext = nullptr ;
50
+
51
+ vkGetPhysicalDeviceFeatures2 (handle, &features2);
52
+
33
53
// Check if there are any memory types have both the HOST_VISIBLE and the
34
54
// DEVICE_LOCAL property flags
35
55
const VkMemoryPropertyFlags unified_memory_flags =
@@ -140,6 +160,9 @@ VkDevice create_logical_device(
140
160
#ifdef VK_KHR_portability_subset
141
161
VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME,
142
162
#endif /* VK_KHR_portability_subset */
163
+ VK_KHR_16BIT_STORAGE_EXTENSION_NAME,
164
+ VK_KHR_8BIT_STORAGE_EXTENSION_NAME,
165
+ VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME,
143
166
};
144
167
145
168
std::vector<const char *> enabled_device_extensions;
@@ -148,7 +171,7 @@ VkDevice create_logical_device(
148
171
enabled_device_extensions,
149
172
requested_device_extensions);
150
173
151
- const VkDeviceCreateInfo device_create_info{
174
+ VkDeviceCreateInfo device_create_info{
152
175
VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO, // sType
153
176
nullptr , // pNext
154
177
0u , // flags
@@ -162,6 +185,8 @@ VkDevice create_logical_device(
162
185
nullptr , // pEnabledFeatures
163
186
};
164
187
188
+ device_create_info.pNext = physical_device.extension_features ;
189
+
165
190
VkDevice handle = nullptr ;
166
191
VK_CHECK (vkCreateDevice (
167
192
physical_device.handle , &device_create_info, nullptr , &handle));
@@ -371,33 +396,53 @@ std::string Adapter::stringize() const {
371
396
ss << " deviceType: " << device_type << std::endl;
372
397
ss << " deviceName: " << properties.deviceName << std::endl;
373
398
374
- #define PRINT_LIMIT_PROP ( name ) \
375
- ss << " " << std::left << std::setw (36 ) << #name << limits .name \
399
+ #define PRINT_PROP ( struct, name ) \
400
+ ss << " " << std::left << std::setw (36 ) << #name << struct .name \
376
401
<< std::endl;
377
402
378
- #define PRINT_LIMIT_PROP_VEC3 ( name ) \
379
- ss << " " << std::left << std::setw (36 ) << #name << limits .name [0 ] \
380
- << " ," << limits .name [1 ] << " ," << limits .name [2 ] << std::endl;
403
+ #define PRINT_PROP_VEC3 ( struct, name ) \
404
+ ss << " " << std::left << std::setw(36 ) << #name << struct .name[0 ] \
405
+ << " ," << struct .name[1 ] << " ," << struct .name[2 ] << std::endl;
381
406
382
407
ss << " Physical Device Limits {" << std::endl;
383
- PRINT_LIMIT_PROP (maxImageDimension1D);
384
- PRINT_LIMIT_PROP (maxImageDimension2D);
385
- PRINT_LIMIT_PROP (maxImageDimension3D);
386
- PRINT_LIMIT_PROP (maxTexelBufferElements);
387
- PRINT_LIMIT_PROP (maxPushConstantsSize);
388
- PRINT_LIMIT_PROP (maxMemoryAllocationCount);
389
- PRINT_LIMIT_PROP (maxSamplerAllocationCount);
390
- PRINT_LIMIT_PROP (maxComputeSharedMemorySize);
391
- PRINT_LIMIT_PROP_VEC3 (maxComputeWorkGroupCount);
392
- PRINT_LIMIT_PROP (maxComputeWorkGroupInvocations);
393
- PRINT_LIMIT_PROP_VEC3 (maxComputeWorkGroupSize);
408
+ PRINT_PROP (limits, maxImageDimension1D);
409
+ PRINT_PROP (limits, maxImageDimension2D);
410
+ PRINT_PROP (limits, maxImageDimension3D);
411
+ PRINT_PROP (limits, maxTexelBufferElements);
412
+ PRINT_PROP (limits, maxPushConstantsSize);
413
+ PRINT_PROP (limits, maxMemoryAllocationCount);
414
+ PRINT_PROP (limits, maxSamplerAllocationCount);
415
+ PRINT_PROP (limits, maxComputeSharedMemorySize);
416
+ PRINT_PROP_VEC3 (limits, maxComputeWorkGroupCount);
417
+ PRINT_PROP (limits, maxComputeWorkGroupInvocations);
418
+ PRINT_PROP_VEC3 (limits, maxComputeWorkGroupSize);
419
+ ss << " }" << std::endl;
420
+
421
+ ss << " 16bit Storage Features {" << std::endl;
422
+ PRINT_PROP (physical_device_.shader_16bit_storage, storageBuffer16BitAccess);
423
+ PRINT_PROP (
424
+ physical_device_.shader_16bit_storage,
425
+ uniformAndStorageBuffer16BitAccess);
426
+ PRINT_PROP (physical_device_.shader_16bit_storage, storagePushConstant16);
427
+ PRINT_PROP (physical_device_.shader_16bit_storage, storageInputOutput16);
428
+ ss << " }" << std::endl;
429
+
430
+ ss << " 8bit Storage Features {" << std::endl;
431
+ PRINT_PROP (physical_device_.shader_8bit_storage, storageBuffer8BitAccess);
432
+ PRINT_PROP (
433
+ physical_device_.shader_8bit_storage, uniformAndStorageBuffer8BitAccess);
434
+ PRINT_PROP (physical_device_.shader_8bit_storage, storagePushConstant8);
435
+ ss << " }" << std::endl;
436
+
437
+ ss << " Shader 16bit and 8bit Features {" << std::endl;
438
+ PRINT_PROP (physical_device_.shader_float16_int8_types, shaderFloat16);
439
+ PRINT_PROP (physical_device_.shader_float16_int8_types, shaderInt8);
394
440
ss << " }" << std::endl;
395
- ss << " }" << std::endl;
396
- ;
397
441
398
442
const VkPhysicalDeviceMemoryProperties& mem_props =
399
443
physical_device_.memory_properties;
400
444
445
+ ss << " }" << std::endl;
401
446
ss << " Memory Info {" << std::endl;
402
447
ss << " Memory Types [" << std::endl;
403
448
for (size_t i = 0 ; i < mem_props.memoryTypeCount; ++i) {
@@ -432,6 +477,9 @@ std::string Adapter::stringize() const {
432
477
ss << " ]" << std::endl;
433
478
ss << " }" ;
434
479
480
+ #undef PRINT_PROP
481
+ #undef PRINT_PROP_VEC3
482
+
435
483
return ss.str();
436
484
}
437
485
0 commit comments