Skip to content

Commit e8752d0

Browse files
pytorchbotSS-JIA
authored andcommitted
[ET-VK] Parse required extensions of shaders and check capabilities during dispatch (#7592)
## Context Now that we are using GLSL/SPIR-V extensions more heavily in our shaders, there is a risk that a particular shader uses an extension that is not supported by the physical device. It is tedious to manually check that all the extensions required by a shader is supported by the device; it would be much more convenient for developers if there was an automated way to perform this check. This diff provides a solution for this. Materially, this has manifested into an issue with our internal CI tests that run on Android emulator (which uses swiftshader under the hood). If the emulator tries to compile a shader that requires the `shaderInt16` feature, then the emulator will crash. ## Solution 1. Update `ShaderInfo` to have fields indicating whether certain extensions that require device support is required. 2. Update the `gen_vulkan_spv.py` shader compilation script to parse the GLSL code and log whether aforemention extensions are needed in the generated `ShaderInfo`. 3. Introduce a new exception class, `ShaderNotSupportedError`. 4. Before dispatching, check that all extensions required by the shader is supported by the device. If not, throw the new exception class. 4. In the generated operator correctness tests, skip the test if `ShaderNotSupportedError` is thrown. Differential Revision: [D67992067](https://our.internmc.facebook.com/intern/diff/D67992067/) ghstack-source-id: 260809479 Pull Request resolved: #7576 Co-authored-by: Stephen Jia <[email protected]>
1 parent e519e57 commit e8752d0

File tree

11 files changed

+128
-7
lines changed

11 files changed

+128
-7
lines changed

backends/vulkan/runtime/api/Context.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,27 @@ void Context::report_shader_dispatch_end() {
8787
}
8888
}
8989

90+
void Context::check_device_capabilities(const vkapi::ShaderInfo& shader) {
91+
if (shader.requires_shader_int16) {
92+
if (!adapter_p_->supports_int16_shader_types()) {
93+
throw vkapi::ShaderNotSupportedError(
94+
shader.kernel_name, vkapi::VulkanExtension::SHADER_INT16);
95+
}
96+
}
97+
if (shader.requires_16bit_storage) {
98+
if (!adapter_p_->supports_16bit_storage_buffers()) {
99+
throw vkapi::ShaderNotSupportedError(
100+
shader.kernel_name, vkapi::VulkanExtension::INT16_STORAGE);
101+
}
102+
}
103+
if (shader.requires_8bit_storage) {
104+
if (!adapter_p_->supports_8bit_storage_buffers()) {
105+
throw vkapi::ShaderNotSupportedError(
106+
shader.kernel_name, vkapi::VulkanExtension::INT8_STORAGE);
107+
}
108+
}
109+
}
110+
90111
vkapi::DescriptorSet Context::get_descriptor_set(
91112
const vkapi::ShaderInfo& shader_descriptor,
92113
const utils::uvec3& local_workgroup_size,

backends/vulkan/runtime/api/Context.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ class Context final {
185185
}
186186
}
187187

188+
void check_device_capabilities(const vkapi::ShaderInfo& shader);
189+
188190
vkapi::DescriptorSet get_descriptor_set(
189191
const vkapi::ShaderInfo&,
190192
const utils::uvec3&,

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,10 @@ def maybe_replace_u16vecn(self, input_text: str) -> str:
720720
if "codegen-nosub" in input_text:
721721
return input_text
722722

723+
# Remove extension requirement so that generated ShaderInfo does not mark it
724+
input_text = input_text.replace(
725+
"#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require", ""
726+
)
723727
input_text = input_text.replace("u16vec", "ivec")
724728
input_text = input_text.replace("uint16_t", "int")
725729
return input_text
@@ -791,6 +795,9 @@ class ShaderInfo:
791795
weight_storage_type: str = ""
792796
bias_storage_type: str = ""
793797
register_for: Optional[Tuple[str, List[str]]] = None
798+
requires_shader_int16_ext: bool = False
799+
requires_16bit_storage_ext: bool = False
800+
requires_8bit_storage_ext: bool = False
794801

795802

796803
def getName(filePath: str) -> str:
@@ -858,6 +865,11 @@ def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]:
858865
return (matches_list[0], matches_list[1:])
859866

860867

868+
def isExtensionRequireLine(lineStr: str) -> bool:
869+
extension_require_id = r"^#extension ([A-Za-z0-9_]+)\s*:\s*require"
870+
return re.search(extension_require_id, lineStr) is not None
871+
872+
861873
typeIdMapping = {
862874
r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE",
863875
r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER",
@@ -889,6 +901,13 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo:
889901
shader_info.bias_storage_type = getBiasStorageType(line)
890902
if isRegisterForLine(line):
891903
shader_info.register_for = findRegisterFor(line)
904+
if isExtensionRequireLine(line):
905+
if "GL_EXT_shader_explicit_arithmetic_types_int16" in line:
906+
shader_info.requires_shader_int16_ext = True
907+
if "GL_EXT_shader_16bit_storage" in line:
908+
shader_info.requires_16bit_storage_ext = True
909+
if "GL_EXT_shader_8bit_storage" in line:
910+
shader_info.requires_8bit_storage_ext = True
892911

893912
return shader_info
894913

@@ -952,12 +971,18 @@ def generateShaderInfoStr(shader_info: ShaderInfo, name: str, sizeBytes: int) ->
952971

953972
shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))
954973

974+
def to_cpp_str(val: bool):
975+
return "true" if val else "false"
976+
955977
shader_info_args = [
956978
f'"{name}"',
957979
f"{name}_bin",
958980
str(sizeBytes),
959981
shader_info_layouts,
960982
tile_size,
983+
to_cpp_str(shader_info.requires_shader_int16_ext),
984+
to_cpp_str(shader_info.requires_16bit_storage_ext),
985+
to_cpp_str(shader_info.requires_8bit_storage_ext),
961986
]
962987

963988
shader_info_str = textwrap.indent(

backends/vulkan/runtime/graph/ops/DispatchNode.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ void DispatchNode::encode(ComputeGraph* graph) {
5858
api::Context* const context = graph->context();
5959
vkapi::PipelineBarrier pipeline_barrier{};
6060

61+
context->check_device_capabilities(shader_);
62+
6163
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();
6264

6365
std::array<uint8_t, kMaxPushConstantSize> push_constants_data;

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
3636

3737
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3838

39-
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
40-
4139
/*
4240
* Computes a depthwise convolution. Each shader invocation calculates the
4341
* output at a single output location.

backends/vulkan/runtime/vk_api/Adapter.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ std::string Adapter::stringize() const {
256256
ss << " deviceType: " << device_type << std::endl;
257257
ss << " deviceName: " << properties.deviceName << std::endl;
258258

259+
#define PRINT_BOOL(value, name) \
260+
ss << " " << std::left << std::setw(36) << #name << value << std::endl;
261+
259262
#define PRINT_PROP(struct, name) \
260263
ss << " " << std::left << std::setw(36) << #name << struct.name \
261264
<< std::endl;
@@ -298,12 +301,13 @@ std::string Adapter::stringize() const {
298301
ss << " }" << std::endl;
299302
#endif /* VK_KHR_8bit_storage */
300303

301-
#ifdef VK_KHR_shader_float16_int8
302304
ss << " Shader 16bit and 8bit Features {" << std::endl;
305+
PRINT_BOOL(physical_device_.supports_int16_shader_types, shaderInt16)
306+
#ifdef VK_KHR_shader_float16_int8
303307
PRINT_PROP(physical_device_.shader_float16_int8_types, shaderFloat16);
304308
PRINT_PROP(physical_device_.shader_float16_int8_types, shaderInt8);
305-
ss << " }" << std::endl;
306309
#endif /* VK_KHR_shader_float16_int8 */
310+
ss << " }" << std::endl;
307311

308312
const VkPhysicalDeviceMemoryProperties& mem_props =
309313
physical_device_.memory_properties;

backends/vulkan/runtime/vk_api/Exception.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,5 +77,36 @@ Error::Error(SourceLocation source_location, const char* cond, std::string msg)
7777
what_ = oss.str();
7878
}
7979

80+
//
81+
// ShaderNotSupportedError
82+
//
83+
84+
std::ostream& operator<<(std::ostream& out, const VulkanExtension result) {
85+
switch (result) {
86+
case VulkanExtension::SHADER_INT16:
87+
out << "shaderInt16";
88+
break;
89+
case VulkanExtension::INT16_STORAGE:
90+
out << "VK_KHR_16bit_storage";
91+
break;
92+
case VulkanExtension::INT8_STORAGE:
93+
out << "VK_KHR_8bit_storage";
94+
break;
95+
}
96+
return out;
97+
}
98+
99+
ShaderNotSupportedError::ShaderNotSupportedError(
100+
std::string shader_name,
101+
VulkanExtension extension)
102+
: shader_name_(std::move(shader_name)), extension_{extension} {
103+
std::ostringstream oss;
104+
oss << "Shader " << shader_name_ << " ";
105+
oss << "not compatible with device. ";
106+
oss << "Missing support for extension or physical device feature: ";
107+
oss << extension_;
108+
what_ = oss.str();
109+
}
110+
80111
} // namespace vkapi
81112
} // namespace vkcompute

backends/vulkan/runtime/vk_api/Exception.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,26 @@ class Error : public std::exception {
7878
}
7979
};
8080

81+
enum class VulkanExtension : uint8_t {
82+
SHADER_INT16,
83+
INT16_STORAGE,
84+
INT8_STORAGE,
85+
};
86+
87+
class ShaderNotSupportedError : public std::exception {
88+
public:
89+
ShaderNotSupportedError(std::string shader_name, VulkanExtension extension);
90+
91+
private:
92+
std::string shader_name_;
93+
VulkanExtension extension_;
94+
std::string what_;
95+
96+
public:
97+
const char* what() const noexcept override {
98+
return what_.c_str();
99+
}
100+
};
101+
81102
} // namespace vkapi
82103
} // namespace vkcompute

backends/vulkan/runtime/vk_api/Shader.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,20 @@ ShaderInfo::ShaderInfo(
2828
const uint32_t* const spirv_bin,
2929
const uint32_t size,
3030
std::vector<VkDescriptorType> layout,
31-
const utils::uvec3 tile_size)
31+
const utils::uvec3 tile_size,
32+
const bool requires_shader_int16_ext,
33+
const bool requires_16bit_storage_ext,
34+
const bool requires_8bit_storage_ext)
3235
: src_code{
3336
spirv_bin,
3437
size,
3538
},
3639
kernel_name{std::move(name)},
3740
kernel_layout{std::move(layout)},
38-
out_tile_size(tile_size) {
41+
out_tile_size(tile_size),
42+
requires_shader_int16(requires_shader_int16_ext),
43+
requires_16bit_storage(requires_16bit_storage_ext),
44+
requires_8bit_storage(requires_8bit_storage_ext) {
3945
}
4046

4147
bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) {

backends/vulkan/runtime/vk_api/Shader.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ struct ShaderInfo final {
6262

6363
// Shader Metadata
6464
utils::uvec3 out_tile_size{1u, 1u, 1u};
65+
bool requires_shader_int16 = false;
66+
bool requires_16bit_storage = false;
67+
bool requires_8bit_storage = false;
6568

6669
explicit ShaderInfo();
6770

@@ -70,7 +73,10 @@ struct ShaderInfo final {
7073
const uint32_t*,
7174
const uint32_t,
7275
std::vector<VkDescriptorType>,
73-
const utils::uvec3 tile_size);
76+
const utils::uvec3 tile_size,
77+
const bool requires_shader_int16_ext,
78+
const bool requires_16bit_storage_ext,
79+
const bool requires_8bit_storage_ext);
7480

7581
operator bool() const {
7682
return src_code.bin != nullptr;

backends/vulkan/test/op_tests/utils/gen_correctness_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,13 @@ class GeneratedOpsTest_{op_name} : public ::testing::Test {{
4545
test_suite_template = """
4646
TEST_P(GeneratedOpsTest_{op_name}, {case_name}) {{
4747
{create_ref_data}
48+
try {{
4849
{create_and_check_out}
4950
}}
51+
catch (const vkcompute::vkapi::ShaderNotSupportedError& e) {{
52+
GTEST_SKIP() << e.what();
53+
}}
54+
}}
5055
"""
5156

5257

0 commit comments

Comments
 (0)