Skip to content

Commit bdc7896

Browse files
committed
Update on "[ET-VK] Enable additional specialization constants in compute shaders"
## Context Building on top of the previous changeset in the stack, this changeset modifies shader dispatch APIs to accept additional specialization constants for a shader. Differential Revision: [D56225042](https://our.internmc.facebook.com/intern/diff/D56225042/) [ghstack-poisoned]
2 parents b8817ef + 9237abc commit bdc7896

File tree

3 files changed

+43
-46
lines changed

3 files changed

+43
-46
lines changed

backends/vulkan/runtime/api/Pipeline.cpp

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -157,34 +157,38 @@ bool operator==(const SpecVar& lhs, const SpecVar& rhs) {
157157
return false;
158158
}
159159

160+
SpecVarList::SpecVarList() {
161+
vars.reserve(8);
162+
}
163+
160164
SpecVarList::SpecVarList(std::initializer_list<SpecVar> init_list) {
161-
VK_CHECK_COND(init_list.size() <= SPECVAR_LIST_LIMIT);
162-
arr_size = init_list.size();
163-
std::copy(init_list.begin(), init_list.end(), arr);
164-
uint32_t cur_offset = 0u;
165-
for (uint32_t i = 0; i < arr_size; ++i) {
166-
map_entries[i] = {i, cur_offset + arr[i].val_offset(), arr[i].val_size()};
167-
cur_offset += sizeof(SpecVar);
168-
}
165+
vars.resize(init_list.size());
166+
std::copy(init_list.begin(), init_list.end(), vars.begin());
169167
}
170168

171169
void SpecVarList::append(const SpecVarList& other) {
172-
VK_CHECK_COND(arr_size + other.size() <= SPECVAR_LIST_LIMIT);
173-
std::copy(other.arr, other.arr + other.size(), arr + arr_size);
174-
uint32_t cur_offset = arr_size * sizeof(SpecVar);
175-
for (uint32_t i = arr_size; i < arr_size + other.size(); ++i) {
176-
map_entries[i] = {i, cur_offset + arr[i].val_offset(), arr[i].val_size()};
170+
vars.insert(vars.end(), other.vars.begin(), other.vars.end());
171+
}
172+
173+
std::vector<VkSpecializationMapEntry> SpecVarList::generate_map_entries()
174+
const {
175+
std::vector<VkSpecializationMapEntry> map_entries;
176+
map_entries.resize(vars.size());
177+
uint32_t cur_offset = 0u;
178+
for (uint32_t i = 0; i < vars.size(); ++i) {
179+
map_entries[i] = {
180+
i, cur_offset + vars.at(i).val_offset(), vars.at(i).val_size()};
177181
cur_offset += sizeof(SpecVar);
178182
}
179-
arr_size += other.size();
183+
return map_entries;
180184
}
181185

182186
bool operator==(const SpecVarList& lhs, const SpecVarList& rhs) {
183187
if (lhs.size() != rhs.size()) {
184188
return false;
185189
}
186190
for (uint32_t i = 0; i < lhs.size(); ++i) {
187-
if (lhs.var_data()[i] != rhs.var_data()[i]) {
191+
if (lhs.vars.at(i) != rhs.vars.at(i)) {
188192
return false;
189193
}
190194
}
@@ -247,10 +251,13 @@ ComputePipeline::ComputePipeline(
247251
const ComputePipeline::Descriptor& descriptor,
248252
VkPipelineCache pipeline_cache)
249253
: device_(device), handle_{VK_NULL_HANDLE} {
254+
std::vector<VkSpecializationMapEntry> map_entries =
255+
descriptor.specialization_constants.generate_map_entries();
256+
250257
const VkSpecializationInfo specialization_info{
251258
descriptor.specialization_constants.size(), // mapEntryCount
252-
descriptor.specialization_constants.map_entries_data(), // pMapEntries
253-
descriptor.specialization_constants.map_entries_data_size(), // dataSize
259+
map_entries.data(), // pMapEntries
260+
descriptor.specialization_constants.data_nbytes(), // dataSize
254261
descriptor.specialization_constants.data(), // pData
255262
};
256263

backends/vulkan/runtime/api/Pipeline.h

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
#include <mutex>
1919
#include <unordered_map>
2020

21-
#define SPECVAR_LIST_LIMIT 8
22-
2321
#define SV(x) ::vkcompute::api::SpecVar(x)
2422

2523
namespace vkcompute {
@@ -55,38 +53,29 @@ struct SpecVar final {
5553

5654
bool operator==(const SpecVar& lhs, const SpecVar& rhs);
5755

58-
// using SpecVarList = std::vector<SpecVar>;
59-
60-
class SpecVarList final {
61-
SpecVar arr[SPECVAR_LIST_LIMIT];
62-
VkSpecializationMapEntry map_entries[SPECVAR_LIST_LIMIT];
63-
uint32_t arr_size;
56+
struct SpecVarList final {
57+
std::vector<SpecVar> vars;
6458

65-
public:
66-
SpecVarList() : arr_size(0) {}
59+
SpecVarList();
6760
SpecVarList(std::initializer_list<SpecVar> init_list);
6861

69-
inline const SpecVar* var_data() const {
70-
return &(arr[0]);
71-
}
72-
73-
inline const void* data() const {
74-
return &(arr[0]);
62+
inline const SpecVar* data() const {
63+
return vars.data();
7564
}
7665

7766
inline uint32_t size() const {
78-
return arr_size;
67+
return api::utils::safe_downcast<uint32_t>(vars.size());
7968
}
8069

81-
inline const VkSpecializationMapEntry* map_entries_data() const {
82-
return &(map_entries[0]);
83-
}
84-
85-
inline size_t map_entries_data_size() const {
86-
return arr_size * sizeof(VkSpecializationMapEntry);
70+
inline uint32_t data_nbytes() const {
71+
return vars.size() * sizeof(SpecVar);
8772
}
8873

8974
void append(const SpecVarList& other);
75+
76+
std::vector<VkSpecializationMapEntry> generate_map_entries() const;
77+
78+
friend bool operator==(const SpecVarList& lhs, const SpecVarList& rhs);
9079
};
9180

9281
bool operator==(const SpecVarList& lhs, const SpecVarList& rhs);
@@ -246,7 +235,7 @@ class ComputePipelineCache final {
246235
seed = utils::hash_combine(seed, std::hash<uint32_t>()(spec_vars.size()));
247236

248237
for (int i = 0; i < spec_vars.size(); ++i) {
249-
const SpecVar& spec_var = spec_vars.var_data()[i];
238+
const SpecVar& spec_var = spec_vars.vars.at(i);
250239
size_t new_seed = 0;
251240
switch (spec_var.type) {
252241
case SpecVar::Type::FLOAT:

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,14 @@ TEST_F(VulkanComputeAPITest, spec_var_classes_test) {
6767
ASSERT_TRUE(spec_vars.size() == 7);
6868

6969
// Check validity of the data
70-
const api::SpecVar* data = spec_vars.var_data();
70+
const api::SpecVar* data = spec_vars.data();
7171
ASSERT_TRUE(*(reinterpret_cast<const float*>(data + 3)) == 2.6f);
7272
ASSERT_TRUE(*(reinterpret_cast<const int32_t*>(data + 1)) == 32);
7373
ASSERT_TRUE(*(reinterpret_cast<const int32_t*>(data + 5)) == 78u);
7474

7575
// Check validity of the map entries
76-
const VkSpecializationMapEntry* entries = spec_vars.map_entries_data();
76+
std::vector<VkSpecializationMapEntry> entries =
77+
spec_vars.generate_map_entries();
7778

7879
for (size_t i = 0; i < spec_vars.size(); ++i) {
7980
ASSERT_TRUE(entries[i].constantID == i);
@@ -86,11 +87,11 @@ TEST_F(VulkanComputeAPITest, spec_var_classes_test) {
8687
}
8788

8889
// Check copy
89-
api::SpecVarList spec_vars_2(spec_vars);
90-
ASSERT_TRUE(spec_vars.size() == 7);
90+
api::SpecVarList spec_vars_copy(spec_vars);
91+
ASSERT_TRUE(spec_vars_copy.size() == 7);
9192

9293
// Check validity of the copied data
93-
const api::SpecVar* copy_data = spec_vars.var_data();
94+
const api::SpecVar* copy_data = spec_vars_copy.data();
9495
ASSERT_TRUE(*(reinterpret_cast<const bool*>(copy_data + 4)) == true);
9596
ASSERT_TRUE(*(reinterpret_cast<const int32_t*>(copy_data + 2)) == 45);
9697
ASSERT_TRUE(*(reinterpret_cast<const float*>(copy_data + 6)) == 5.5f);

0 commit comments

Comments
 (0)