Skip to content

Commit 0815c2b

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Introduce SpecVarList to represent specialization constants (#3078)
Summary: Pull Request resolved: #3078 ## Context Specialization constants are a useful tool to compile compute shaders with constants defined at runtime. The primary application of specialization constants is to define variables which may have an impact on how the code is compiled, for example: * the number of elements of an array * the range of a loop Compared to the shader codegen system, which produces a complete copy of the shader and for which variants must be defined at build time, specialization constants can be defined at runtime when the compute pipeline is built. Specialization constants are currently used to define local work group sizes in Vulkan, but the Compute API hard-codes the number of specialization constants accepted by the shader to 3. This changeset introduces the `SpecVar` and `SpecVarList` classes to manage specialization constants and enable additional specialization constants to be specified. ghstack-source-id: 222903462 exported-using-ghexport Reviewed By: copyrightly, jorgep31415 Differential Revision: D56225041 fbshipit-source-id: 88c94c09e380793c75edcb0a92c2987fac882431
1 parent 7e14c0e commit 0815c2b

File tree

5 files changed

+257
-36
lines changed

5 files changed

+257
-36
lines changed

backends/vulkan/runtime/api/Context.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,25 @@ Context::~Context() {
5959

6060
DescriptorSet Context::get_descriptor_set(
6161
const ShaderInfo& shader_descriptor,
62-
const utils::uvec3& local_workgroup_size) {
62+
const utils::uvec3& local_workgroup_size,
63+
const SpecVarList& additional_constants) {
6364
VkDescriptorSetLayout shader_layout =
6465
shader_layout_cache().retrieve(shader_descriptor.kernel_layout);
6566

6667
VkPipelineLayout pipeline_layout =
6768
pipeline_layout_cache().retrieve(shader_layout);
6869

70+
SpecVarList spec_constants = {
71+
SV(local_workgroup_size.data[0u]),
72+
SV(local_workgroup_size.data[1u]),
73+
SV(local_workgroup_size.data[2u])};
74+
75+
spec_constants.append(additional_constants);
76+
6977
VkPipeline pipeline = pipeline_cache().retrieve(
7078
{pipeline_layout_cache().retrieve(shader_layout),
7179
shader_cache().retrieve(shader_descriptor),
72-
local_workgroup_size});
80+
spec_constants});
7381

7482
cmd_.bind_pipeline(pipeline, pipeline_layout, local_workgroup_size);
7583

backends/vulkan/runtime/api/Context.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,16 @@ class Context final {
172172
}
173173
}
174174

175-
DescriptorSet get_descriptor_set(const ShaderInfo&, const utils::uvec3&);
175+
DescriptorSet get_descriptor_set(
176+
const ShaderInfo&,
177+
const utils::uvec3&,
178+
const SpecVarList&);
179+
180+
inline DescriptorSet get_descriptor_set(
181+
const ShaderInfo& shader_descriptor,
182+
const utils::uvec3& local_work_group_size) {
183+
return get_descriptor_set(shader_descriptor, local_work_group_size, {});
184+
}
176185

177186
void register_shader_dispatch(
178187
const DescriptorSet&,

backends/vulkan/runtime/api/Pipeline.cpp

Lines changed: 102 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,101 @@ VkImageLayout vk_layout(
9898
return VK_IMAGE_LAYOUT_UNDEFINED;
9999
}
100100

101+
//
102+
// SpecVar
103+
//
104+
105+
SpecVar::SpecVar() : type(SpecVar::Type::INT) {
106+
value.as_int32 = 0;
107+
}
108+
109+
SpecVar::SpecVar(const float val) : type(SpecVar::Type::FLOAT) {
110+
value.as_float = val;
111+
}
112+
113+
SpecVar::SpecVar(const int32_t val) : type(SpecVar::Type::INT) {
114+
value.as_int32 = val;
115+
}
116+
117+
SpecVar::SpecVar(const uint32_t val) : type(SpecVar::Type::UINT) {
118+
value.as_uint32 = val;
119+
}
120+
121+
SpecVar::SpecVar(const bool val) : type(SpecVar::Type::BOOL) {
122+
value.as_bool = val;
123+
}
124+
125+
uint32_t SpecVar::val_size() const {
126+
switch (type) {
127+
case SpecVar::Type::FLOAT:
128+
return sizeof(float);
129+
case SpecVar::Type::INT:
130+
return sizeof(int32_t);
131+
case SpecVar::Type::UINT:
132+
return sizeof(uint32_t);
133+
case SpecVar::Type::BOOL:
134+
return sizeof(bool);
135+
}
136+
return 4;
137+
}
138+
139+
uint32_t SpecVar::val_offset() const {
140+
return api::utils::safe_downcast<uint32_t>(offsetof(SpecVar, value));
141+
}
142+
143+
bool operator==(const SpecVar& lhs, const SpecVar& rhs) {
144+
if (lhs.type != rhs.type) {
145+
return false;
146+
}
147+
switch (lhs.type) {
148+
case SpecVar::Type::FLOAT:
149+
return lhs.value.as_float == rhs.value.as_float;
150+
case SpecVar::Type::INT:
151+
return lhs.value.as_int32 == rhs.value.as_int32;
152+
case SpecVar::Type::UINT:
153+
return lhs.value.as_uint32 == rhs.value.as_uint32;
154+
case SpecVar::Type::BOOL:
155+
return lhs.value.as_bool == rhs.value.as_bool;
156+
}
157+
return false;
158+
}
159+
160+
SpecVarList::SpecVarList() {}
161+
162+
SpecVarList::SpecVarList(std::initializer_list<SpecVar> init_list) {
163+
vars.resize(init_list.size());
164+
std::copy(init_list.begin(), init_list.end(), vars.begin());
165+
}
166+
167+
void SpecVarList::append(const SpecVarList& other) {
168+
vars.insert(vars.end(), other.vars.begin(), other.vars.end());
169+
}
170+
171+
std::vector<VkSpecializationMapEntry> SpecVarList::generate_map_entries()
172+
const {
173+
std::vector<VkSpecializationMapEntry> map_entries;
174+
map_entries.resize(vars.size());
175+
uint32_t cur_offset = 0u;
176+
for (uint32_t i = 0; i < vars.size(); ++i) {
177+
map_entries.at(i) = {
178+
i, cur_offset + vars.at(i).val_offset(), vars.at(i).val_size()};
179+
cur_offset += sizeof(SpecVar);
180+
}
181+
return map_entries;
182+
}
183+
184+
bool operator==(const SpecVarList& lhs, const SpecVarList& rhs) {
185+
if (lhs.size() != rhs.size()) {
186+
return false;
187+
}
188+
for (uint32_t i = 0; i < lhs.size(); ++i) {
189+
if (lhs.vars.at(i) != rhs.vars.at(i)) {
190+
return false;
191+
}
192+
}
193+
return true;
194+
}
195+
101196
//
102197
// PipelineLayout
103198
//
@@ -154,33 +249,14 @@ ComputePipeline::ComputePipeline(
154249
const ComputePipeline::Descriptor& descriptor,
155250
VkPipelineCache pipeline_cache)
156251
: device_(device), handle_{VK_NULL_HANDLE} {
157-
// NOLINTNEXTLINE
158-
constexpr VkSpecializationMapEntry specialization_map_entries[3]{
159-
// X
160-
{
161-
0u,
162-
offsetof(utils::uvec3, data[0u]),
163-
sizeof(utils::uvec3::data[0u]),
164-
},
165-
// Y
166-
{
167-
1u,
168-
offsetof(utils::uvec3, data[1u]),
169-
sizeof(utils::uvec3::data[1u]),
170-
},
171-
// Z
172-
{
173-
2u,
174-
offsetof(utils::uvec3, data[2u]),
175-
sizeof(utils::uvec3::data[2u]),
176-
},
177-
};
252+
std::vector<VkSpecializationMapEntry> map_entries =
253+
descriptor.specialization_constants.generate_map_entries();
178254

179255
const VkSpecializationInfo specialization_info{
180-
3u, // mapEntryCount
181-
specialization_map_entries, // pMapEntries
182-
sizeof(descriptor.local_work_group), // dataSize
183-
&descriptor.local_work_group, // pData
256+
descriptor.specialization_constants.size(), // mapEntryCount
257+
map_entries.data(), // pMapEntries
258+
descriptor.specialization_constants.data_nbytes(), // dataSize
259+
descriptor.specialization_constants.data(), // pData
184260
};
185261

186262
const VkPipelineShaderStageCreateInfo shader_stage_create_info{
@@ -242,7 +318,7 @@ bool operator==(
242318
return (
243319
_1.pipeline_layout == _2.pipeline_layout &&
244320
_1.shader_module == _2.shader_module &&
245-
_1.local_work_group == _2.local_work_group);
321+
_1.specialization_constants == _2.specialization_constants);
246322
}
247323

248324
//

backends/vulkan/runtime/api/Pipeline.h

Lines changed: 88 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,73 @@
1818
#include <mutex>
1919
#include <unordered_map>
2020

21+
#define SV(x) ::vkcompute::api::SpecVar(x)
22+
2123
namespace vkcompute {
2224
namespace api {
2325

26+
struct SpecVar final {
27+
enum class Type : uint8_t {
28+
FLOAT,
29+
INT,
30+
UINT,
31+
BOOL,
32+
};
33+
34+
union Value {
35+
int32_t as_int32;
36+
uint32_t as_uint32;
37+
float as_float;
38+
bool as_bool;
39+
};
40+
41+
Value value;
42+
Type type;
43+
44+
SpecVar();
45+
SpecVar(const float val);
46+
SpecVar(const int32_t val);
47+
SpecVar(const uint32_t val);
48+
SpecVar(const bool val);
49+
50+
uint32_t val_size() const;
51+
uint32_t val_offset() const;
52+
};
53+
54+
bool operator==(const SpecVar& lhs, const SpecVar& rhs);
55+
56+
class SpecVarList final {
57+
std::vector<SpecVar> vars;
58+
59+
public:
60+
SpecVarList();
61+
SpecVarList(std::initializer_list<SpecVar> init_list);
62+
63+
inline const SpecVar& at(const size_t index) const {
64+
return vars.at(index);
65+
}
66+
67+
inline const SpecVar* data() const {
68+
return vars.data();
69+
}
70+
71+
inline uint32_t size() const {
72+
return api::utils::safe_downcast<uint32_t>(vars.size());
73+
}
74+
75+
inline uint32_t data_nbytes() const {
76+
return vars.size() * sizeof(SpecVar);
77+
}
78+
79+
void append(const SpecVarList& other);
80+
81+
std::vector<VkSpecializationMapEntry> generate_map_entries() const;
82+
83+
friend bool operator==(const SpecVarList& lhs, const SpecVarList& rhs);
84+
};
85+
86+
bool operator==(const SpecVarList& lhs, const SpecVarList& rhs);
87+
2488
struct PipelineBarrier final {
2589
struct Stages final {
2690
VkPipelineStageFlags src;
@@ -83,7 +147,7 @@ class ComputePipeline final {
83147
struct Descriptor final {
84148
VkPipelineLayout pipeline_layout;
85149
VkShaderModule shader_module;
86-
utils::uvec3 local_work_group;
150+
SpecVarList specialization_constants;
87151
};
88152

89153
explicit ComputePipeline(
@@ -171,12 +235,29 @@ class ComputePipelineCache final {
171235
seed, std::hash<VkPipelineLayout>()(descriptor.pipeline_layout));
172236
seed = utils::hash_combine(
173237
seed, std::hash<VkShaderModule>()(descriptor.shader_module));
174-
seed = utils::hash_combine(
175-
seed, std::hash<uint32_t>()(descriptor.local_work_group.data[0u]));
176-
seed = utils::hash_combine(
177-
seed, std::hash<uint32_t>()(descriptor.local_work_group.data[1u]));
178-
seed = utils::hash_combine(
179-
seed, std::hash<uint32_t>()(descriptor.local_work_group.data[2u]));
238+
239+
const SpecVarList& spec_vars = descriptor.specialization_constants;
240+
seed = utils::hash_combine(seed, std::hash<uint32_t>()(spec_vars.size()));
241+
242+
for (int i = 0; i < spec_vars.size(); ++i) {
243+
const SpecVar& spec_var = spec_vars.at(i);
244+
size_t new_seed = 0;
245+
switch (spec_var.type) {
246+
case SpecVar::Type::FLOAT:
247+
new_seed = std::hash<float>()(spec_var.value.as_float);
248+
break;
249+
case SpecVar::Type::INT:
250+
new_seed = std::hash<int32_t>()(spec_var.value.as_int32);
251+
break;
252+
case SpecVar::Type::UINT:
253+
new_seed = std::hash<uint32_t>()(spec_var.value.as_uint32);
254+
break;
255+
case SpecVar::Type::BOOL:
256+
new_seed = std::hash<bool>()(spec_var.value.as_bool);
257+
break;
258+
}
259+
seed = utils::hash_combine(seed, new_seed);
260+
}
180261

181262
return seed;
182263
}

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,53 @@ TEST_F(VulkanComputeAPITest, retrieve_custom_shader_test) {
5050
ASSERT_TRUE(kernel.kernel_name == "test_shader");
5151
}
5252

53+
TEST_F(VulkanComputeAPITest, spec_var_classes_test) {
54+
// Check equality operator
55+
ASSERT_TRUE(SV(1.5f) == SV(1.5f));
56+
ASSERT_FALSE(SV(15.0f) == SV(15));
57+
ASSERT_FALSE(SV(1u) == SV(true));
58+
59+
size_t sv_size = sizeof(api::SpecVar);
60+
61+
api::SpecVarList spec_vars = {};
62+
ASSERT_TRUE(spec_vars.size() == 0);
63+
spec_vars = {SV(1.1f), SV(32), SV(45)};
64+
ASSERT_TRUE(spec_vars.size() == 3);
65+
api::SpecVarList spec_vars_other = {SV(2.6f), SV(true), SV(78u), SV(5.5f)};
66+
spec_vars.append(spec_vars_other);
67+
ASSERT_TRUE(spec_vars.size() == 7);
68+
69+
// Check validity of the data
70+
const api::SpecVar* data = spec_vars.data();
71+
ASSERT_TRUE(*(reinterpret_cast<const float*>(data + 3)) == 2.6f);
72+
ASSERT_TRUE(*(reinterpret_cast<const int32_t*>(data + 1)) == 32);
73+
ASSERT_TRUE(*(reinterpret_cast<const int32_t*>(data + 5)) == 78u);
74+
75+
// Check validity of the map entries
76+
std::vector<VkSpecializationMapEntry> entries =
77+
spec_vars.generate_map_entries();
78+
79+
for (size_t i = 0; i < spec_vars.size(); ++i) {
80+
ASSERT_TRUE(entries[i].constantID == i);
81+
ASSERT_TRUE(entries[i].offset == sv_size * i);
82+
if (i != 4) {
83+
ASSERT_TRUE(entries[i].size == 4);
84+
} else {
85+
ASSERT_TRUE(entries[i].size == 1);
86+
}
87+
}
88+
89+
// Check copy
90+
api::SpecVarList spec_vars_copy(spec_vars);
91+
ASSERT_TRUE(spec_vars_copy.size() == 7);
92+
93+
// Check validity of the copied data
94+
const api::SpecVar* copy_data = spec_vars_copy.data();
95+
ASSERT_TRUE(*(reinterpret_cast<const bool*>(copy_data + 4)) == true);
96+
ASSERT_TRUE(*(reinterpret_cast<const int32_t*>(copy_data + 2)) == 45);
97+
ASSERT_TRUE(*(reinterpret_cast<const float*>(copy_data + 6)) == 5.5f);
98+
}
99+
53100
TEST_F(VulkanComputeAPITest, update_params_between_submit) {
54101
api::context()->set_cmd(/*reusable = */ true);
55102
std::vector<int64_t> sizes = {4, 4, 2};

0 commit comments

Comments
 (0)