Skip to content

[ET-VK] Bump Vulkan API requirement to 1.1 and enable 16 bit and 8 bit types in buffer storage #3058

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 68 additions & 20 deletions backends/vulkan/runtime/api/Adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
* LICENSE file in the root directory of this source tree.
*/

// @lint-ignore-every CLANGTIDY clang-diagnostic-missing-field-initializers

#include <executorch/backends/vulkan/runtime/api/Adapter.h>

#include <bitset>
Expand All @@ -21,15 +23,33 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle)
: handle(physical_device_handle),
properties{},
memory_properties{},
shader_16bit_storage{
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES},
shader_8bit_storage{
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES},
shader_float16_int8_types{
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES_KHR},
queue_families{},
num_compute_queues(0),
has_unified_memory(false),
has_timestamps(properties.limits.timestampComputeAndGraphics),
timestamp_period(properties.limits.timestampPeriod) {
timestamp_period(properties.limits.timestampPeriod),
extension_features(&shader_16bit_storage) {
// Extract physical device properties
vkGetPhysicalDeviceProperties(handle, &properties);
vkGetPhysicalDeviceMemoryProperties(handle, &memory_properties);

VkPhysicalDeviceFeatures2 features2{
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2};

// Create linked list to query availability of extensions
features2.pNext = &shader_16bit_storage;
shader_16bit_storage.pNext = &shader_8bit_storage;
shader_8bit_storage.pNext = &shader_float16_int8_types;
shader_float16_int8_types.pNext = nullptr;

vkGetPhysicalDeviceFeatures2(handle, &features2);

// Check if there are any memory types have both the HOST_VISIBLE and the
// DEVICE_LOCAL property flags
const VkMemoryPropertyFlags unified_memory_flags =
Expand Down Expand Up @@ -140,6 +160,9 @@ VkDevice create_logical_device(
#ifdef VK_KHR_portability_subset
VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME,
#endif /* VK_KHR_portability_subset */
VK_KHR_16BIT_STORAGE_EXTENSION_NAME,
VK_KHR_8BIT_STORAGE_EXTENSION_NAME,
VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME,
};

std::vector<const char*> enabled_device_extensions;
Expand All @@ -148,7 +171,7 @@ VkDevice create_logical_device(
enabled_device_extensions,
requested_device_extensions);

const VkDeviceCreateInfo device_create_info{
VkDeviceCreateInfo device_create_info{
VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO, // sType
nullptr, // pNext
0u, // flags
Expand All @@ -162,6 +185,8 @@ VkDevice create_logical_device(
nullptr, // pEnabledFeatures
};

device_create_info.pNext = physical_device.extension_features;

VkDevice handle = nullptr;
VK_CHECK(vkCreateDevice(
physical_device.handle, &device_create_info, nullptr, &handle));
Expand Down Expand Up @@ -371,33 +396,53 @@ std::string Adapter::stringize() const {
ss << " deviceType: " << device_type << std::endl;
ss << " deviceName: " << properties.deviceName << std::endl;

#define PRINT_LIMIT_PROP(name) \
ss << " " << std::left << std::setw(36) << #name << limits.name \
#define PRINT_PROP(struct, name) \
ss << " " << std::left << std::setw(36) << #name << struct.name \
<< std::endl;

#define PRINT_LIMIT_PROP_VEC3(name) \
ss << " " << std::left << std::setw(36) << #name << limits.name[0] \
<< "," << limits.name[1] << "," << limits.name[2] << std::endl;
#define PRINT_PROP_VEC3(struct, name) \
ss << " " << std::left << std::setw(36) << #name << struct.name[0] \
<< "," << struct.name[1] << "," << struct.name[2] << std::endl;

ss << " Physical Device Limits {" << std::endl;
PRINT_LIMIT_PROP(maxImageDimension1D);
PRINT_LIMIT_PROP(maxImageDimension2D);
PRINT_LIMIT_PROP(maxImageDimension3D);
PRINT_LIMIT_PROP(maxTexelBufferElements);
PRINT_LIMIT_PROP(maxPushConstantsSize);
PRINT_LIMIT_PROP(maxMemoryAllocationCount);
PRINT_LIMIT_PROP(maxSamplerAllocationCount);
PRINT_LIMIT_PROP(maxComputeSharedMemorySize);
PRINT_LIMIT_PROP_VEC3(maxComputeWorkGroupCount);
PRINT_LIMIT_PROP(maxComputeWorkGroupInvocations);
PRINT_LIMIT_PROP_VEC3(maxComputeWorkGroupSize);
PRINT_PROP(limits, maxImageDimension1D);
PRINT_PROP(limits, maxImageDimension2D);
PRINT_PROP(limits, maxImageDimension3D);
PRINT_PROP(limits, maxTexelBufferElements);
PRINT_PROP(limits, maxPushConstantsSize);
PRINT_PROP(limits, maxMemoryAllocationCount);
PRINT_PROP(limits, maxSamplerAllocationCount);
PRINT_PROP(limits, maxComputeSharedMemorySize);
PRINT_PROP_VEC3(limits, maxComputeWorkGroupCount);
PRINT_PROP(limits, maxComputeWorkGroupInvocations);
PRINT_PROP_VEC3(limits, maxComputeWorkGroupSize);
ss << " }" << std::endl;

ss << " 16bit Storage Features {" << std::endl;
PRINT_PROP(physical_device_.shader_16bit_storage, storageBuffer16BitAccess);
PRINT_PROP(
physical_device_.shader_16bit_storage,
uniformAndStorageBuffer16BitAccess);
PRINT_PROP(physical_device_.shader_16bit_storage, storagePushConstant16);
PRINT_PROP(physical_device_.shader_16bit_storage, storageInputOutput16);
ss << " }" << std::endl;

ss << " 8bit Storage Features {" << std::endl;
PRINT_PROP(physical_device_.shader_8bit_storage, storageBuffer8BitAccess);
PRINT_PROP(
physical_device_.shader_8bit_storage, uniformAndStorageBuffer8BitAccess);
PRINT_PROP(physical_device_.shader_8bit_storage, storagePushConstant8);
ss << " }" << std::endl;

ss << " Shader 16bit and 8bit Features {" << std::endl;
PRINT_PROP(physical_device_.shader_float16_int8_types, shaderFloat16);
PRINT_PROP(physical_device_.shader_float16_int8_types, shaderInt8);
ss << " }" << std::endl;
ss << " }" << std::endl;
;

const VkPhysicalDeviceMemoryProperties& mem_props =
physical_device_.memory_properties;

ss << " }" << std::endl;
ss << " Memory Info {" << std::endl;
ss << " Memory Types [" << std::endl;
for (size_t i = 0; i < mem_props.memoryTypeCount; ++i) {
Expand Down Expand Up @@ -432,6 +477,9 @@ std::string Adapter::stringize() const {
ss << " ]" << std::endl;
ss << "}";

#undef PRINT_PROP
#undef PRINT_PROP_VEC3

return ss.str();
}

Expand Down
37 changes: 37 additions & 0 deletions backends/vulkan/runtime/api/Adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ struct PhysicalDevice final {
// Properties obtained from Vulkan
VkPhysicalDeviceProperties properties;
VkPhysicalDeviceMemoryProperties memory_properties;
// Additional features available from extensions
VkPhysicalDevice16BitStorageFeatures shader_16bit_storage;
VkPhysicalDevice8BitStorageFeatures shader_8bit_storage;
VkPhysicalDeviceShaderFloat16Int8Features shader_float16_int8_types;

// Available GPU queues
std::vector<VkQueueFamilyProperties> queue_families;

// Metadata
Expand All @@ -38,6 +44,9 @@ struct PhysicalDevice final {
bool has_timestamps;
float timestamp_period;

// Head of the linked list of extensions to be requested
void* extension_features{nullptr};

explicit PhysicalDevice(VkPhysicalDevice);
};

Expand Down Expand Up @@ -189,6 +198,34 @@ class Adapter final {
return vma_;
}

// Physical Device Features

inline bool has_16bit_storage() {
return physical_device_.shader_16bit_storage.storageBuffer16BitAccess ==
VK_TRUE;
}

inline bool has_8bit_storage() {
return physical_device_.shader_8bit_storage.storageBuffer8BitAccess ==
VK_TRUE;
}

inline bool has_16bit_compute() {
return physical_device_.shader_float16_int8_types.shaderFloat16 == VK_TRUE;
}

inline bool has_8bit_compute() {
return physical_device_.shader_float16_int8_types.shaderInt8 == VK_TRUE;
}

inline bool has_full_float16_buffers_support() {
return has_16bit_storage() && has_16bit_compute();
}

inline bool has_full_int8_buffers_support() {
return has_8bit_storage() && has_8bit_compute();
}

// Command Buffer Submission

void
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/api/Runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ VkInstance create_instance(const RuntimeConfiguration& config) {
0, // applicationVersion
nullptr, // pEngineName
0, // engineVersion
VK_API_VERSION_1_0, // apiVersion
VK_API_VERSION_1_1, // apiVersion
};

std::vector<const char*> enabled_layers;
Expand Down
9 changes: 8 additions & 1 deletion backends/vulkan/runtime/api/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,14 @@ vTensor::vTensor(
memory_layout_,
gpu_sizes_,
dtype_,
allocate_memory)) {}
allocate_memory)) {
if (dtype == api::kHalf) {
VK_CHECK_COND(
api::context()->adapter_ptr()->has_16bit_storage(),
"Half dtype is only available if the physical device supports float16 "
"storage buffers!");
}
}

vTensor::vTensor(
api::Context* const context,
Expand Down
18 changes: 9 additions & 9 deletions backends/vulkan/runtime/api/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
#define VK_FORMAT_FLOAT4 VK_FORMAT_R32G32B32A32_SFLOAT
#endif /* USE_VULKAN_FP16_INFERENCE */

#define VK_FORALL_SCALAR_TYPES(_) \
_(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte) \
_(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \
_(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \
_(bool, VK_FORMAT_R8G8B8A8_SINT, Bool) \
_(float, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \
_(float, VK_FORMAT_FLOAT4, Float) \
_(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8) \
_(uint8_t, VK_FORMAT_R8G8B8A8_UINT, QUInt8) \
#define VK_FORALL_SCALAR_TYPES(_) \
_(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte) \
_(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \
_(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \
_(bool, VK_FORMAT_R8G8B8A8_SINT, Bool) \
_(uint16_t, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \
_(float, VK_FORMAT_FLOAT4, Float) \
_(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8) \
_(uint8_t, VK_FORMAT_R8G8B8A8_UINT, QUInt8) \
_(int32_t, VK_FORMAT_R32G32B32A32_SINT, QInt32)

namespace vkcompute {
Expand Down
19 changes: 18 additions & 1 deletion backends/vulkan/runtime/api/gen_vulkan_spv.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,22 @@ def get_buffer_scalar_type(dtype: str) -> str:
return dtype


def get_buffer_gvec_type(dtype: str, n: int) -> str:
if n == 1:
return get_buffer_scalar_type(dtype)

if dtype == "float":
return f"vec{n}"
elif dtype == "half":
return f"f16vec{n}"
elif dtype == "int8":
return f"i8vec{n}"
elif dtype == "uint8":
return f"u8vec{n}"

raise AssertionError(f"Invalid dtype: {dtype}")


def get_texel_type(dtype: str) -> str:
image_format = TYPE_MAPPINGS["IMAGE_FORMAT"][dtype]
if image_format[-1] == "f":
Expand Down Expand Up @@ -134,6 +150,7 @@ def get_texel_component_type(dtype: str) -> str:
2: lambda pos: f"{pos}.xy",
},
"buffer_scalar_type": get_buffer_scalar_type,
"buffer_gvec_type": get_buffer_gvec_type,
"texel_type": get_texel_type,
"gvec_type": get_gvec_type,
"texel_component_type": get_texel_component_type,
Expand Down Expand Up @@ -456,7 +473,7 @@ def generateSPV(self, output_dir: str) -> Dict[str, str]:
glsl_out_path,
"-o",
spv_out_path,
"--target-env=vulkan1.0",
"--target-env=vulkan1.1",
"-Werror",
] + [
arg
Expand Down
11 changes: 11 additions & 0 deletions backends/vulkan/test/glsl/all_shaders.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ fill_texture__test:
shader_variants:
- NAME: fill_texture__test

idx_fill_buffer:
parameter_names_with_default_values:
DTYPE: float
generate_variant_forall:
DTYPE:
- VALUE: float
- VALUE: half
- VALUE: int8
shader_variants:
- NAME: idx_fill_buffer

idx_fill_texture:
parameter_names_with_default_values:
DTYPE: float
Expand Down
48 changes: 48 additions & 0 deletions backends/vulkan/test/glsl/idx_fill_buffer.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}

#include "indexing_utils.h"

$if DTYPE == "half":
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
$elif DTYPE == "int8":
#extension GL_EXT_shader_8bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
$elif DTYPE == "uint8":
#extension GL_EXT_shader_8bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_uint8 : require

layout(std430) buffer;

layout(set = 0, binding = 0) buffer PRECISION restrict writeonly Buffer {
VEC4_T data[];
}
buffer_in;

layout(set = 0, binding = 1) uniform PRECISION restrict Params {
int len;
}
params;

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const int i = ivec3(gl_GlobalInvocationID).x;

const int base = 4 * i;
if (base < params.len) {
buffer_in.data[i] = VEC4_T(base, base + 1, base + 2, base + 3);
}
}
Loading