Skip to content

Commit d481c11

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Bump Vulkan API requirement to 1.1 and enable 16 bit and 8 bit types in buffer storage (#3058)
Summary: Pull Request resolved: #3058 ## Context Enable use of explicit fp16 and int8 types in GPU storage buffers via the following extensions: * [VK_KHR_16bit_storage](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_KHR_16bit_storage.html) * [VK_KHR_8bit_storage](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_KHR_8bit_storage.html) * [VK_KHR_shader_float16_int8](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_KHR_shader_float16_int8.html) The first two enables usage of 16-bit and 8-bit types in storage buffers, while the last one enables using those types in arithmetic operations. By enabling these extensions and checking that the device supports the required features, explicit fp16 and int8 types can be used in compute shaders, as demonstrated by the added test. Vulkan 1.1 is required in order to access `vkGetPhysicalDeviceFeatures2`, which is required to query whether the device support 16bit and 8bit types. This should be a fairly straightforward version bump as Vulkan 1.1 is supported by the vast majority of Android devices. ghstack-source-id: 222727208 exported-using-ghexport Reviewed By: jorgep31415 Differential Revision: D56164239 fbshipit-source-id: 879804567ff08201933a220c9f168f435af80019
1 parent 473c98c commit d481c11

File tree

9 files changed

+272
-32
lines changed

9 files changed

+272
-32
lines changed

backends/vulkan/runtime/api/Adapter.cpp

Lines changed: 68 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
// @lint-ignore-every CLANGTIDY clang-diagnostic-missing-field-initializers
10+
911
#include <executorch/backends/vulkan/runtime/api/Adapter.h>
1012

1113
#include <bitset>
@@ -21,15 +23,33 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle)
2123
: handle(physical_device_handle),
2224
properties{},
2325
memory_properties{},
26+
shader_16bit_storage{
27+
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES},
28+
shader_8bit_storage{
29+
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES},
30+
shader_float16_int8_types{
31+
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES_KHR},
2432
queue_families{},
2533
num_compute_queues(0),
2634
has_unified_memory(false),
2735
has_timestamps(properties.limits.timestampComputeAndGraphics),
28-
timestamp_period(properties.limits.timestampPeriod) {
36+
timestamp_period(properties.limits.timestampPeriod),
37+
extension_features(&shader_16bit_storage) {
2938
// Extract physical device properties
3039
vkGetPhysicalDeviceProperties(handle, &properties);
3140
vkGetPhysicalDeviceMemoryProperties(handle, &memory_properties);
3241

42+
VkPhysicalDeviceFeatures2 features2{
43+
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2};
44+
45+
// Create linked list to query availability of extensions
46+
features2.pNext = &shader_16bit_storage;
47+
shader_16bit_storage.pNext = &shader_8bit_storage;
48+
shader_8bit_storage.pNext = &shader_float16_int8_types;
49+
shader_float16_int8_types.pNext = nullptr;
50+
51+
vkGetPhysicalDeviceFeatures2(handle, &features2);
52+
3353
// Check if there are any memory types have both the HOST_VISIBLE and the
3454
// DEVICE_LOCAL property flags
3555
const VkMemoryPropertyFlags unified_memory_flags =
@@ -140,6 +160,9 @@ VkDevice create_logical_device(
140160
#ifdef VK_KHR_portability_subset
141161
VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME,
142162
#endif /* VK_KHR_portability_subset */
163+
VK_KHR_16BIT_STORAGE_EXTENSION_NAME,
164+
VK_KHR_8BIT_STORAGE_EXTENSION_NAME,
165+
VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME,
143166
};
144167

145168
std::vector<const char*> enabled_device_extensions;
@@ -148,7 +171,7 @@ VkDevice create_logical_device(
148171
enabled_device_extensions,
149172
requested_device_extensions);
150173

151-
const VkDeviceCreateInfo device_create_info{
174+
VkDeviceCreateInfo device_create_info{
152175
VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO, // sType
153176
nullptr, // pNext
154177
0u, // flags
@@ -162,6 +185,8 @@ VkDevice create_logical_device(
162185
nullptr, // pEnabledFeatures
163186
};
164187

188+
device_create_info.pNext = physical_device.extension_features;
189+
165190
VkDevice handle = nullptr;
166191
VK_CHECK(vkCreateDevice(
167192
physical_device.handle, &device_create_info, nullptr, &handle));
@@ -371,33 +396,53 @@ std::string Adapter::stringize() const {
371396
ss << " deviceType: " << device_type << std::endl;
372397
ss << " deviceName: " << properties.deviceName << std::endl;
373398

374-
#define PRINT_LIMIT_PROP(name) \
375-
ss << " " << std::left << std::setw(36) << #name << limits.name \
399+
#define PRINT_PROP(struct, name) \
400+
ss << " " << std::left << std::setw(36) << #name << struct.name \
376401
<< std::endl;
377402

378-
#define PRINT_LIMIT_PROP_VEC3(name) \
379-
ss << " " << std::left << std::setw(36) << #name << limits.name[0] \
380-
<< "," << limits.name[1] << "," << limits.name[2] << std::endl;
403+
#define PRINT_PROP_VEC3(struct, name) \
404+
ss << " " << std::left << std::setw(36) << #name << struct.name[0] \
405+
<< "," << struct.name[1] << "," << struct.name[2] << std::endl;
381406

382407
ss << " Physical Device Limits {" << std::endl;
383-
PRINT_LIMIT_PROP(maxImageDimension1D);
384-
PRINT_LIMIT_PROP(maxImageDimension2D);
385-
PRINT_LIMIT_PROP(maxImageDimension3D);
386-
PRINT_LIMIT_PROP(maxTexelBufferElements);
387-
PRINT_LIMIT_PROP(maxPushConstantsSize);
388-
PRINT_LIMIT_PROP(maxMemoryAllocationCount);
389-
PRINT_LIMIT_PROP(maxSamplerAllocationCount);
390-
PRINT_LIMIT_PROP(maxComputeSharedMemorySize);
391-
PRINT_LIMIT_PROP_VEC3(maxComputeWorkGroupCount);
392-
PRINT_LIMIT_PROP(maxComputeWorkGroupInvocations);
393-
PRINT_LIMIT_PROP_VEC3(maxComputeWorkGroupSize);
408+
PRINT_PROP(limits, maxImageDimension1D);
409+
PRINT_PROP(limits, maxImageDimension2D);
410+
PRINT_PROP(limits, maxImageDimension3D);
411+
PRINT_PROP(limits, maxTexelBufferElements);
412+
PRINT_PROP(limits, maxPushConstantsSize);
413+
PRINT_PROP(limits, maxMemoryAllocationCount);
414+
PRINT_PROP(limits, maxSamplerAllocationCount);
415+
PRINT_PROP(limits, maxComputeSharedMemorySize);
416+
PRINT_PROP_VEC3(limits, maxComputeWorkGroupCount);
417+
PRINT_PROP(limits, maxComputeWorkGroupInvocations);
418+
PRINT_PROP_VEC3(limits, maxComputeWorkGroupSize);
419+
ss << " }" << std::endl;
420+
421+
ss << " 16bit Storage Features {" << std::endl;
422+
PRINT_PROP(physical_device_.shader_16bit_storage, storageBuffer16BitAccess);
423+
PRINT_PROP(
424+
physical_device_.shader_16bit_storage,
425+
uniformAndStorageBuffer16BitAccess);
426+
PRINT_PROP(physical_device_.shader_16bit_storage, storagePushConstant16);
427+
PRINT_PROP(physical_device_.shader_16bit_storage, storageInputOutput16);
428+
ss << " }" << std::endl;
429+
430+
ss << " 8bit Storage Features {" << std::endl;
431+
PRINT_PROP(physical_device_.shader_8bit_storage, storageBuffer8BitAccess);
432+
PRINT_PROP(
433+
physical_device_.shader_8bit_storage, uniformAndStorageBuffer8BitAccess);
434+
PRINT_PROP(physical_device_.shader_8bit_storage, storagePushConstant8);
435+
ss << " }" << std::endl;
436+
437+
ss << " Shader 16bit and 8bit Features {" << std::endl;
438+
PRINT_PROP(physical_device_.shader_float16_int8_types, shaderFloat16);
439+
PRINT_PROP(physical_device_.shader_float16_int8_types, shaderInt8);
394440
ss << " }" << std::endl;
395-
ss << " }" << std::endl;
396-
;
397441

398442
const VkPhysicalDeviceMemoryProperties& mem_props =
399443
physical_device_.memory_properties;
400444

445+
ss << " }" << std::endl;
401446
ss << " Memory Info {" << std::endl;
402447
ss << " Memory Types [" << std::endl;
403448
for (size_t i = 0; i < mem_props.memoryTypeCount; ++i) {
@@ -432,6 +477,9 @@ std::string Adapter::stringize() const {
432477
ss << " ]" << std::endl;
433478
ss << "}";
434479

480+
#undef PRINT_PROP
481+
#undef PRINT_PROP_VEC3
482+
435483
return ss.str();
436484
}
437485

backends/vulkan/runtime/api/Adapter.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ struct PhysicalDevice final {
3030
// Properties obtained from Vulkan
3131
VkPhysicalDeviceProperties properties;
3232
VkPhysicalDeviceMemoryProperties memory_properties;
33+
// Additional features available from extensions
34+
VkPhysicalDevice16BitStorageFeatures shader_16bit_storage;
35+
VkPhysicalDevice8BitStorageFeatures shader_8bit_storage;
36+
VkPhysicalDeviceShaderFloat16Int8Features shader_float16_int8_types;
37+
38+
// Available GPU queues
3339
std::vector<VkQueueFamilyProperties> queue_families;
3440

3541
// Metadata
@@ -38,6 +44,9 @@ struct PhysicalDevice final {
3844
bool has_timestamps;
3945
float timestamp_period;
4046

47+
// Head of the linked list of extensions to be requested
48+
void* extension_features{nullptr};
49+
4150
explicit PhysicalDevice(VkPhysicalDevice);
4251
};
4352

@@ -189,6 +198,34 @@ class Adapter final {
189198
return vma_;
190199
}
191200

201+
// Physical Device Features
202+
203+
inline bool has_16bit_storage() {
204+
return physical_device_.shader_16bit_storage.storageBuffer16BitAccess ==
205+
VK_TRUE;
206+
}
207+
208+
inline bool has_8bit_storage() {
209+
return physical_device_.shader_8bit_storage.storageBuffer8BitAccess ==
210+
VK_TRUE;
211+
}
212+
213+
inline bool has_16bit_compute() {
214+
return physical_device_.shader_float16_int8_types.shaderFloat16 == VK_TRUE;
215+
}
216+
217+
inline bool has_8bit_compute() {
218+
return physical_device_.shader_float16_int8_types.shaderInt8 == VK_TRUE;
219+
}
220+
221+
inline bool has_full_float16_buffers_support() {
222+
return has_16bit_storage() && has_16bit_compute();
223+
}
224+
225+
inline bool has_full_int8_buffers_support() {
226+
return has_8bit_storage() && has_8bit_compute();
227+
}
228+
192229
// Command Buffer Submission
193230

194231
void

backends/vulkan/runtime/api/Runtime.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ VkInstance create_instance(const RuntimeConfiguration& config) {
8585
0, // applicationVersion
8686
nullptr, // pEngineName
8787
0, // engineVersion
88-
VK_API_VERSION_1_0, // apiVersion
88+
VK_API_VERSION_1_1, // apiVersion
8989
};
9090

9191
std::vector<const char*> enabled_layers;

backends/vulkan/runtime/api/Tensor.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,14 @@ vTensor::vTensor(
228228
memory_layout_,
229229
gpu_sizes_,
230230
dtype_,
231-
allocate_memory)) {}
231+
allocate_memory)) {
232+
if (dtype == api::kHalf) {
233+
VK_CHECK_COND(
234+
api::context()->adapter_ptr()->has_16bit_storage(),
235+
"Half dtype is only available if the physical device supports float16 "
236+
"storage buffers!");
237+
}
238+
}
232239

233240
vTensor::vTensor(
234241
api::Context* const context,

backends/vulkan/runtime/api/Types.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@
2323
#define VK_FORMAT_FLOAT4 VK_FORMAT_R32G32B32A32_SFLOAT
2424
#endif /* USE_VULKAN_FP16_INFERENCE */
2525

26-
#define VK_FORALL_SCALAR_TYPES(_) \
27-
_(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte) \
28-
_(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \
29-
_(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \
30-
_(bool, VK_FORMAT_R8G8B8A8_SINT, Bool) \
31-
_(float, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \
32-
_(float, VK_FORMAT_FLOAT4, Float) \
33-
_(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8) \
34-
_(uint8_t, VK_FORMAT_R8G8B8A8_UINT, QUInt8) \
26+
#define VK_FORALL_SCALAR_TYPES(_) \
27+
_(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte) \
28+
_(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \
29+
_(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \
30+
_(bool, VK_FORMAT_R8G8B8A8_SINT, Bool) \
31+
_(uint16_t, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \
32+
_(float, VK_FORMAT_FLOAT4, Float) \
33+
_(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8) \
34+
_(uint8_t, VK_FORMAT_R8G8B8A8_UINT, QUInt8) \
3535
_(int32_t, VK_FORMAT_R32G32B32A32_SINT, QInt32)
3636

3737
namespace vkcompute {

backends/vulkan/runtime/api/gen_vulkan_spv.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,22 @@ def get_buffer_scalar_type(dtype: str) -> str:
100100
return dtype
101101

102102

103+
def get_buffer_gvec_type(dtype: str, n: int) -> str:
104+
if n == 1:
105+
return get_buffer_scalar_type(dtype)
106+
107+
if dtype == "float":
108+
return f"vec{n}"
109+
elif dtype == "half":
110+
return f"f16vec{n}"
111+
elif dtype == "int8":
112+
return f"i8vec{n}"
113+
elif dtype == "uint8":
114+
return f"u8vec{n}"
115+
116+
raise AssertionError(f"Invalid dtype: {dtype}")
117+
118+
103119
def get_texel_type(dtype: str) -> str:
104120
image_format = TYPE_MAPPINGS["IMAGE_FORMAT"][dtype]
105121
if image_format[-1] == "f":
@@ -134,6 +150,7 @@ def get_texel_component_type(dtype: str) -> str:
134150
2: lambda pos: f"{pos}.xy",
135151
},
136152
"buffer_scalar_type": get_buffer_scalar_type,
153+
"buffer_gvec_type": get_buffer_gvec_type,
137154
"texel_type": get_texel_type,
138155
"gvec_type": get_gvec_type,
139156
"texel_component_type": get_texel_component_type,
@@ -456,7 +473,7 @@ def generateSPV(self, output_dir: str) -> Dict[str, str]:
456473
glsl_out_path,
457474
"-o",
458475
spv_out_path,
459-
"--target-env=vulkan1.0",
476+
"--target-env=vulkan1.1",
460477
"-Werror",
461478
] + [
462479
arg

backends/vulkan/test/glsl/all_shaders.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,17 @@ fill_texture__test:
3333
shader_variants:
3434
- NAME: fill_texture__test
3535

36+
idx_fill_buffer:
37+
parameter_names_with_default_values:
38+
DTYPE: float
39+
generate_variant_forall:
40+
DTYPE:
41+
- VALUE: float
42+
- VALUE: half
43+
- VALUE: int8
44+
shader_variants:
45+
- NAME: idx_fill_buffer
46+
3647
idx_fill_texture:
3748
parameter_names_with_default_values:
3849
DTYPE: float
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}
14+
15+
#include "indexing_utils.h"
16+
17+
$if DTYPE == "half":
18+
#extension GL_EXT_shader_16bit_storage : require
19+
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
20+
$elif DTYPE == "int8":
21+
#extension GL_EXT_shader_8bit_storage : require
22+
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
23+
$elif DTYPE == "uint8":
24+
#extension GL_EXT_shader_8bit_storage : require
25+
#extension GL_EXT_shader_explicit_arithmetic_types_uint8 : require
26+
27+
layout(std430) buffer;
28+
29+
layout(set = 0, binding = 0) buffer PRECISION restrict writeonly Buffer {
30+
VEC4_T data[];
31+
}
32+
buffer_in;
33+
34+
layout(set = 0, binding = 1) uniform PRECISION restrict Params {
35+
int len;
36+
}
37+
params;
38+
39+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
40+
41+
void main() {
42+
const int i = ivec3(gl_GlobalInvocationID).x;
43+
44+
const int base = 4 * i;
45+
if (base < params.len) {
46+
buffer_in.data[i] = VEC4_T(base, base + 1, base + 2, base + 3);
47+
}
48+
}

0 commit comments

Comments
 (0)