Skip to content

Commit e34d724

Browse files
[ET-VK] Adding function to set push constants in Command buffer.
Pull Request resolved: #7221 This diff adds a function to set push constants in the Command buffer for ET-VK. The changes include adding a new `set_push_constants` function to the CommandBuffer class and modifying the code in the CommandBuffer class to call this new function. ghstack-source-id: 257227241 @exported-using-ghexport Differential Revision: [D66714317](https://our.internmc.facebook.com/intern/diff/D66714317/) Co-authored-by: Vivek Trivedi <[email protected]>
1 parent 28ad3f2 commit e34d724

File tree

4 files changed

+31
-2
lines changed

4 files changed

+31
-2
lines changed

backends/vulkan/runtime/api/Context.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ void Context::register_shader_dispatch(
119119
const vkapi::DescriptorSet& descriptors,
120120
vkapi::PipelineBarrier& pipeline_barrier,
121121
const vkapi::ShaderInfo& shader_descriptor,
122-
const utils::uvec3& global_workgroup_size) {
122+
const utils::uvec3& global_workgroup_size,
123+
const void* push_constants_data,
124+
const uint32_t push_constants_size) {
123125
// Adjust the global workgroup size based on the output tile size
124126
uint32_t global_wg_w = utils::div_up(
125127
global_workgroup_size[0u], shader_descriptor.out_tile_size[0u]);
@@ -145,6 +147,15 @@ void Context::register_shader_dispatch(
145147
cmd_.bind_descriptors(descriptors.get_bind_handle());
146148
cmd_.insert_barrier(pipeline_barrier);
147149

150+
if (push_constants_size > 0 && push_constants_data != nullptr) {
151+
const VkDescriptorSetLayout shader_layout =
152+
shader_layout_cache().retrieve(shader_descriptor.kernel_layout);
153+
const VkPipelineLayout pipeline_layout =
154+
pipeline_layout_cache().retrieve(shader_layout);
155+
cmd_.set_push_constants(
156+
pipeline_layout, push_constants_data, push_constants_size);
157+
}
158+
148159
cmd_.dispatch(effective_global_wg);
149160
}
150161

backends/vulkan/runtime/api/Context.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,9 @@ class Context final {
200200
const vkapi::DescriptorSet&,
201201
vkapi::PipelineBarrier&,
202202
const vkapi::ShaderInfo&,
203-
const utils::uvec3&);
203+
const utils::uvec3&,
204+
const void* = nullptr,
205+
const uint32_t = 0);
204206

205207
void register_blit(
206208
vkapi::PipelineBarrier&,

backends/vulkan/runtime/vk_api/Command.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,21 @@ void CommandBuffer::bind_descriptors(VkDescriptorSet descriptors) {
122122
state_ = CommandBuffer::State::DESCRIPTORS_BOUND;
123123
}
124124

125+
void CommandBuffer::set_push_constants(
126+
VkPipelineLayout pipeline_layout,
127+
const void* push_constants_data,
128+
uint32_t push_constants_size) {
129+
if (push_constants_data != nullptr && push_constants_size > 0) {
130+
vkCmdPushConstants(
131+
handle_,
132+
pipeline_layout,
133+
VK_SHADER_STAGE_COMPUTE_BIT,
134+
0,
135+
push_constants_size,
136+
push_constants_data);
137+
}
138+
}
139+
125140
void CommandBuffer::insert_barrier(PipelineBarrier& pipeline_barrier) {
126141
VK_CHECK_COND(
127142
state_ == CommandBuffer::State::DESCRIPTORS_BOUND ||

backends/vulkan/runtime/vk_api/Command.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ class CommandBuffer final {
8989

9090
void bind_pipeline(VkPipeline, VkPipelineLayout, const utils::uvec3);
9191
void bind_descriptors(VkDescriptorSet);
92+
void set_push_constants(VkPipelineLayout, const void*, uint32_t);
9293

9394
void insert_barrier(PipelineBarrier& pipeline_barrier);
9495
void dispatch(const utils::uvec3&);

0 commit comments

Comments
 (0)