Skip to content

[ET-VK] Consolidate shader compilation into one vkCreateComputePipelines call #11345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
builder.build_graph();

compute_graph->prepare();
compute_graph->prepare_pipelines();

compute_graph->encode_prepack();
compute_graph->prepack();
Expand Down
50 changes: 50 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,42 @@ void ComputeGraph::update_descriptor_counts(
}
}

void ComputeGraph::register_pipeline_to_create(
const vkapi::ShaderInfo& shader_info,
const utils::WorkgroupSize& local_workgroup_size,
const vkapi::SpecVarList& spec_vars,
const std::vector<PushConstantDataInfo>& push_constants) {
VkDescriptorSetLayout shader_layout =
context()->shader_layout_cache().retrieve(shader_info.kernel_layout);

uint32_t pc_offset = 0;
std::array<uint8_t, kMaxPushConstantSize> pc_data;
for (const auto& pc : push_constants) {
pc_offset += pc.write(pc_data.data(), pc_offset, kMaxPushConstantSize);
}

vkapi::SpecVarList spec_constants = {
SV(local_workgroup_size[0u]),
SV(local_workgroup_size[1u]),
SV(local_workgroup_size[2u])};

spec_constants.append(spec_vars);

const vkapi::ComputePipelineCache::Key desc = {
context()->pipeline_layout_cache().retrieve(shader_layout, pc_offset),
context()->shader_cache().retrieve(shader_info),
spec_constants};

if (context_->pipeline_cache().contains(desc)) {
return;
}
auto it = pipeline_descriptors_.find(desc);
if (it != pipeline_descriptors_.cend()) {
return;
}
pipeline_descriptors_.insert(desc);
}

utils::uvec3 ComputeGraph::create_global_wg_size(const ValueRef idx) {
if (is_buffer_storage(idx)) {
return {uint32_t(numel_of(idx)), 1u, 1u};
Expand Down Expand Up @@ -670,6 +706,20 @@ void ComputeGraph::prepare() {
}
}

void ComputeGraph::prepare_pipelines() {
for (std::unique_ptr<PrepackNode>& node : prepack_nodes_) {
node->prepare_pipelines(this);
}
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
node->prepare_pipelines(this);
}
context_->pipeline_cache().create_pipelines(pipeline_descriptors_);

pipeline_descriptors_ = std::unordered_set<
vkapi::ComputePipelineCache::Key,
vkapi::ComputePipelineCache::Hasher>();
}

void ComputeGraph::encode_prepack() {
for (std::unique_ptr<PrepackNode>& node : prepack_nodes_) {
node->encode(this);
Expand Down
13 changes: 13 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ class ComputeGraph final {
std::vector<IOValueRef> inputs_;
std::vector<IOValueRef> outputs_;

std::unordered_set<
vkapi::ComputePipelineCache::Key,
vkapi::ComputePipelineCache::Hasher>
pipeline_descriptors_;

protected:
size_t values_in_use_ = 0;
size_t execute_count_ = 0;
Expand Down Expand Up @@ -711,8 +716,16 @@ class ComputeGraph final {
const vkapi::ShaderInfo& shader_info,
bool execute);

void register_pipeline_to_create(
const vkapi::ShaderInfo& shader_info,
const utils::WorkgroupSize& local_workgroup_size,
const vkapi::SpecVarList& spec_vars,
const std::vector<PushConstantDataInfo>& push_constants);

void prepare();

void prepare_pipelines();

//
// Dispatch Utilities
//
Expand Down
5 changes: 5 additions & 0 deletions backends/vulkan/runtime/graph/ops/DispatchNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ DispatchNode::DispatchNode(
graph.update_descriptor_counts(shader, /*execute = */ true);
}

void DispatchNode::prepare_pipelines(ComputeGraph* graph) {
graph->register_pipeline_to_create(
shader_, local_workgroup_size_, spec_vars_, push_constants_);
}

void DispatchNode::encode(ComputeGraph* graph) {
if (!shader_) {
return;
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/DispatchNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class DispatchNode : public ExecuteNode {

~DispatchNode() override = default;

void prepare_pipelines(ComputeGraph* graph) override;

void encode(ComputeGraph* graph) override;

protected:
Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ class ExecuteNode {

virtual ~ExecuteNode() = default;

virtual void prepare_pipelines(ComputeGraph* graph) {
(void)graph;
}

virtual void encode(ComputeGraph* graph) {
(void)graph;
}
Expand Down
7 changes: 7 additions & 0 deletions backends/vulkan/runtime/graph/ops/PrepackNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
return staging;
}

void PrepackNode::prepare_pipelines(ComputeGraph* graph) {
graph->register_pipeline_to_create(
shader_, local_workgroup_size_, spec_vars_, push_constants_);
graph->register_pipeline_to_create(
noop_shader_, utils::WorkgroupSize(1, 1, 1), {}, {});
}

void PrepackNode::encode(ComputeGraph* graph) {
api::Context* const context = graph->context();

Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/PrepackNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class PrepackNode final {

~PrepackNode() = default;

void prepare_pipelines(ComputeGraph* graph);

void encode(ComputeGraph* graph);

inline void set_node_id(uint32_t node_id) {
Expand Down
82 changes: 80 additions & 2 deletions backends/vulkan/runtime/vk_api/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,9 @@ void swap(PipelineLayout& lhs, PipelineLayout& rhs) noexcept {
// ComputePipeline
//

ComputePipeline::ComputePipeline(VkDevice device, VkPipeline handle)
: device_{device}, handle_{handle} {}

ComputePipeline::ComputePipeline(
VkDevice device,
const ComputePipeline::Descriptor& descriptor,
Expand Down Expand Up @@ -444,19 +447,94 @@ ComputePipelineCache::~ComputePipelineCache() {
pipeline_cache_ = VK_NULL_HANDLE;
}

bool ComputePipelineCache::contains(const ComputePipelineCache::Key& key) {
std::lock_guard<std::mutex> lock(cache_mutex_);

auto it = cache_.find(key);
return it != cache_.cend();
}

void ComputePipelineCache::create_pipelines(
const std::unordered_set<Key, Hasher>& descriptors) {
std::lock_guard<std::mutex> lock(cache_mutex_);

const auto num_pipelines = descriptors.size();
std::vector<VkPipeline> pipelines(num_pipelines);

std::vector<std::vector<VkSpecializationMapEntry>> map_entries;
map_entries.reserve(num_pipelines);

std::vector<VkSpecializationInfo> specialization_infos;
specialization_infos.reserve(num_pipelines);

std::vector<VkPipelineShaderStageCreateInfo> shader_stage_create_infos;
shader_stage_create_infos.reserve(num_pipelines);

std::vector<VkComputePipelineCreateInfo> create_infos;
create_infos.reserve(num_pipelines);

for (auto& key : descriptors) {
map_entries.push_back(key.specialization_constants.generate_map_entries());

specialization_infos.push_back(VkSpecializationInfo{
key.specialization_constants.size(), // mapEntryCount
map_entries.back().data(), // pMapEntries
key.specialization_constants.data_nbytes(), // dataSize
key.specialization_constants.data(), // pData
});

shader_stage_create_infos.push_back(VkPipelineShaderStageCreateInfo{
VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // sType
nullptr, // pNext
0u, // flags
VK_SHADER_STAGE_COMPUTE_BIT, // stage
key.shader_module, // module
"main", // pName
&specialization_infos.back(), // pSpecializationInfo
});

create_infos.push_back(VkComputePipelineCreateInfo{
VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, // sType
nullptr, // pNext
0u, // flags
shader_stage_create_infos.back(), // stage
key.pipeline_layout, // layout
VK_NULL_HANDLE, // basePipelineHandle
0u, // basePipelineIndex
});
}

VK_CHECK(vkCreateComputePipelines(
device_,
pipeline_cache_,
create_infos.size(),
create_infos.data(),
nullptr,
pipelines.data()));

uint32_t i = 0;
for (auto& key : descriptors) {
auto it = cache_.find(key);
if (it != cache_.cend()) {
continue;
}
cache_.insert({key, ComputePipelineCache::Value(device_, pipelines[i])});
++i;
}
}

VkPipeline ComputePipelineCache::retrieve(
const ComputePipelineCache::Key& key) {
std::lock_guard<std::mutex> lock(cache_mutex_);

auto it = cache_.find(key);
if (cache_.cend() == it) {
if (it == cache_.cend()) {
it = cache_
.insert(
{key,
ComputePipelineCache::Value(device_, key, pipeline_cache_)})
.first;
}

return it->second.handle();
}

Expand Down
9 changes: 9 additions & 0 deletions backends/vulkan/runtime/vk_api/Pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <mutex>
#include <unordered_map>
#include <unordered_set>

#define SV(x) ::vkcompute::vkapi::SpecVar(x)

Expand Down Expand Up @@ -158,6 +159,8 @@ class ComputePipeline final {
SpecVarList specialization_constants;
};

explicit ComputePipeline(VkDevice device, VkPipeline handle);

explicit ComputePipeline(
VkDevice device,
const Descriptor& descriptor,
Expand Down Expand Up @@ -185,6 +188,10 @@ class ComputePipeline final {
// does not allow for move assignment. The swap function will
// be used in the hash map.
friend void swap(ComputePipeline& lhs, ComputePipeline& rhs) noexcept;

friend bool operator==(
const ComputePipeline::Descriptor& _1,
const ComputePipeline::Descriptor& _2);
};

class PipelineLayoutCache final {
Expand Down Expand Up @@ -293,6 +300,8 @@ class ComputePipelineCache final {
const std::string cache_data_path_;

public:
bool contains(const Key&);
void create_pipelines(const std::unordered_set<Key, Hasher>&);
VkPipeline retrieve(const Key&);
void purge();
};
Expand Down
Loading