Skip to content

[ET-VK] Introduce SpecVarList to represent specialization constants #3078

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions backends/vulkan/runtime/api/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,25 @@ Context::~Context() {

DescriptorSet Context::get_descriptor_set(
const ShaderInfo& shader_descriptor,
const utils::uvec3& local_workgroup_size) {
const utils::uvec3& local_workgroup_size,
const SpecVarList& additional_constants) {
VkDescriptorSetLayout shader_layout =
shader_layout_cache().retrieve(shader_descriptor.kernel_layout);

VkPipelineLayout pipeline_layout =
pipeline_layout_cache().retrieve(shader_layout);

SpecVarList spec_constants = {
SV(local_workgroup_size.data[0u]),
SV(local_workgroup_size.data[1u]),
SV(local_workgroup_size.data[2u])};

spec_constants.append(additional_constants);

VkPipeline pipeline = pipeline_cache().retrieve(
{pipeline_layout_cache().retrieve(shader_layout),
shader_cache().retrieve(shader_descriptor),
local_workgroup_size});
spec_constants});

cmd_.bind_pipeline(pipeline, pipeline_layout, local_workgroup_size);

Expand Down
11 changes: 10 additions & 1 deletion backends/vulkan/runtime/api/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,16 @@ class Context final {
}
}

DescriptorSet get_descriptor_set(const ShaderInfo&, const utils::uvec3&);
DescriptorSet get_descriptor_set(
const ShaderInfo&,
const utils::uvec3&,
const SpecVarList&);

inline DescriptorSet get_descriptor_set(
const ShaderInfo& shader_descriptor,
const utils::uvec3& local_work_group_size) {
return get_descriptor_set(shader_descriptor, local_work_group_size, {});
}

void register_shader_dispatch(
const DescriptorSet&,
Expand Down
128 changes: 102 additions & 26 deletions backends/vulkan/runtime/api/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,101 @@ VkImageLayout vk_layout(
return VK_IMAGE_LAYOUT_UNDEFINED;
}

//
// SpecVar
//

SpecVar::SpecVar() : type(SpecVar::Type::INT) {
value.as_int32 = 0;
}

SpecVar::SpecVar(const float val) : type(SpecVar::Type::FLOAT) {
value.as_float = val;
}

SpecVar::SpecVar(const int32_t val) : type(SpecVar::Type::INT) {
value.as_int32 = val;
}

SpecVar::SpecVar(const uint32_t val) : type(SpecVar::Type::UINT) {
value.as_uint32 = val;
}

SpecVar::SpecVar(const bool val) : type(SpecVar::Type::BOOL) {
value.as_bool = val;
}

uint32_t SpecVar::val_size() const {
switch (type) {
case SpecVar::Type::FLOAT:
return sizeof(float);
case SpecVar::Type::INT:
return sizeof(int32_t);
case SpecVar::Type::UINT:
return sizeof(uint32_t);
case SpecVar::Type::BOOL:
return sizeof(bool);
}
return 4;
}

uint32_t SpecVar::val_offset() const {
return api::utils::safe_downcast<uint32_t>(offsetof(SpecVar, value));
}

bool operator==(const SpecVar& lhs, const SpecVar& rhs) {
if (lhs.type != rhs.type) {
return false;
}
switch (lhs.type) {
case SpecVar::Type::FLOAT:
return lhs.value.as_float == rhs.value.as_float;
case SpecVar::Type::INT:
return lhs.value.as_int32 == rhs.value.as_int32;
case SpecVar::Type::UINT:
return lhs.value.as_uint32 == rhs.value.as_uint32;
case SpecVar::Type::BOOL:
return lhs.value.as_bool == rhs.value.as_bool;
}
return false;
}

SpecVarList::SpecVarList() {}

SpecVarList::SpecVarList(std::initializer_list<SpecVar> init_list) {
vars.resize(init_list.size());
std::copy(init_list.begin(), init_list.end(), vars.begin());
}

void SpecVarList::append(const SpecVarList& other) {
vars.insert(vars.end(), other.vars.begin(), other.vars.end());
}

std::vector<VkSpecializationMapEntry> SpecVarList::generate_map_entries()
const {
std::vector<VkSpecializationMapEntry> map_entries;
map_entries.resize(vars.size());
uint32_t cur_offset = 0u;
for (uint32_t i = 0; i < vars.size(); ++i) {
map_entries.at(i) = {
i, cur_offset + vars.at(i).val_offset(), vars.at(i).val_size()};
cur_offset += sizeof(SpecVar);
}
return map_entries;
}

bool operator==(const SpecVarList& lhs, const SpecVarList& rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
for (uint32_t i = 0; i < lhs.size(); ++i) {
if (lhs.vars.at(i) != rhs.vars.at(i)) {
return false;
}
}
return true;
}

//
// PipelineLayout
//
Expand Down Expand Up @@ -154,33 +249,14 @@ ComputePipeline::ComputePipeline(
const ComputePipeline::Descriptor& descriptor,
VkPipelineCache pipeline_cache)
: device_(device), handle_{VK_NULL_HANDLE} {
// NOLINTNEXTLINE
constexpr VkSpecializationMapEntry specialization_map_entries[3]{
// X
{
0u,
offsetof(utils::uvec3, data[0u]),
sizeof(utils::uvec3::data[0u]),
},
// Y
{
1u,
offsetof(utils::uvec3, data[1u]),
sizeof(utils::uvec3::data[1u]),
},
// Z
{
2u,
offsetof(utils::uvec3, data[2u]),
sizeof(utils::uvec3::data[2u]),
},
};
std::vector<VkSpecializationMapEntry> map_entries =
descriptor.specialization_constants.generate_map_entries();

const VkSpecializationInfo specialization_info{
3u, // mapEntryCount
specialization_map_entries, // pMapEntries
sizeof(descriptor.local_work_group), // dataSize
&descriptor.local_work_group, // pData
descriptor.specialization_constants.size(), // mapEntryCount
map_entries.data(), // pMapEntries
descriptor.specialization_constants.data_nbytes(), // dataSize
descriptor.specialization_constants.data(), // pData
};

const VkPipelineShaderStageCreateInfo shader_stage_create_info{
Expand Down Expand Up @@ -242,7 +318,7 @@ bool operator==(
return (
_1.pipeline_layout == _2.pipeline_layout &&
_1.shader_module == _2.shader_module &&
_1.local_work_group == _2.local_work_group);
_1.specialization_constants == _2.specialization_constants);
}

//
Expand Down
95 changes: 88 additions & 7 deletions backends/vulkan/runtime/api/Pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,73 @@
#include <mutex>
#include <unordered_map>

#define SV(x) ::vkcompute::api::SpecVar(x)

namespace vkcompute {
namespace api {

struct SpecVar final {
enum class Type : uint8_t {
FLOAT,
INT,
UINT,
BOOL,
};

union Value {
int32_t as_int32;
uint32_t as_uint32;
float as_float;
bool as_bool;
};

Value value;
Type type;

SpecVar();
SpecVar(const float val);
SpecVar(const int32_t val);
SpecVar(const uint32_t val);
SpecVar(const bool val);

uint32_t val_size() const;
uint32_t val_offset() const;
};

bool operator==(const SpecVar& lhs, const SpecVar& rhs);

class SpecVarList final {
std::vector<SpecVar> vars;

public:
SpecVarList();
SpecVarList(std::initializer_list<SpecVar> init_list);

inline const SpecVar& at(const size_t index) const {
return vars.at(index);
}

inline const SpecVar* data() const {
return vars.data();
}

inline uint32_t size() const {
return api::utils::safe_downcast<uint32_t>(vars.size());
}

inline uint32_t data_nbytes() const {
return vars.size() * sizeof(SpecVar);
}

void append(const SpecVarList& other);

std::vector<VkSpecializationMapEntry> generate_map_entries() const;

friend bool operator==(const SpecVarList& lhs, const SpecVarList& rhs);
};

bool operator==(const SpecVarList& lhs, const SpecVarList& rhs);

struct PipelineBarrier final {
struct Stages final {
VkPipelineStageFlags src;
Expand Down Expand Up @@ -83,7 +147,7 @@ class ComputePipeline final {
struct Descriptor final {
VkPipelineLayout pipeline_layout;
VkShaderModule shader_module;
utils::uvec3 local_work_group;
SpecVarList specialization_constants;
};

explicit ComputePipeline(
Expand Down Expand Up @@ -171,12 +235,29 @@ class ComputePipelineCache final {
seed, std::hash<VkPipelineLayout>()(descriptor.pipeline_layout));
seed = utils::hash_combine(
seed, std::hash<VkShaderModule>()(descriptor.shader_module));
seed = utils::hash_combine(
seed, std::hash<uint32_t>()(descriptor.local_work_group.data[0u]));
seed = utils::hash_combine(
seed, std::hash<uint32_t>()(descriptor.local_work_group.data[1u]));
seed = utils::hash_combine(
seed, std::hash<uint32_t>()(descriptor.local_work_group.data[2u]));

const SpecVarList& spec_vars = descriptor.specialization_constants;
seed = utils::hash_combine(seed, std::hash<uint32_t>()(spec_vars.size()));

for (int i = 0; i < spec_vars.size(); ++i) {
const SpecVar& spec_var = spec_vars.at(i);
size_t new_seed = 0;
switch (spec_var.type) {
case SpecVar::Type::FLOAT:
new_seed = std::hash<float>()(spec_var.value.as_float);
break;
case SpecVar::Type::INT:
new_seed = std::hash<int32_t>()(spec_var.value.as_int32);
break;
case SpecVar::Type::UINT:
new_seed = std::hash<uint32_t>()(spec_var.value.as_uint32);
break;
case SpecVar::Type::BOOL:
new_seed = std::hash<bool>()(spec_var.value.as_bool);
break;
}
seed = utils::hash_combine(seed, new_seed);
}

return seed;
}
Expand Down
47 changes: 47 additions & 0 deletions backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,53 @@ TEST_F(VulkanComputeAPITest, retrieve_custom_shader_test) {
ASSERT_TRUE(kernel.kernel_name == "test_shader");
}

TEST_F(VulkanComputeAPITest, spec_var_classes_test) {
// Check equality operator
ASSERT_TRUE(SV(1.5f) == SV(1.5f));
ASSERT_FALSE(SV(15.0f) == SV(15));
ASSERT_FALSE(SV(1u) == SV(true));

size_t sv_size = sizeof(api::SpecVar);

api::SpecVarList spec_vars = {};
ASSERT_TRUE(spec_vars.size() == 0);
spec_vars = {SV(1.1f), SV(32), SV(45)};
ASSERT_TRUE(spec_vars.size() == 3);
api::SpecVarList spec_vars_other = {SV(2.6f), SV(true), SV(78u), SV(5.5f)};
spec_vars.append(spec_vars_other);
ASSERT_TRUE(spec_vars.size() == 7);

// Check validity of the data
const api::SpecVar* data = spec_vars.data();
ASSERT_TRUE(*(reinterpret_cast<const float*>(data + 3)) == 2.6f);
ASSERT_TRUE(*(reinterpret_cast<const int32_t*>(data + 1)) == 32);
ASSERT_TRUE(*(reinterpret_cast<const int32_t*>(data + 5)) == 78u);

// Check validity of the map entries
std::vector<VkSpecializationMapEntry> entries =
spec_vars.generate_map_entries();

for (size_t i = 0; i < spec_vars.size(); ++i) {
ASSERT_TRUE(entries[i].constantID == i);
ASSERT_TRUE(entries[i].offset == sv_size * i);
if (i != 4) {
ASSERT_TRUE(entries[i].size == 4);
} else {
ASSERT_TRUE(entries[i].size == 1);
}
}

// Check copy
api::SpecVarList spec_vars_copy(spec_vars);
ASSERT_TRUE(spec_vars_copy.size() == 7);

// Check validity of the copied data
const api::SpecVar* copy_data = spec_vars_copy.data();
ASSERT_TRUE(*(reinterpret_cast<const bool*>(copy_data + 4)) == true);
ASSERT_TRUE(*(reinterpret_cast<const int32_t*>(copy_data + 2)) == 45);
ASSERT_TRUE(*(reinterpret_cast<const float*>(copy_data + 6)) == 5.5f);
}

TEST_F(VulkanComputeAPITest, update_params_between_submit) {
api::context()->set_cmd(/*reusable = */ true);
std::vector<int64_t> sizes = {4, 4, 2};
Expand Down