Skip to content

Commit cdde4b5

Browse files
committed
Update on "[ET-VK] Introduce SpecVarList to represent specialization constants"
## Context Specialization constants are a useful tool to compile compute shaders with constants defined at runtime. The primary application of specialization constants is to define variables which may have an impact on how the code is compiled, for example: * the number of elements of an array * the range of a loop Compared to the shader codegen system, which produces a complete copy of the shader and for which variants must be defined at build time, specialization constants can be defined at runtime when the compute pipeline is built. Specialization constants are currently used to define local work group sizes in Vulkan, but the Compute API hard-codes the number of specialization constants accepted by the shader to 3. This changeset introduces the `SpecVar` and `SpecVarList` classes to manage specialization constants and enable additional specialization constants to be specified. Differential Revision: [D56225041](https://our.internmc.facebook.com/intern/diff/D56225041/) [ghstack-poisoned]
1 parent 837a4bc commit cdde4b5

File tree

3 files changed

+43
-46
lines changed

3 files changed

+43
-46
lines changed

backends/vulkan/runtime/api/Pipeline.cpp

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -157,34 +157,38 @@ bool operator==(const SpecVar& lhs, const SpecVar& rhs) {
157157
return false;
158158
}
159159

160+
SpecVarList::SpecVarList() {
161+
vars.reserve(8);
162+
}
163+
160164
SpecVarList::SpecVarList(std::initializer_list<SpecVar> init_list) {
161-
VK_CHECK_COND(init_list.size() <= SPECVAR_LIST_LIMIT);
162-
arr_size = init_list.size();
163-
std::copy(init_list.begin(), init_list.end(), arr);
164-
uint32_t cur_offset = 0u;
165-
for (uint32_t i = 0; i < arr_size; ++i) {
166-
map_entries[i] = {i, cur_offset + arr[i].val_offset(), arr[i].val_size()};
167-
cur_offset += sizeof(SpecVar);
168-
}
165+
vars.resize(init_list.size());
166+
std::copy(init_list.begin(), init_list.end(), vars.begin());
169167
}
170168

171169
void SpecVarList::append(const SpecVarList& other) {
172-
VK_CHECK_COND(arr_size + other.size() <= SPECVAR_LIST_LIMIT);
173-
std::copy(other.arr, other.arr + other.size(), arr + arr_size);
174-
uint32_t cur_offset = arr_size * sizeof(SpecVar);
175-
for (uint32_t i = arr_size; i < arr_size + other.size(); ++i) {
176-
map_entries[i] = {i, cur_offset + arr[i].val_offset(), arr[i].val_size()};
170+
vars.insert(vars.end(), other.vars.begin(), other.vars.end());
171+
}
172+
173+
std::vector<VkSpecializationMapEntry> SpecVarList::generate_map_entries()
174+
const {
175+
std::vector<VkSpecializationMapEntry> map_entries;
176+
map_entries.resize(vars.size());
177+
uint32_t cur_offset = 0u;
178+
for (uint32_t i = 0; i < vars.size(); ++i) {
179+
map_entries[i] = {
180+
i, cur_offset + vars.at(i).val_offset(), vars.at(i).val_size()};
177181
cur_offset += sizeof(SpecVar);
178182
}
179-
arr_size += other.size();
183+
return map_entries;
180184
}
181185

182186
bool operator==(const SpecVarList& lhs, const SpecVarList& rhs) {
183187
if (lhs.size() != rhs.size()) {
184188
return false;
185189
}
186190
for (uint32_t i = 0; i < lhs.size(); ++i) {
187-
if (lhs.var_data()[i] != rhs.var_data()[i]) {
191+
if (lhs.vars.at(i) != rhs.vars.at(i)) {
188192
return false;
189193
}
190194
}
@@ -247,10 +251,13 @@ ComputePipeline::ComputePipeline(
247251
const ComputePipeline::Descriptor& descriptor,
248252
VkPipelineCache pipeline_cache)
249253
: device_(device), handle_{VK_NULL_HANDLE} {
254+
std::vector<VkSpecializationMapEntry> map_entries =
255+
descriptor.specialization_constants.generate_map_entries();
256+
250257
const VkSpecializationInfo specialization_info{
251258
descriptor.specialization_constants.size(), // mapEntryCount
252-
descriptor.specialization_constants.map_entries_data(), // pMapEntries
253-
descriptor.specialization_constants.map_entries_data_size(), // dataSize
259+
map_entries.data(), // pMapEntries
260+
descriptor.specialization_constants.data_nbytes(), // dataSize
254261
descriptor.specialization_constants.data(), // pData
255262
};
256263

backends/vulkan/runtime/api/Pipeline.h

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
#include <mutex>
1919
#include <unordered_map>
2020

21-
#define SPECVAR_LIST_LIMIT 8
22-
2321
#define SV(x) ::vkcompute::api::SpecVar(x)
2422

2523
namespace vkcompute {
@@ -55,38 +53,29 @@ struct SpecVar final {
5553

5654
bool operator==(const SpecVar& lhs, const SpecVar& rhs);
5755

58-
// using SpecVarList = std::vector<SpecVar>;
59-
60-
class SpecVarList final {
61-
SpecVar arr[SPECVAR_LIST_LIMIT];
62-
VkSpecializationMapEntry map_entries[SPECVAR_LIST_LIMIT];
63-
uint32_t arr_size;
56+
struct SpecVarList final {
57+
std::vector<SpecVar> vars;
6458

65-
public:
66-
SpecVarList() : arr_size(0) {}
59+
SpecVarList();
6760
SpecVarList(std::initializer_list<SpecVar> init_list);
6861

69-
inline const SpecVar* var_data() const {
70-
return &(arr[0]);
71-
}
72-
73-
inline const void* data() const {
74-
return &(arr[0]);
62+
inline const SpecVar* data() const {
63+
return vars.data();
7564
}
7665

7766
inline uint32_t size() const {
78-
return arr_size;
67+
return api::utils::safe_downcast<uint32_t>(vars.size());
7968
}
8069

81-
inline const VkSpecializationMapEntry* map_entries_data() const {
82-
return &(map_entries[0]);
83-
}
84-
85-
inline size_t map_entries_data_size() const {
86-
return arr_size * sizeof(VkSpecializationMapEntry);
70+
inline uint32_t data_nbytes() const {
71+
return vars.size() * sizeof(SpecVar);
8772
}
8873

8974
void append(const SpecVarList& other);
75+
76+
std::vector<VkSpecializationMapEntry> generate_map_entries() const;
77+
78+
friend bool operator==(const SpecVarList& lhs, const SpecVarList& rhs);
9079
};
9180

9281
bool operator==(const SpecVarList& lhs, const SpecVarList& rhs);
@@ -246,7 +235,7 @@ class ComputePipelineCache final {
246235
seed = utils::hash_combine(seed, std::hash<uint32_t>()(spec_vars.size()));
247236

248237
for (int i = 0; i < spec_vars.size(); ++i) {
249-
const SpecVar& spec_var = spec_vars.var_data()[i];
238+
const SpecVar& spec_var = spec_vars.vars.at(i);
250239
size_t new_seed = 0;
251240
switch (spec_var.type) {
252241
case SpecVar::Type::FLOAT:

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,14 @@ TEST_F(VulkanComputeAPITest, spec_var_classes_test) {
6767
ASSERT_TRUE(spec_vars.size() == 7);
6868

6969
// Check validity of the data
70-
const api::SpecVar* data = spec_vars.var_data();
70+
const api::SpecVar* data = spec_vars.data();
7171
ASSERT_TRUE(*(reinterpret_cast<const float*>(data + 3)) == 2.6f);
7272
ASSERT_TRUE(*(reinterpret_cast<const int32_t*>(data + 1)) == 32);
7373
ASSERT_TRUE(*(reinterpret_cast<const int32_t*>(data + 5)) == 78u);
7474

7575
// Check validity of the map entries
76-
const VkSpecializationMapEntry* entries = spec_vars.map_entries_data();
76+
std::vector<VkSpecializationMapEntry> entries =
77+
spec_vars.generate_map_entries();
7778

7879
for (size_t i = 0; i < spec_vars.size(); ++i) {
7980
ASSERT_TRUE(entries[i].constantID == i);
@@ -86,11 +87,11 @@ TEST_F(VulkanComputeAPITest, spec_var_classes_test) {
8687
}
8788

8889
// Check copy
89-
api::SpecVarList spec_vars_2(spec_vars);
90-
ASSERT_TRUE(spec_vars.size() == 7);
90+
api::SpecVarList spec_vars_copy(spec_vars);
91+
ASSERT_TRUE(spec_vars_copy.size() == 7);
9192

9293
// Check validity of the copied data
93-
const api::SpecVar* copy_data = spec_vars.var_data();
94+
const api::SpecVar* copy_data = spec_vars_copy.data();
9495
ASSERT_TRUE(*(reinterpret_cast<const bool*>(copy_data + 4)) == true);
9596
ASSERT_TRUE(*(reinterpret_cast<const int32_t*>(copy_data + 2)) == 45);
9697
ASSERT_TRUE(*(reinterpret_cast<const float*>(copy_data + 6)) == 5.5f);

0 commit comments

Comments
 (0)