Skip to content

Commit b086901

Browse files
committed
[ET-VK] Introduce SpecVarList to represent specialization constants
Pull Request resolved: #3078 ## Context Specialization constants are a useful tool to compile compute shaders with constants defined at runtime. The primary application of specialization constants is to define variables which may have an impact on how the code is compiled, for example: * the number of elements of an array * the range of a loop Compared to the shader codegen system, which produces a complete copy of the shader and for which variants must be defined at build time, specialization constants can be defined at runtime when the compute pipeline is built. Specialization constants are currently used to define local work group sizes in Vulkan, but the Compute API hard-codes the number of specialization constants accepted by the shader to 3. This changeset introduces the `SpecVar` and `SpecVarList` classes to manage specialization constants and enable additional specialization constants to be specified. ghstack-source-id: 222863753 @exported-using-ghexport Differential Revision: [D56225041](https://our.internmc.facebook.com/intern/diff/D56225041/)
1 parent 4b6d2c3 commit b086901

File tree

5 files changed

+254
-36
lines changed

5 files changed

+254
-36
lines changed

backends/vulkan/runtime/api/Context.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,25 @@ Context::~Context() {
5959

6060
DescriptorSet Context::get_descriptor_set(
6161
const ShaderInfo& shader_descriptor,
62-
const utils::uvec3& local_workgroup_size) {
62+
const utils::uvec3& local_workgroup_size,
63+
const SpecVarList& additional_constants) {
6364
VkDescriptorSetLayout shader_layout =
6465
shader_layout_cache().retrieve(shader_descriptor.kernel_layout);
6566

6667
VkPipelineLayout pipeline_layout =
6768
pipeline_layout_cache().retrieve(shader_layout);
6869

70+
SpecVarList spec_constants = {
71+
SV(local_workgroup_size.data[0u]),
72+
SV(local_workgroup_size.data[1u]),
73+
SV(local_workgroup_size.data[2u])};
74+
75+
spec_constants.append(additional_constants);
76+
6977
VkPipeline pipeline = pipeline_cache().retrieve(
7078
{pipeline_layout_cache().retrieve(shader_layout),
7179
shader_cache().retrieve(shader_descriptor),
72-
local_workgroup_size});
80+
spec_constants});
7381

7482
cmd_.bind_pipeline(pipeline, pipeline_layout, local_workgroup_size);
7583

backends/vulkan/runtime/api/Context.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,16 @@ class Context final {
172172
}
173173
}
174174

175-
DescriptorSet get_descriptor_set(const ShaderInfo&, const utils::uvec3&);
175+
DescriptorSet get_descriptor_set(
176+
const ShaderInfo&,
177+
const utils::uvec3&,
178+
const SpecVarList&);
179+
180+
inline DescriptorSet get_descriptor_set(
181+
const ShaderInfo& shader_descriptor,
182+
const utils::uvec3& local_work_group_size) {
183+
return get_descriptor_set(shader_descriptor, local_work_group_size, {});
184+
}
176185

177186
void register_shader_dispatch(
178187
const DescriptorSet&,

backends/vulkan/runtime/api/Pipeline.cpp

Lines changed: 104 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,103 @@ VkImageLayout vk_layout(
9898
return VK_IMAGE_LAYOUT_UNDEFINED;
9999
}
100100

101+
//
102+
// SpecVar
103+
//
104+
105+
SpecVar::SpecVar() : type(SpecVar::Type::INT) {
106+
value.as_int32 = 0;
107+
}
108+
109+
SpecVar::SpecVar(const float val) : type(SpecVar::Type::FLOAT) {
110+
value.as_float = val;
111+
}
112+
113+
SpecVar::SpecVar(const int32_t val) : type(SpecVar::Type::INT) {
114+
value.as_int32 = val;
115+
}
116+
117+
SpecVar::SpecVar(const uint32_t val) : type(SpecVar::Type::UINT) {
118+
value.as_uint32 = val;
119+
}
120+
121+
SpecVar::SpecVar(const bool val) : type(SpecVar::Type::BOOL) {
122+
value.as_bool = val;
123+
}
124+
125+
uint32_t SpecVar::val_size() const {
126+
switch (type) {
127+
case SpecVar::Type::FLOAT:
128+
return sizeof(float);
129+
case SpecVar::Type::INT:
130+
return sizeof(int32_t);
131+
case SpecVar::Type::UINT:
132+
return sizeof(uint32_t);
133+
case SpecVar::Type::BOOL:
134+
return sizeof(bool);
135+
}
136+
return 4;
137+
}
138+
139+
uint32_t SpecVar::val_offset() const {
140+
return api::utils::safe_downcast<uint32_t>(offsetof(SpecVar, value));
141+
}
142+
143+
bool operator==(const SpecVar& lhs, const SpecVar& rhs) {
144+
if (lhs.type != rhs.type) {
145+
return false;
146+
}
147+
switch (lhs.type) {
148+
case SpecVar::Type::FLOAT:
149+
return lhs.value.as_float == rhs.value.as_float;
150+
case SpecVar::Type::INT:
151+
return lhs.value.as_int32 == rhs.value.as_int32;
152+
case SpecVar::Type::UINT:
153+
return lhs.value.as_uint32 == rhs.value.as_uint32;
154+
case SpecVar::Type::BOOL:
155+
return lhs.value.as_bool == rhs.value.as_bool;
156+
}
157+
return false;
158+
}
159+
160+
SpecVarList::SpecVarList() {
161+
vars.reserve(8);
162+
}
163+
164+
SpecVarList::SpecVarList(std::initializer_list<SpecVar> init_list) {
165+
vars.resize(init_list.size());
166+
std::copy(init_list.begin(), init_list.end(), vars.begin());
167+
}
168+
169+
void SpecVarList::append(const SpecVarList& other) {
170+
vars.insert(vars.end(), other.vars.begin(), other.vars.end());
171+
}
172+
173+
std::vector<VkSpecializationMapEntry> SpecVarList::generate_map_entries()
174+
const {
175+
std::vector<VkSpecializationMapEntry> map_entries;
176+
map_entries.resize(vars.size());
177+
uint32_t cur_offset = 0u;
178+
for (uint32_t i = 0; i < vars.size(); ++i) {
179+
map_entries[i] = {
180+
i, cur_offset + vars.at(i).val_offset(), vars.at(i).val_size()};
181+
cur_offset += sizeof(SpecVar);
182+
}
183+
return map_entries;
184+
}
185+
186+
bool operator==(const SpecVarList& lhs, const SpecVarList& rhs) {
187+
if (lhs.size() != rhs.size()) {
188+
return false;
189+
}
190+
for (uint32_t i = 0; i < lhs.size(); ++i) {
191+
if (lhs.vars.at(i) != rhs.vars.at(i)) {
192+
return false;
193+
}
194+
}
195+
return true;
196+
}
197+
101198
//
102199
// PipelineLayout
103200
//
@@ -154,33 +251,14 @@ ComputePipeline::ComputePipeline(
154251
const ComputePipeline::Descriptor& descriptor,
155252
VkPipelineCache pipeline_cache)
156253
: device_(device), handle_{VK_NULL_HANDLE} {
157-
// NOLINTNEXTLINE
158-
constexpr VkSpecializationMapEntry specialization_map_entries[3]{
159-
// X
160-
{
161-
0u,
162-
offsetof(utils::uvec3, data[0u]),
163-
sizeof(utils::uvec3::data[0u]),
164-
},
165-
// Y
166-
{
167-
1u,
168-
offsetof(utils::uvec3, data[1u]),
169-
sizeof(utils::uvec3::data[1u]),
170-
},
171-
// Z
172-
{
173-
2u,
174-
offsetof(utils::uvec3, data[2u]),
175-
sizeof(utils::uvec3::data[2u]),
176-
},
177-
};
254+
std::vector<VkSpecializationMapEntry> map_entries =
255+
descriptor.specialization_constants.generate_map_entries();
178256

179257
const VkSpecializationInfo specialization_info{
180-
3u, // mapEntryCount
181-
specialization_map_entries, // pMapEntries
182-
sizeof(descriptor.local_work_group), // dataSize
183-
&descriptor.local_work_group, // pData
258+
descriptor.specialization_constants.size(), // mapEntryCount
259+
map_entries.data(), // pMapEntries
260+
descriptor.specialization_constants.data_nbytes(), // dataSize
261+
descriptor.specialization_constants.data(), // pData
184262
};
185263

186264
const VkPipelineShaderStageCreateInfo shader_stage_create_info{
@@ -242,7 +320,7 @@ bool operator==(
242320
return (
243321
_1.pipeline_layout == _2.pipeline_layout &&
244322
_1.shader_module == _2.shader_module &&
245-
_1.local_work_group == _2.local_work_group);
323+
_1.specialization_constants == _2.specialization_constants);
246324
}
247325

248326
//

backends/vulkan/runtime/api/Pipeline.h

Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,68 @@
1818
#include <mutex>
1919
#include <unordered_map>
2020

21+
#define SV(x) ::vkcompute::api::SpecVar(x)
22+
2123
namespace vkcompute {
2224
namespace api {
2325

26+
struct SpecVar final {
27+
enum class Type : uint8_t {
28+
FLOAT,
29+
INT,
30+
UINT,
31+
BOOL,
32+
};
33+
34+
union Value {
35+
int32_t as_int32;
36+
uint32_t as_uint32;
37+
float as_float;
38+
bool as_bool;
39+
};
40+
41+
Value value;
42+
Type type;
43+
44+
SpecVar();
45+
SpecVar(const float val);
46+
SpecVar(const int32_t val);
47+
SpecVar(const uint32_t val);
48+
SpecVar(const bool val);
49+
50+
uint32_t val_size() const;
51+
uint32_t val_offset() const;
52+
};
53+
54+
bool operator==(const SpecVar& lhs, const SpecVar& rhs);
55+
56+
struct SpecVarList final {
57+
std::vector<SpecVar> vars;
58+
59+
SpecVarList();
60+
SpecVarList(std::initializer_list<SpecVar> init_list);
61+
62+
inline const SpecVar* data() const {
63+
return vars.data();
64+
}
65+
66+
inline uint32_t size() const {
67+
return api::utils::safe_downcast<uint32_t>(vars.size());
68+
}
69+
70+
inline uint32_t data_nbytes() const {
71+
return vars.size() * sizeof(SpecVar);
72+
}
73+
74+
void append(const SpecVarList& other);
75+
76+
std::vector<VkSpecializationMapEntry> generate_map_entries() const;
77+
78+
friend bool operator==(const SpecVarList& lhs, const SpecVarList& rhs);
79+
};
80+
81+
bool operator==(const SpecVarList& lhs, const SpecVarList& rhs);
82+
2483
struct PipelineBarrier final {
2584
struct Stages final {
2685
VkPipelineStageFlags src;
@@ -83,7 +142,7 @@ class ComputePipeline final {
83142
struct Descriptor final {
84143
VkPipelineLayout pipeline_layout;
85144
VkShaderModule shader_module;
86-
utils::uvec3 local_work_group;
145+
SpecVarList specialization_constants;
87146
};
88147

89148
explicit ComputePipeline(
@@ -171,12 +230,29 @@ class ComputePipelineCache final {
171230
seed, std::hash<VkPipelineLayout>()(descriptor.pipeline_layout));
172231
seed = utils::hash_combine(
173232
seed, std::hash<VkShaderModule>()(descriptor.shader_module));
174-
seed = utils::hash_combine(
175-
seed, std::hash<uint32_t>()(descriptor.local_work_group.data[0u]));
176-
seed = utils::hash_combine(
177-
seed, std::hash<uint32_t>()(descriptor.local_work_group.data[1u]));
178-
seed = utils::hash_combine(
179-
seed, std::hash<uint32_t>()(descriptor.local_work_group.data[2u]));
233+
234+
const SpecVarList& spec_vars = descriptor.specialization_constants;
235+
seed = utils::hash_combine(seed, std::hash<uint32_t>()(spec_vars.size()));
236+
237+
for (int i = 0; i < spec_vars.size(); ++i) {
238+
const SpecVar& spec_var = spec_vars.vars.at(i);
239+
size_t new_seed = 0;
240+
switch (spec_var.type) {
241+
case SpecVar::Type::FLOAT:
242+
new_seed = std::hash<float>()(spec_var.value.as_float);
243+
break;
244+
case SpecVar::Type::INT:
245+
new_seed = std::hash<int32_t>()(spec_var.value.as_int32);
246+
break;
247+
case SpecVar::Type::UINT:
248+
new_seed = std::hash<uint32_t>()(spec_var.value.as_uint32);
249+
break;
250+
case SpecVar::Type::BOOL:
251+
new_seed = std::hash<bool>()(spec_var.value.as_bool);
252+
break;
253+
}
254+
seed = utils::hash_combine(seed, new_seed);
255+
}
180256

181257
return seed;
182258
}

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,53 @@ TEST_F(VulkanComputeAPITest, retrieve_custom_shader_test) {
5050
ASSERT_TRUE(kernel.kernel_name == "test_shader");
5151
}
5252

53+
TEST_F(VulkanComputeAPITest, spec_var_classes_test) {
54+
// Check equality operator
55+
ASSERT_TRUE(SV(1.5f) == SV(1.5f));
56+
ASSERT_FALSE(SV(15.0f) == SV(15));
57+
ASSERT_FALSE(SV(1u) == SV(true));
58+
59+
size_t sv_size = sizeof(api::SpecVar);
60+
61+
api::SpecVarList spec_vars = {};
62+
ASSERT_TRUE(spec_vars.size() == 0);
63+
spec_vars = {SV(1.1f), SV(32), SV(45)};
64+
ASSERT_TRUE(spec_vars.size() == 3);
65+
api::SpecVarList spec_vars_other = {SV(2.6f), SV(true), SV(78u), SV(5.5f)};
66+
spec_vars.append(spec_vars_other);
67+
ASSERT_TRUE(spec_vars.size() == 7);
68+
69+
// Check validity of the data
70+
const api::SpecVar* data = spec_vars.data();
71+
ASSERT_TRUE(*(reinterpret_cast<const float*>(data + 3)) == 2.6f);
72+
ASSERT_TRUE(*(reinterpret_cast<const int32_t*>(data + 1)) == 32);
73+
ASSERT_TRUE(*(reinterpret_cast<const int32_t*>(data + 5)) == 78u);
74+
75+
// Check validity of the map entries
76+
std::vector<VkSpecializationMapEntry> entries =
77+
spec_vars.generate_map_entries();
78+
79+
for (size_t i = 0; i < spec_vars.size(); ++i) {
80+
ASSERT_TRUE(entries[i].constantID == i);
81+
ASSERT_TRUE(entries[i].offset == sv_size * i);
82+
if (i != 4) {
83+
ASSERT_TRUE(entries[i].size == 4);
84+
} else {
85+
ASSERT_TRUE(entries[i].size == 1);
86+
}
87+
}
88+
89+
// Check copy
90+
api::SpecVarList spec_vars_copy(spec_vars);
91+
ASSERT_TRUE(spec_vars_copy.size() == 7);
92+
93+
// Check validity of the copied data
94+
const api::SpecVar* copy_data = spec_vars_copy.data();
95+
ASSERT_TRUE(*(reinterpret_cast<const bool*>(copy_data + 4)) == true);
96+
ASSERT_TRUE(*(reinterpret_cast<const int32_t*>(copy_data + 2)) == 45);
97+
ASSERT_TRUE(*(reinterpret_cast<const float*>(copy_data + 6)) == 5.5f);
98+
}
99+
53100
TEST_F(VulkanComputeAPITest, update_params_between_submit) {
54101
api::context()->set_cmd(/*reusable = */ true);
55102
std::vector<int64_t> sizes = {4, 4, 2};

0 commit comments

Comments
 (0)