Skip to content

Commit 478c8f0

Browse files
committed
[ET-VK] Introduce SpecVarList to represent specialization constants
## Context Specialization constants are a useful tool to compile compute shaders with constants defined at runtime. The primary application of specialization constants is to define variables which may have an impact on how the code is compiled, for example: * the number of elements of an array * the range of a loop Compared to the shader codegen system, which produces a complete copy of the shader and for which variants must be defined at build time, specialization constants can be defined at runtime when the compute pipeline is built. Specialization constants are currently used to define local work group sizes in Vulkan, but the Compute API hard-codes the number of specialization constants accepted by the shader to 3. This changeset introduces the `SpecVar` and `SpecVarList` classes to manage specialization constants and enable additional specialization constants to be specified. Differential Revision: [D56225041](https://our.internmc.facebook.com/intern/diff/D56225041/) ghstack-source-id: 222806318 Pull Request resolved: #3078
1 parent 4b6d2c3 commit 478c8f0

File tree

5 files changed

+258
-37
lines changed

5 files changed

+258
-37
lines changed

backends/vulkan/runtime/api/Context.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,25 @@ Context::~Context() {
5959

6060
DescriptorSet Context::get_descriptor_set(
6161
const ShaderInfo& shader_descriptor,
62-
const utils::uvec3& local_workgroup_size) {
62+
const utils::uvec3& local_workgroup_size,
63+
const SpecVarList& additional_constants) {
6364
VkDescriptorSetLayout shader_layout =
6465
shader_layout_cache().retrieve(shader_descriptor.kernel_layout);
6566

6667
VkPipelineLayout pipeline_layout =
6768
pipeline_layout_cache().retrieve(shader_layout);
6869

70+
SpecVarList spec_constants = {
71+
SV(local_workgroup_size.data[0u]),
72+
SV(local_workgroup_size.data[1u]),
73+
SV(local_workgroup_size.data[2u])};
74+
75+
spec_constants.append(additional_constants);
76+
6977
VkPipeline pipeline = pipeline_cache().retrieve(
7078
{pipeline_layout_cache().retrieve(shader_layout),
7179
shader_cache().retrieve(shader_descriptor),
72-
local_workgroup_size});
80+
spec_constants});
7381

7482
cmd_.bind_pipeline(pipeline, pipeline_layout, local_workgroup_size);
7583

backends/vulkan/runtime/api/Context.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,16 @@ class Context final {
172172
}
173173
}
174174

175-
DescriptorSet get_descriptor_set(const ShaderInfo&, const utils::uvec3&);
175+
DescriptorSet get_descriptor_set(
176+
const ShaderInfo&,
177+
const utils::uvec3&,
178+
const SpecVarList&);
179+
180+
inline DescriptorSet get_descriptor_set(
181+
const ShaderInfo& shader_descriptor,
182+
const utils::uvec3& local_work_group_size) {
183+
return get_descriptor_set(shader_descriptor, local_work_group_size, {});
184+
}
176185

177186
void register_shader_dispatch(
178187
const DescriptorSet&,

backends/vulkan/runtime/api/Pipeline.cpp

Lines changed: 98 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,99 @@ VkImageLayout vk_layout(
9898
return VK_IMAGE_LAYOUT_UNDEFINED;
9999
}
100100

101+
//
102+
// SpecVar
103+
//
104+
105+
SpecVar::SpecVar() : type(SpecVar::Type::INT) {
106+
value.as_int32 = 0;
107+
}
108+
109+
SpecVar::SpecVar(const float val) : type(SpecVar::Type::FLOAT) {
110+
value.as_float = val;
111+
}
112+
113+
SpecVar::SpecVar(const int32_t val) : type(SpecVar::Type::INT) {
114+
value.as_int32 = val;
115+
}
116+
117+
SpecVar::SpecVar(const uint32_t val) : type(SpecVar::Type::UINT) {
118+
value.as_uint32 = val;
119+
}
120+
121+
SpecVar::SpecVar(const bool val) : type(SpecVar::Type::BOOL) {
122+
value.as_bool = val;
123+
}
124+
125+
uint32_t SpecVar::val_size() const {
126+
switch (type) {
127+
case SpecVar::Type::FLOAT:
128+
return sizeof(float);
129+
case SpecVar::Type::INT:
130+
return sizeof(int32_t);
131+
case SpecVar::Type::UINT:
132+
return sizeof(uint32_t);
133+
case SpecVar::Type::BOOL:
134+
return sizeof(bool);
135+
}
136+
return 4;
137+
}
138+
139+
uint32_t SpecVar::val_offset() const {
140+
return api::utils::safe_downcast<uint32_t>(offsetof(SpecVar, value));
141+
}
142+
143+
bool operator==(const SpecVar& lhs, const SpecVar& rhs) {
144+
if (lhs.type != rhs.type) {
145+
return false;
146+
}
147+
switch (lhs.type) {
148+
case SpecVar::Type::FLOAT:
149+
return lhs.value.as_float == rhs.value.as_float;
150+
case SpecVar::Type::INT:
151+
return lhs.value.as_int32 == rhs.value.as_int32;
152+
case SpecVar::Type::UINT:
153+
return lhs.value.as_uint32 == rhs.value.as_uint32;
154+
case SpecVar::Type::BOOL:
155+
return lhs.value.as_bool == rhs.value.as_bool;
156+
}
157+
return false;
158+
}
159+
160+
SpecVarList::SpecVarList(std::initializer_list<SpecVar> init_list) {
161+
VK_CHECK_COND(init_list.size() <= SPECVAR_LIST_LIMIT);
162+
arr_size = init_list.size();
163+
std::copy(init_list.begin(), init_list.end(), arr);
164+
uint32_t cur_offset = 0u;
165+
for (uint32_t i = 0; i < arr_size; ++i) {
166+
map_entries[i] = {i, cur_offset + arr[i].val_offset(), arr[i].val_size()};
167+
cur_offset += sizeof(SpecVar);
168+
}
169+
}
170+
171+
void SpecVarList::append(const SpecVarList& other) {
172+
VK_CHECK_COND(arr_size + other.size() <= SPECVAR_LIST_LIMIT);
173+
std::copy(other.arr, other.arr + other.size(), arr + arr_size);
174+
uint32_t cur_offset = arr_size * sizeof(SpecVar);
175+
for (uint32_t i = arr_size; i < arr_size + other.size(); ++i) {
176+
map_entries[i] = {i, cur_offset + arr[i].val_offset(), arr[i].val_size()};
177+
cur_offset += sizeof(SpecVar);
178+
}
179+
arr_size += other.size();
180+
}
181+
182+
bool operator==(const SpecVarList& lhs, const SpecVarList& rhs) {
183+
if (lhs.size() != rhs.size()) {
184+
return false;
185+
}
186+
for (uint32_t i = 0; i < lhs.size(); ++i) {
187+
if (lhs.var_data()[i] != rhs.var_data()[i]) {
188+
return false;
189+
}
190+
}
191+
return true;
192+
}
193+
101194
//
102195
// PipelineLayout
103196
//
@@ -154,33 +247,11 @@ ComputePipeline::ComputePipeline(
154247
const ComputePipeline::Descriptor& descriptor,
155248
VkPipelineCache pipeline_cache)
156249
: device_(device), handle_{VK_NULL_HANDLE} {
157-
// NOLINTNEXTLINE
158-
constexpr VkSpecializationMapEntry specialization_map_entries[3]{
159-
// X
160-
{
161-
0u,
162-
offsetof(utils::uvec3, data[0u]),
163-
sizeof(utils::uvec3::data[0u]),
164-
},
165-
// Y
166-
{
167-
1u,
168-
offsetof(utils::uvec3, data[1u]),
169-
sizeof(utils::uvec3::data[1u]),
170-
},
171-
// Z
172-
{
173-
2u,
174-
offsetof(utils::uvec3, data[2u]),
175-
sizeof(utils::uvec3::data[2u]),
176-
},
177-
};
178-
179250
const VkSpecializationInfo specialization_info{
180-
3u, // mapEntryCount
181-
specialization_map_entries, // pMapEntries
182-
sizeof(descriptor.local_work_group), // dataSize
183-
&descriptor.local_work_group, // pData
251+
descriptor.specialization_constants.size(), // mapEntryCount
252+
descriptor.specialization_constants.map_entries_data(), // pMapEntries
253+
descriptor.specialization_constants.map_entries_data_size(), // dataSize
254+
descriptor.specialization_constants.data(), // pData
184255
};
185256

186257
const VkPipelineShaderStageCreateInfo shader_stage_create_info{
@@ -242,7 +313,7 @@ bool operator==(
242313
return (
243314
_1.pipeline_layout == _2.pipeline_layout &&
244315
_1.shader_module == _2.shader_module &&
245-
_1.local_work_group == _2.local_work_group);
316+
_1.specialization_constants == _2.specialization_constants);
246317
}
247318

248319
//

backends/vulkan/runtime/api/Pipeline.h

Lines changed: 94 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,79 @@
1818
#include <mutex>
1919
#include <unordered_map>
2020

21+
#define SPECVAR_LIST_LIMIT 8
22+
23+
#define SV(x) ::vkcompute::api::SpecVar(x)
24+
2125
namespace vkcompute {
2226
namespace api {
2327

28+
struct SpecVar final {
29+
enum class Type : uint8_t {
30+
FLOAT,
31+
INT,
32+
UINT,
33+
BOOL,
34+
};
35+
36+
union Value {
37+
int32_t as_int32;
38+
uint32_t as_uint32;
39+
float as_float;
40+
bool as_bool;
41+
};
42+
43+
Value value;
44+
Type type;
45+
46+
SpecVar();
47+
SpecVar(const float val);
48+
SpecVar(const int32_t val);
49+
SpecVar(const uint32_t val);
50+
SpecVar(const bool val);
51+
52+
uint32_t val_size() const;
53+
uint32_t val_offset() const;
54+
};
55+
56+
bool operator==(const SpecVar& lhs, const SpecVar& rhs);
57+
58+
// using SpecVarList = std::vector<SpecVar>;
59+
60+
class SpecVarList final {
61+
SpecVar arr[SPECVAR_LIST_LIMIT];
62+
VkSpecializationMapEntry map_entries[SPECVAR_LIST_LIMIT];
63+
uint32_t arr_size;
64+
65+
public:
66+
SpecVarList() : arr_size(0) {}
67+
SpecVarList(std::initializer_list<SpecVar> init_list);
68+
69+
inline const SpecVar* var_data() const {
70+
return &(arr[0]);
71+
}
72+
73+
inline const void* data() const {
74+
return &(arr[0]);
75+
}
76+
77+
inline uint32_t size() const {
78+
return arr_size;
79+
}
80+
81+
inline const VkSpecializationMapEntry* map_entries_data() const {
82+
return &(map_entries[0]);
83+
}
84+
85+
inline size_t map_entries_data_size() const {
86+
return arr_size * sizeof(VkSpecializationMapEntry);
87+
}
88+
89+
void append(const SpecVarList& other);
90+
};
91+
92+
bool operator==(const SpecVarList& lhs, const SpecVarList& rhs);
93+
2494
struct PipelineBarrier final {
2595
struct Stages final {
2696
VkPipelineStageFlags src;
@@ -83,7 +153,7 @@ class ComputePipeline final {
83153
struct Descriptor final {
84154
VkPipelineLayout pipeline_layout;
85155
VkShaderModule shader_module;
86-
utils::uvec3 local_work_group;
156+
SpecVarList specialization_constants;
87157
};
88158

89159
explicit ComputePipeline(
@@ -171,12 +241,29 @@ class ComputePipelineCache final {
171241
seed, std::hash<VkPipelineLayout>()(descriptor.pipeline_layout));
172242
seed = utils::hash_combine(
173243
seed, std::hash<VkShaderModule>()(descriptor.shader_module));
174-
seed = utils::hash_combine(
175-
seed, std::hash<uint32_t>()(descriptor.local_work_group.data[0u]));
176-
seed = utils::hash_combine(
177-
seed, std::hash<uint32_t>()(descriptor.local_work_group.data[1u]));
178-
seed = utils::hash_combine(
179-
seed, std::hash<uint32_t>()(descriptor.local_work_group.data[2u]));
244+
245+
const SpecVarList& spec_vars = descriptor.specialization_constants;
246+
seed = utils::hash_combine(seed, std::hash<uint32_t>()(spec_vars.size()));
247+
248+
for (int i = 0; i < spec_vars.size(); ++i) {
249+
const SpecVar& spec_var = spec_vars.var_data()[i];
250+
size_t new_seed = 0;
251+
switch (spec_var.type) {
252+
case SpecVar::Type::FLOAT:
253+
new_seed = std::hash<float>()(spec_var.value.as_float);
254+
break;
255+
case SpecVar::Type::INT:
256+
new_seed = std::hash<int32_t>()(spec_var.value.as_int32);
257+
break;
258+
case SpecVar::Type::UINT:
259+
new_seed = std::hash<uint32_t>()(spec_var.value.as_uint32);
260+
break;
261+
case SpecVar::Type::BOOL:
262+
new_seed = std::hash<bool>()(spec_var.value.as_bool);
263+
break;
264+
}
265+
seed = utils::hash_combine(seed, new_seed);
266+
}
180267

181268
return seed;
182269
}

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,52 @@ TEST_F(VulkanComputeAPITest, retrieve_custom_shader_test) {
5050
ASSERT_TRUE(kernel.kernel_name == "test_shader");
5151
}
5252

53+
TEST_F(VulkanComputeAPITest, spec_var_classes_test) {
54+
// Check equality operator
55+
ASSERT_TRUE(SV(1.5f) == SV(1.5f));
56+
ASSERT_FALSE(SV(15.0f) == SV(15));
57+
ASSERT_FALSE(SV(1u) == SV(true));
58+
59+
size_t sv_size = sizeof(api::SpecVar);
60+
61+
api::SpecVarList spec_vars = {};
62+
ASSERT_TRUE(spec_vars.size() == 0);
63+
spec_vars = {SV(1.1f), SV(32), SV(45)};
64+
ASSERT_TRUE(spec_vars.size() == 3);
65+
api::SpecVarList spec_vars_other = {SV(2.6f), SV(true), SV(78u), SV(5.5f)};
66+
spec_vars.append(spec_vars_other);
67+
ASSERT_TRUE(spec_vars.size() == 7);
68+
69+
// Check validity of the data
70+
const api::SpecVar* data = spec_vars.var_data();
71+
ASSERT_TRUE(*(reinterpret_cast<const float*>(data + 3)) == 2.6f);
72+
ASSERT_TRUE(*(reinterpret_cast<const int32_t*>(data + 1)) == 32);
73+
ASSERT_TRUE(*(reinterpret_cast<const int32_t*>(data + 5)) == 78u);
74+
75+
// Check validity of the map entries
76+
const VkSpecializationMapEntry* entries = spec_vars.map_entries_data();
77+
78+
for (size_t i = 0; i < spec_vars.size(); ++i) {
79+
ASSERT_TRUE(entries[i].constantID == i);
80+
ASSERT_TRUE(entries[i].offset == sv_size * i);
81+
if (i != 4) {
82+
ASSERT_TRUE(entries[i].size == 4);
83+
} else {
84+
ASSERT_TRUE(entries[i].size == 1);
85+
}
86+
}
87+
88+
// Check copy
89+
api::SpecVarList spec_vars_2(spec_vars);
90+
ASSERT_TRUE(spec_vars.size() == 7);
91+
92+
// Check validity of the copied data
93+
const api::SpecVar* copy_data = spec_vars.var_data();
94+
ASSERT_TRUE(*(reinterpret_cast<const bool*>(copy_data + 4)) == true);
95+
ASSERT_TRUE(*(reinterpret_cast<const int32_t*>(copy_data + 2)) == 45);
96+
ASSERT_TRUE(*(reinterpret_cast<const float*>(copy_data + 6)) == 5.5f);
97+
}
98+
5399
TEST_F(VulkanComputeAPITest, update_params_between_submit) {
54100
api::context()->set_cmd(/*reusable = */ true);
55101
std::vector<int64_t> sizes = {4, 4, 2};

0 commit comments

Comments
 (0)