Skip to content

Commit 3550824

Browse files
[ET-VK] Consolidate shader compilation into one vkCreateComputePipelines call (#11381)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11345 by @jorgep31415 ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/jorgep31415/135/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/jorgep31415/135/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/jorgep31415/135/orig @diff-train-skip-merge Co-authored-by: jorgep31415 <[email protected]>
1 parent a9178f1 commit 3550824

File tree

10 files changed

+173
-2
lines changed

10 files changed

+173
-2
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
495495
builder.build_graph();
496496

497497
compute_graph->prepare();
498+
compute_graph->prepare_pipelines();
498499

499500
compute_graph->encode_prepack();
500501
compute_graph->prepack();

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,42 @@ void ComputeGraph::update_descriptor_counts(
561561
}
562562
}
563563

564+
void ComputeGraph::register_pipeline_to_create(
565+
const vkapi::ShaderInfo& shader_info,
566+
const utils::WorkgroupSize& local_workgroup_size,
567+
const vkapi::SpecVarList& spec_vars,
568+
const std::vector<PushConstantDataInfo>& push_constants) {
569+
VkDescriptorSetLayout shader_layout =
570+
context()->shader_layout_cache().retrieve(shader_info.kernel_layout);
571+
572+
uint32_t pc_offset = 0;
573+
std::array<uint8_t, kMaxPushConstantSize> pc_data;
574+
for (const auto& pc : push_constants) {
575+
pc_offset += pc.write(pc_data.data(), pc_offset, kMaxPushConstantSize);
576+
}
577+
578+
vkapi::SpecVarList spec_constants = {
579+
SV(local_workgroup_size[0u]),
580+
SV(local_workgroup_size[1u]),
581+
SV(local_workgroup_size[2u])};
582+
583+
spec_constants.append(spec_vars);
584+
585+
const vkapi::ComputePipelineCache::Key desc = {
586+
context()->pipeline_layout_cache().retrieve(shader_layout, pc_offset),
587+
context()->shader_cache().retrieve(shader_info),
588+
spec_constants};
589+
590+
if (context_->pipeline_cache().contains(desc)) {
591+
return;
592+
}
593+
auto it = pipeline_descriptors_.find(desc);
594+
if (it != pipeline_descriptors_.cend()) {
595+
return;
596+
}
597+
pipeline_descriptors_.insert(desc);
598+
}
599+
564600
utils::uvec3 ComputeGraph::create_global_wg_size(const ValueRef idx) {
565601
if (is_buffer_storage(idx)) {
566602
return {uint32_t(numel_of(idx)), 1u, 1u};
@@ -670,6 +706,20 @@ void ComputeGraph::prepare() {
670706
}
671707
}
672708

709+
void ComputeGraph::prepare_pipelines() {
710+
for (std::unique_ptr<PrepackNode>& node : prepack_nodes_) {
711+
node->prepare_pipelines(this);
712+
}
713+
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
714+
node->prepare_pipelines(this);
715+
}
716+
context_->pipeline_cache().create_pipelines(pipeline_descriptors_);
717+
718+
pipeline_descriptors_ = std::unordered_set<
719+
vkapi::ComputePipelineCache::Key,
720+
vkapi::ComputePipelineCache::Hasher>();
721+
}
722+
673723
void ComputeGraph::encode_prepack() {
674724
for (std::unique_ptr<PrepackNode>& node : prepack_nodes_) {
675725
node->encode(this);

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,11 @@ class ComputeGraph final {
185185
std::vector<IOValueRef> inputs_;
186186
std::vector<IOValueRef> outputs_;
187187

188+
std::unordered_set<
189+
vkapi::ComputePipelineCache::Key,
190+
vkapi::ComputePipelineCache::Hasher>
191+
pipeline_descriptors_;
192+
188193
protected:
189194
size_t values_in_use_ = 0;
190195
size_t execute_count_ = 0;
@@ -711,8 +716,16 @@ class ComputeGraph final {
711716
const vkapi::ShaderInfo& shader_info,
712717
bool execute);
713718

719+
void register_pipeline_to_create(
720+
const vkapi::ShaderInfo& shader_info,
721+
const utils::WorkgroupSize& local_workgroup_size,
722+
const vkapi::SpecVarList& spec_vars,
723+
const std::vector<PushConstantDataInfo>& push_constants);
724+
714725
void prepare();
715726

727+
void prepare_pipelines();
728+
716729
//
717730
// Dispatch Utilities
718731
//

backends/vulkan/runtime/graph/ops/DispatchNode.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ DispatchNode::DispatchNode(
3535
graph.update_descriptor_counts(shader, /*execute = */ true);
3636
}
3737

38+
void DispatchNode::prepare_pipelines(ComputeGraph* graph) {
39+
graph->register_pipeline_to_create(
40+
shader_, local_workgroup_size_, spec_vars_, push_constants_);
41+
}
42+
3843
void DispatchNode::encode(ComputeGraph* graph) {
3944
if (!shader_) {
4045
return;

backends/vulkan/runtime/graph/ops/DispatchNode.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class DispatchNode : public ExecuteNode {
4040

4141
~DispatchNode() override = default;
4242

43+
void prepare_pipelines(ComputeGraph* graph) override;
44+
4345
void encode(ComputeGraph* graph) override;
4446

4547
protected:

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ class ExecuteNode {
6161

6262
virtual ~ExecuteNode() = default;
6363

64+
virtual void prepare_pipelines(ComputeGraph* graph) {
65+
(void)graph;
66+
}
67+
6468
virtual void encode(ComputeGraph* graph) {
6569
(void)graph;
6670
}

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@ api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
6767
return staging;
6868
}
6969

70+
void PrepackNode::prepare_pipelines(ComputeGraph* graph) {
71+
graph->register_pipeline_to_create(
72+
shader_, local_workgroup_size_, spec_vars_, push_constants_);
73+
graph->register_pipeline_to_create(
74+
noop_shader_, utils::WorkgroupSize(1, 1, 1), {}, {});
75+
}
76+
7077
void PrepackNode::encode(ComputeGraph* graph) {
7178
api::Context* const context = graph->context();
7279

backends/vulkan/runtime/graph/ops/PrepackNode.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class PrepackNode final {
4040

4141
~PrepackNode() = default;
4242

43+
void prepare_pipelines(ComputeGraph* graph);
44+
4345
void encode(ComputeGraph* graph);
4446

4547
inline void set_node_id(uint32_t node_id) {

backends/vulkan/runtime/vk_api/Pipeline.cpp

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,9 @@ void swap(PipelineLayout& lhs, PipelineLayout& rhs) noexcept {
270270
// ComputePipeline
271271
//
272272

273+
ComputePipeline::ComputePipeline(VkDevice device, VkPipeline handle)
274+
: device_{device}, handle_{handle} {}
275+
273276
ComputePipeline::ComputePipeline(
274277
VkDevice device,
275278
const ComputePipeline::Descriptor& descriptor,
@@ -444,19 +447,94 @@ ComputePipelineCache::~ComputePipelineCache() {
444447
pipeline_cache_ = VK_NULL_HANDLE;
445448
}
446449

450+
bool ComputePipelineCache::contains(const ComputePipelineCache::Key& key) {
451+
std::lock_guard<std::mutex> lock(cache_mutex_);
452+
453+
auto it = cache_.find(key);
454+
return it != cache_.cend();
455+
}
456+
457+
void ComputePipelineCache::create_pipelines(
458+
const std::unordered_set<Key, Hasher>& descriptors) {
459+
std::lock_guard<std::mutex> lock(cache_mutex_);
460+
461+
const auto num_pipelines = descriptors.size();
462+
std::vector<VkPipeline> pipelines(num_pipelines);
463+
464+
std::vector<std::vector<VkSpecializationMapEntry>> map_entries;
465+
map_entries.reserve(num_pipelines);
466+
467+
std::vector<VkSpecializationInfo> specialization_infos;
468+
specialization_infos.reserve(num_pipelines);
469+
470+
std::vector<VkPipelineShaderStageCreateInfo> shader_stage_create_infos;
471+
shader_stage_create_infos.reserve(num_pipelines);
472+
473+
std::vector<VkComputePipelineCreateInfo> create_infos;
474+
create_infos.reserve(num_pipelines);
475+
476+
for (auto& key : descriptors) {
477+
map_entries.push_back(key.specialization_constants.generate_map_entries());
478+
479+
specialization_infos.push_back(VkSpecializationInfo{
480+
key.specialization_constants.size(), // mapEntryCount
481+
map_entries.back().data(), // pMapEntries
482+
key.specialization_constants.data_nbytes(), // dataSize
483+
key.specialization_constants.data(), // pData
484+
});
485+
486+
shader_stage_create_infos.push_back(VkPipelineShaderStageCreateInfo{
487+
VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // sType
488+
nullptr, // pNext
489+
0u, // flags
490+
VK_SHADER_STAGE_COMPUTE_BIT, // stage
491+
key.shader_module, // module
492+
"main", // pName
493+
&specialization_infos.back(), // pSpecializationInfo
494+
});
495+
496+
create_infos.push_back(VkComputePipelineCreateInfo{
497+
VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, // sType
498+
nullptr, // pNext
499+
0u, // flags
500+
shader_stage_create_infos.back(), // stage
501+
key.pipeline_layout, // layout
502+
VK_NULL_HANDLE, // basePipelineHandle
503+
0u, // basePipelineIndex
504+
});
505+
}
506+
507+
VK_CHECK(vkCreateComputePipelines(
508+
device_,
509+
pipeline_cache_,
510+
create_infos.size(),
511+
create_infos.data(),
512+
nullptr,
513+
pipelines.data()));
514+
515+
uint32_t i = 0;
516+
for (auto& key : descriptors) {
517+
auto it = cache_.find(key);
518+
if (it != cache_.cend()) {
519+
continue;
520+
}
521+
cache_.insert({key, ComputePipelineCache::Value(device_, pipelines[i])});
522+
++i;
523+
}
524+
}
525+
447526
VkPipeline ComputePipelineCache::retrieve(
448527
const ComputePipelineCache::Key& key) {
449528
std::lock_guard<std::mutex> lock(cache_mutex_);
450529

451530
auto it = cache_.find(key);
452-
if (cache_.cend() == it) {
531+
if (it == cache_.cend()) {
453532
it = cache_
454533
.insert(
455534
{key,
456535
ComputePipelineCache::Value(device_, key, pipeline_cache_)})
457536
.first;
458537
}
459-
460538
return it->second.handle();
461539
}
462540

backends/vulkan/runtime/vk_api/Pipeline.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include <mutex>
2121
#include <unordered_map>
22+
#include <unordered_set>
2223

2324
#define SV(x) ::vkcompute::vkapi::SpecVar(x)
2425

@@ -158,6 +159,8 @@ class ComputePipeline final {
158159
SpecVarList specialization_constants;
159160
};
160161

162+
explicit ComputePipeline(VkDevice device, VkPipeline handle);
163+
161164
explicit ComputePipeline(
162165
VkDevice device,
163166
const Descriptor& descriptor,
@@ -185,6 +188,10 @@ class ComputePipeline final {
185188
// does not allow for move assignment. The swap function will
186189
// be used in the hash map.
187190
friend void swap(ComputePipeline& lhs, ComputePipeline& rhs) noexcept;
191+
192+
friend bool operator==(
193+
const ComputePipeline::Descriptor& _1,
194+
const ComputePipeline::Descriptor& _2);
188195
};
189196

190197
class PipelineLayoutCache final {
@@ -293,6 +300,8 @@ class ComputePipelineCache final {
293300
const std::string cache_data_path_;
294301

295302
public:
303+
bool contains(const Key&);
304+
void create_pipelines(const std::unordered_set<Key, Hasher>&);
296305
VkPipeline retrieve(const Key&);
297306
void purge();
298307
};

0 commit comments

Comments
 (0)