Skip to content

Commit bf58333

Browse files
committed
[ET-VK] Bump Vulkan API requirement to 1.1 and enable 16 bit and 8 bit types in buffer storage
## Context Enable use of explicit fp16 and int8 types in GPU storage buffers via the following extensions: * [VK_KHR_16bit_storage](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_KHR_16bit_storage.html) * [VK_KHR_8bit_storage](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_KHR_8bit_storage.html) * [VK_KHR_shader_float16_int8](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_KHR_shader_float16_int8.html) The first two enables usage of 16-bit and 8-bit types in storage buffers, while the last one enables using those types in arithmetic operations. By enabling these extensions and checking that the device supports the required features, explicit fp16 and int8 types can be used in compute shaders, as demonstrated by the added test. Vulkan 1.1 is required in order to access `vkGetPhysicalDeviceFeatures2`, which is required to query whether the device support 16bit and 8bit types. This should be a fairly straightforward version bump as Vulkan 1.1 is supported by the vast majority of Android devices. Differential Revision: [D56164239](https://our.internmc.facebook.com/intern/diff/D56164239/) [ghstack-poisoned]
1 parent 458d743 commit bf58333

File tree

8 files changed

+262
-31
lines changed

8 files changed

+262
-31
lines changed

backends/vulkan/runtime/api/Adapter.cpp

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,33 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle)
2121
: handle(physical_device_handle),
2222
properties{},
2323
memory_properties{},
24+
shader_16bit_storage{
25+
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES},
26+
shader_8bit_storage{
27+
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES},
28+
shader_float16_int8_types{
29+
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES_KHR},
2430
queue_families{},
2531
num_compute_queues(0),
2632
has_unified_memory(false),
2733
has_timestamps(properties.limits.timestampComputeAndGraphics),
28-
timestamp_period(properties.limits.timestampPeriod) {
34+
timestamp_period(properties.limits.timestampPeriod),
35+
extension_features(&shader_16bit_storage) {
2936
// Extract physical device properties
3037
vkGetPhysicalDeviceProperties(handle, &properties);
3138
vkGetPhysicalDeviceMemoryProperties(handle, &memory_properties);
3239

40+
VkPhysicalDeviceFeatures2 features2{
41+
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2};
42+
43+
// Create linked list to query availability of extensions
44+
features2.pNext = &shader_16bit_storage;
45+
shader_16bit_storage.pNext = &shader_8bit_storage;
46+
shader_8bit_storage.pNext = &shader_float16_int8_types;
47+
shader_float16_int8_types.pNext = nullptr;
48+
49+
vkGetPhysicalDeviceFeatures2(handle, &features2);
50+
3351
// Check if there are any memory types have both the HOST_VISIBLE and the
3452
// DEVICE_LOCAL property flags
3553
const VkMemoryPropertyFlags unified_memory_flags =
@@ -140,6 +158,9 @@ VkDevice create_logical_device(
140158
#ifdef VK_KHR_portability_subset
141159
VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME,
142160
#endif /* VK_KHR_portability_subset */
161+
VK_KHR_16BIT_STORAGE_EXTENSION_NAME,
162+
VK_KHR_8BIT_STORAGE_EXTENSION_NAME,
163+
VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME,
143164
};
144165

145166
std::vector<const char*> enabled_device_extensions;
@@ -148,7 +169,7 @@ VkDevice create_logical_device(
148169
enabled_device_extensions,
149170
requested_device_extensions);
150171

151-
const VkDeviceCreateInfo device_create_info{
172+
VkDeviceCreateInfo device_create_info{
152173
VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO, // sType
153174
nullptr, // pNext
154175
0u, // flags
@@ -162,6 +183,8 @@ VkDevice create_logical_device(
162183
nullptr, // pEnabledFeatures
163184
};
164185

186+
device_create_info.pNext = physical_device.extension_features;
187+
165188
VkDevice handle = nullptr;
166189
VK_CHECK(vkCreateDevice(
167190
physical_device.handle, &device_create_info, nullptr, &handle));
@@ -371,33 +394,53 @@ std::string Adapter::stringize() const {
371394
ss << " deviceType: " << device_type << std::endl;
372395
ss << " deviceName: " << properties.deviceName << std::endl;
373396

374-
#define PRINT_LIMIT_PROP(name) \
375-
ss << " " << std::left << std::setw(36) << #name << limits.name \
397+
#define PRINT_PROP(struct, name) \
398+
ss << " " << std::left << std::setw(36) << #name << struct.name \
376399
<< std::endl;
377400

378-
#define PRINT_LIMIT_PROP_VEC3(name) \
379-
ss << " " << std::left << std::setw(36) << #name << limits.name[0] \
380-
<< "," << limits.name[1] << "," << limits.name[2] << std::endl;
401+
#define PRINT_PROP_VEC3(struct, name) \
402+
ss << " " << std::left << std::setw(36) << #name << struct.name[0] \
403+
<< "," << struct.name[1] << "," << struct.name[2] << std::endl;
381404

382405
ss << " Physical Device Limits {" << std::endl;
383-
PRINT_LIMIT_PROP(maxImageDimension1D);
384-
PRINT_LIMIT_PROP(maxImageDimension2D);
385-
PRINT_LIMIT_PROP(maxImageDimension3D);
386-
PRINT_LIMIT_PROP(maxTexelBufferElements);
387-
PRINT_LIMIT_PROP(maxPushConstantsSize);
388-
PRINT_LIMIT_PROP(maxMemoryAllocationCount);
389-
PRINT_LIMIT_PROP(maxSamplerAllocationCount);
390-
PRINT_LIMIT_PROP(maxComputeSharedMemorySize);
391-
PRINT_LIMIT_PROP_VEC3(maxComputeWorkGroupCount);
392-
PRINT_LIMIT_PROP(maxComputeWorkGroupInvocations);
393-
PRINT_LIMIT_PROP_VEC3(maxComputeWorkGroupSize);
406+
PRINT_PROP(limits, maxImageDimension1D);
407+
PRINT_PROP(limits, maxImageDimension2D);
408+
PRINT_PROP(limits, maxImageDimension3D);
409+
PRINT_PROP(limits, maxTexelBufferElements);
410+
PRINT_PROP(limits, maxPushConstantsSize);
411+
PRINT_PROP(limits, maxMemoryAllocationCount);
412+
PRINT_PROP(limits, maxSamplerAllocationCount);
413+
PRINT_PROP(limits, maxComputeSharedMemorySize);
414+
PRINT_PROP_VEC3(limits, maxComputeWorkGroupCount);
415+
PRINT_PROP(limits, maxComputeWorkGroupInvocations);
416+
PRINT_PROP_VEC3(limits, maxComputeWorkGroupSize);
417+
ss << " }" << std::endl;
418+
419+
ss << " 16bit Storage Features {" << std::endl;
420+
PRINT_PROP(physical_device_.shader_16bit_storage, storageBuffer16BitAccess);
421+
PRINT_PROP(
422+
physical_device_.shader_16bit_storage,
423+
uniformAndStorageBuffer16BitAccess);
424+
PRINT_PROP(physical_device_.shader_16bit_storage, storagePushConstant16);
425+
PRINT_PROP(physical_device_.shader_16bit_storage, storageInputOutput16);
426+
ss << " }" << std::endl;
427+
428+
ss << " 8bit Storage Features {" << std::endl;
429+
PRINT_PROP(physical_device_.shader_8bit_storage, storageBuffer8BitAccess);
430+
PRINT_PROP(
431+
physical_device_.shader_8bit_storage, uniformAndStorageBuffer8BitAccess);
432+
PRINT_PROP(physical_device_.shader_8bit_storage, storagePushConstant8);
433+
ss << " }" << std::endl;
434+
435+
ss << " Shader 16bit and 8bit Features {" << std::endl;
436+
PRINT_PROP(physical_device_.shader_float16_int8_types, shaderFloat16);
437+
PRINT_PROP(physical_device_.shader_float16_int8_types, shaderInt8);
394438
ss << " }" << std::endl;
395-
ss << " }" << std::endl;
396-
;
397439

398440
const VkPhysicalDeviceMemoryProperties& mem_props =
399441
physical_device_.memory_properties;
400442

443+
ss << " }" << std::endl;
401444
ss << " Memory Info {" << std::endl;
402445
ss << " Memory Types [" << std::endl;
403446
for (size_t i = 0; i < mem_props.memoryTypeCount; ++i) {
@@ -432,6 +475,9 @@ std::string Adapter::stringize() const {
432475
ss << " ]" << std::endl;
433476
ss << "}";
434477

478+
#undef PRINT_PROP
479+
#undef PRINT_PROP_VEC3
480+
435481
return ss.str();
436482
}
437483

backends/vulkan/runtime/api/Adapter.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ struct PhysicalDevice final {
3030
// Properties obtained from Vulkan
3131
VkPhysicalDeviceProperties properties;
3232
VkPhysicalDeviceMemoryProperties memory_properties;
33+
// Additional features available from extensions
34+
VkPhysicalDevice16BitStorageFeatures shader_16bit_storage;
35+
VkPhysicalDevice8BitStorageFeatures shader_8bit_storage;
36+
VkPhysicalDeviceShaderFloat16Int8Features shader_float16_int8_types;
37+
38+
// Available GPU queues
3339
std::vector<VkQueueFamilyProperties> queue_families;
3440

3541
// Metadata
@@ -38,6 +44,9 @@ struct PhysicalDevice final {
3844
bool has_timestamps;
3945
float timestamp_period;
4046

47+
// Head of the linked list of extensions to be requested
48+
void* extension_features{nullptr};
49+
4150
explicit PhysicalDevice(VkPhysicalDevice);
4251
};
4352

@@ -189,6 +198,34 @@ class Adapter final {
189198
return vma_;
190199
}
191200

201+
// Physical Device Features
202+
203+
inline bool has_16bit_storage() {
204+
return physical_device_.shader_16bit_storage.storageBuffer16BitAccess ==
205+
VK_TRUE;
206+
}
207+
208+
inline bool has_8bit_storage() {
209+
return physical_device_.shader_8bit_storage.storageBuffer8BitAccess ==
210+
VK_TRUE;
211+
}
212+
213+
inline bool has_16bit_compute() {
214+
return physical_device_.shader_float16_int8_types.shaderFloat16 == VK_TRUE;
215+
}
216+
217+
inline bool has_8bit_compute() {
218+
return physical_device_.shader_float16_int8_types.shaderInt8 == VK_TRUE;
219+
}
220+
221+
inline bool has_full_float16_buffers_support() {
222+
return has_16bit_storage() && has_16bit_compute();
223+
}
224+
225+
inline bool has_full_int8_buffers_support() {
226+
return has_8bit_storage() && has_8bit_compute();
227+
}
228+
192229
// Command Buffer Submission
193230

194231
void

backends/vulkan/runtime/api/Runtime.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ VkInstance create_instance(const RuntimeConfiguration& config) {
8585
0, // applicationVersion
8686
nullptr, // pEngineName
8787
0, // engineVersion
88-
VK_API_VERSION_1_0, // apiVersion
88+
VK_API_VERSION_1_1, // apiVersion
8989
};
9090

9191
std::vector<const char*> enabled_layers;

backends/vulkan/runtime/api/Types.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@
2323
#define VK_FORMAT_FLOAT4 VK_FORMAT_R32G32B32A32_SFLOAT
2424
#endif /* USE_VULKAN_FP16_INFERENCE */
2525

26-
#define VK_FORALL_SCALAR_TYPES(_) \
27-
_(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte) \
28-
_(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \
29-
_(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \
30-
_(bool, VK_FORMAT_R8G8B8A8_SINT, Bool) \
31-
_(float, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \
32-
_(float, VK_FORMAT_FLOAT4, Float) \
33-
_(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8) \
34-
_(uint8_t, VK_FORMAT_R8G8B8A8_UINT, QUInt8) \
26+
#define VK_FORALL_SCALAR_TYPES(_) \
27+
_(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte) \
28+
_(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \
29+
_(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \
30+
_(bool, VK_FORMAT_R8G8B8A8_SINT, Bool) \
31+
_(uint16_t, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \
32+
_(float, VK_FORMAT_FLOAT4, Float) \
33+
_(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8) \
34+
_(uint8_t, VK_FORMAT_R8G8B8A8_UINT, QUInt8) \
3535
_(int32_t, VK_FORMAT_R32G32B32A32_SINT, QInt32)
3636

3737
namespace vkcompute {

backends/vulkan/runtime/api/gen_vulkan_spv.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,22 @@ def get_buffer_scalar_type(dtype: str) -> str:
100100
return dtype
101101

102102

103+
def get_buffer_gvec_type(dtype: str, n: int) -> str:
104+
if n == 1:
105+
return get_buffer_scalar_type(dtype)
106+
107+
if dtype == "float":
108+
return f"vec{n}"
109+
elif dtype == "half":
110+
return f"f16vec{n}"
111+
elif dtype == "int8":
112+
return f"i8vec{n}"
113+
elif dtype == "uint8":
114+
return f"u8vec{n}"
115+
116+
raise AssertionError(f"Invalid dtype: {dtype}")
117+
118+
103119
def get_texel_type(dtype: str) -> str:
104120
image_format = TYPE_MAPPINGS["IMAGE_FORMAT"][dtype]
105121
if image_format[-1] == "f":
@@ -134,6 +150,7 @@ def get_texel_component_type(dtype: str) -> str:
134150
2: lambda pos: f"{pos}.xy",
135151
},
136152
"buffer_scalar_type": get_buffer_scalar_type,
153+
"buffer_gvec_type": get_buffer_gvec_type,
137154
"texel_type": get_texel_type,
138155
"gvec_type": get_gvec_type,
139156
"texel_component_type": get_texel_component_type,
@@ -456,7 +473,7 @@ def generateSPV(self, output_dir: str) -> Dict[str, str]:
456473
glsl_out_path,
457474
"-o",
458475
spv_out_path,
459-
"--target-env=vulkan1.0",
476+
"--target-env=vulkan1.1",
460477
"-Werror",
461478
] + [
462479
arg

backends/vulkan/test/glsl/all_shaders.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,17 @@ fill_texture__test:
3333
shader_variants:
3434
- NAME: fill_texture__test
3535

36+
idx_fill_buffer:
37+
parameter_names_with_default_values:
38+
DTYPE: float
39+
generate_variant_forall:
40+
DTYPE:
41+
- VALUE: float
42+
- VALUE: half
43+
- VALUE: int8
44+
shader_variants:
45+
- NAME: idx_fill_buffer
46+
3647
idx_fill_texture:
3748
parameter_names_with_default_values:
3849
DTYPE: float
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}
14+
15+
#include "indexing_utils.h"
16+
17+
$if DTYPE == "half":
18+
#extension GL_EXT_shader_16bit_storage : require
19+
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
20+
$elif DTYPE == "int8":
21+
#extension GL_EXT_shader_8bit_storage : require
22+
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
23+
$elif DTYPE == "uint8":
24+
#extension GL_EXT_shader_8bit_storage : require
25+
#extension GL_EXT_shader_explicit_arithmetic_types_uint8 : require
26+
27+
layout(std430) buffer;
28+
29+
layout(set = 0, binding = 0) buffer PRECISION restrict writeonly Buffer {
30+
VEC4_T data[];
31+
}
32+
buffer_in;
33+
34+
layout(set = 0, binding = 1) uniform PRECISION restrict Params {
35+
int len;
36+
}
37+
params;
38+
39+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
40+
41+
void main() {
42+
const int i = ivec3(gl_GlobalInvocationID).x;
43+
44+
const int base = 4 * i;
45+
if (base < params.len) {
46+
buffer_in.data[i] = VEC4_T(base, base + 1, base + 2, base + 3);
47+
}
48+
}

0 commit comments

Comments
 (0)