Skip to content

Commit cde9aad

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
[ATen-vulkan] Implement global shader registry (pytorch#121088)
Summary: X-link: pytorch/executorch#2222 ## Context This changeset updates Vulkan SPIR-V codegen to introduce a global SPIR-V shader registry and register shaders dynamically at static initialization time. This change makes it possible to define and link custom shader libraries to the ATen-Vulkan runtime. Before: * `gen_vulkan_spv.py` generated two files, `spv.h` and `spv.cpp` which would contain the definition and initialization of Vulkan shader registry variables. After: * Introduce the `ShaderRegistry` class in `api/`, which encapsulates functionality of the `ShaderRegistry` class previously defined in the generated `spv.h` file * Introduce a global shader registry (defined as a static variable in the `api::shader_registry() function` * Define a `ShaderRegisterInit` class (taking inspiration from `TorchLibraryInit`) that allows for dynamic shader registration * `gen_vulkan_spv.py` now only generates `spv.cpp`, which defines a static `ShaderRegisterInit` instance that triggers registration of the compiled shaders to the global shader registry. Benefits: * Cleaner code base; we no longer have `ShaderRegistry` defined in a generated file, and don't need a separate implementation file (`impl/Registry.*`) to handle shader lookup. All that logic now lives under `api/ShaderRegistry.*` * Makes it possible to compile and link separate shader libraries, providing similar flexibility as defining and linking custom ATen operators Test Plan: CI Differential Revision: D54447700
1 parent 5804720 commit cde9aad

File tree

10 files changed

+291
-225
lines changed

10 files changed

+291
-225
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#include <ATen/native/vulkan/api/ShaderRegistry.h>
2+
3+
namespace at {
4+
namespace native {
5+
namespace vulkan {
6+
namespace api {
7+
8+
bool ShaderRegistry::has_shader(const std::string& shader_name) {
9+
const ShaderListing::const_iterator it = listings_.find(shader_name);
10+
return it != listings_.end();
11+
}
12+
13+
bool ShaderRegistry::has_dispatch(const std::string& op_name) {
14+
const Registry::const_iterator it = registry_.find(op_name);
15+
return it != registry_.end();
16+
}
17+
18+
void ShaderRegistry::register_shader(ShaderInfo&& shader_info) {
19+
if (has_shader(shader_info.kernel_name)) {
20+
VK_THROW(
21+
"Shader with name ", shader_info.kernel_name, "already registered");
22+
}
23+
listings_.emplace(shader_info.kernel_name, shader_info);
24+
}
25+
26+
void ShaderRegistry::register_op_dispatch(
27+
const std::string& op_name,
28+
const DispatchKey key,
29+
const std::string& shader_name) {
30+
if (!has_dispatch(op_name)) {
31+
registry_.emplace(op_name, Dispatcher());
32+
}
33+
const Dispatcher::const_iterator it = registry_[op_name].find(key);
34+
if (it != registry_[op_name].end()) {
35+
registry_[op_name][key] = shader_name;
36+
} else {
37+
registry_[op_name].emplace(key, shader_name);
38+
}
39+
}
40+
41+
const ShaderInfo& ShaderRegistry::get_shader_info(
42+
const std::string& shader_name) {
43+
const ShaderListing::const_iterator it = listings_.find(shader_name);
44+
45+
VK_CHECK_COND(
46+
it != listings_.end(),
47+
"Could not find ShaderInfo with name ",
48+
shader_name);
49+
50+
return it->second;
51+
}
52+
53+
ShaderRegistry& shader_registry() {
54+
static ShaderRegistry registry;
55+
return registry;
56+
}
57+
58+
} // namespace api
59+
} // namespace vulkan
60+
} // namespace native
61+
} // namespace at
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#pragma once
2+
3+
// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
4+
5+
#ifdef USE_VULKAN_API
6+
7+
#include <ATen/native/vulkan/api/Shader.h>
8+
9+
#include <string>
10+
#include <unordered_map>
11+
12+
#define VK_KERNEL(shader_name) \
13+
::at::native::vulkan::api::shader_registry().get_shader_info(#shader_name)
14+
15+
namespace at {
16+
namespace native {
17+
namespace vulkan {
18+
namespace api {
19+
20+
enum class DispatchKey : int8_t {
21+
CATCHALL,
22+
ADRENO,
23+
MALI,
24+
OVERRIDE,
25+
};
26+
27+
class ShaderRegistry final {
28+
using ShaderListing = std::unordered_map<std::string, ShaderInfo>;
29+
using Dispatcher = std::unordered_map<DispatchKey, std::string>;
30+
using Registry = std::unordered_map<std::string, Dispatcher>;
31+
32+
ShaderListing listings_;
33+
Dispatcher dispatcher_;
34+
Registry registry_;
35+
36+
public:
37+
/*
38+
* Check if the registry has a shader registered under the given name
39+
*/
40+
bool has_shader(const std::string& shader_name);
41+
42+
/*
43+
* Check if the registry has a dispatch registered under the given name
44+
*/
45+
bool has_dispatch(const std::string& op_name);
46+
47+
/*
48+
* Register a ShaderInfo to a given shader name
49+
*/
50+
void register_shader(ShaderInfo&& shader_info);
51+
52+
/*
53+
* Register a dispatch entry to the given op name
54+
*/
55+
void register_op_dispatch(
56+
const std::string& op_name,
57+
const DispatchKey key,
58+
const std::string& shader_name);
59+
60+
/*
61+
* Given a shader name, return the ShaderInfo which contains the SPIRV binary
62+
*/
63+
const ShaderInfo& get_shader_info(const std::string& shader_name);
64+
};
65+
66+
class ShaderRegisterInit final {
67+
using InitFn = void();
68+
69+
public:
70+
ShaderRegisterInit(InitFn* init_fn) {
71+
init_fn();
72+
};
73+
};
74+
75+
// The global shader registry is retrieved using this function, where it is
76+
// declared as a static local variable.
77+
ShaderRegistry& shader_registry();
78+
79+
} // namespace api
80+
} // namespace vulkan
81+
} // namespace native
82+
} // namespace at
83+
84+
#endif /* USE_VULKAN_API */

aten/src/ATen/native/vulkan/api/api.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <ATen/native/vulkan/api/Resource.h>
1111
#include <ATen/native/vulkan/api/Runtime.h>
1212
#include <ATen/native/vulkan/api/Shader.h>
13+
#include <ATen/native/vulkan/api/ShaderRegistry.h>
1314
#include <ATen/native/vulkan/api/Tensor.h>
1415
#include <ATen/native/vulkan/api/Utils.h>
1516

aten/src/ATen/native/vulkan/impl/Common.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,6 @@
33
#ifdef USE_VULKAN_API
44

55
#include <ATen/native/vulkan/api/api.h>
6-
#include <ATen/native/vulkan/impl/Registry.h>
7-
8-
#define VK_KERNEL(shader_name) \
9-
::at::native::vulkan::get_shader_info(#shader_name)
10-
#define VK_LOOKUP_KERNEL(op_name) \
11-
::at::native::vulkan::look_up_shader_info(#op_name)
126

137
namespace at {
148
namespace native {

aten/src/ATen/native/vulkan/impl/Registry.cpp

Lines changed: 0 additions & 88 deletions
This file was deleted.

aten/src/ATen/native/vulkan/impl/Registry.h

Lines changed: 0 additions & 33 deletions
This file was deleted.

aten/src/ATen/native/vulkan/ops/Common.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,6 @@
1010
#include <ATen/native/vulkan/impl/Common.h>
1111
#include <ATen/native/vulkan/ops/Convert.h>
1212

13-
#define VK_KERNEL(shader_name) \
14-
::at::native::vulkan::get_shader_info(#shader_name)
15-
#define VK_LOOKUP_KERNEL(op_name) \
16-
::at::native::vulkan::look_up_shader_info(#op_name)
17-
1813
namespace at {
1914
namespace native {
2015
namespace vulkan {

aten/src/ATen/native/vulkan/ops/Convolution.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ static api::ShaderInfo get_shader(
319319

320320
switch (method) {
321321
case Conv2dSlidingWindow:
322-
shader = VK_LOOKUP_KERNEL(conv2d);
322+
shader = VK_KERNEL(conv2d);
323323
break;
324324
case Conv2dDepthwise:
325325
shader = VK_KERNEL(conv2d_dw);
@@ -335,7 +335,7 @@ static api::ShaderInfo get_shader(
335335
}
336336
break;
337337
case Conv2dPointwise:
338-
shader = VK_LOOKUP_KERNEL(conv2d_pw);
338+
shader = VK_KERNEL(conv2d_pw_output_tile_2x2);
339339
break;
340340
}
341341
return shader;

buckbuild.bzl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ def get_glsl_paths():
144144
],
145145
)
146146

147+
def spv_shader_library():
148+
pass
149+
147150
# @lint-ignore BUCKRESTRICTEDSYNTAX
148151
IS_OSS = read_config("pt", "is_oss", "0") == "1" # True for OSS BUCK build, and False for internal BUCK build
149152

@@ -700,6 +703,43 @@ def gen_aten_libtorch_files(name, extra_params = [], compatible_with = [], apple
700703
apple_sdks = apple_sdks,
701704
)
702705

706+
def vulkan_spv_shader_library(name, spv_filegroup):
707+
genrule_cmd = [
708+
"$(exe //xplat/caffe2/tools:gen_aten_vulkan_spv_bin)",
709+
"--glsl-paths $(location {})".format(spv_filegroup),
710+
"--output-path $OUT --env FLOAT_IMAGE_FORMAT={}".format(get_glsl_image_format()),
711+
"--glslc-path=$(exe //xplat/caffe2/fb/vulkan/dotslash:glslc)",
712+
"--tmp-dir-path=$TMP",
713+
]
714+
715+
genrule_name = "gen_{}_cpp".format(name)
716+
fb_xplat_genrule(
717+
name = "gen_{}_cpp".format(name),
718+
outs = {
719+
"{}.cpp".format(name): ["spv.cpp"],
720+
},
721+
cmd = " ".join(genrule_cmd),
722+
default_outs = ["."],
723+
labels = ["uses_dotslash"],
724+
)
725+
726+
fb_xplat_cxx_library(
727+
name = name,
728+
srcs = [
729+
":{}[{}.cpp]".format(genrule_name, name),
730+
],
731+
# Static initialization is used to register shaders to the global shader registry,
732+
# therefore link_whole must be True to make sure unused symbols are not discarded.
733+
# @lint-ignore BUCKLINT: Avoid `link_whole=True`
734+
link_whole = True,
735+
# Define a soname that can be used for dynamic loading in Java, Python, etc.
736+
soname = "lib{}.$(ext)".format(name),
737+
visibility = ["PUBLIC"],
738+
exported_deps = [
739+
"//xplat/caffe2:torch_vulkan_api",
740+
],
741+
)
742+
703743
def copy_metal(name, apple_sdks = None):
704744
cmd = []
705745
cmd_exe = []

0 commit comments

Comments
 (0)