Skip to content

Commit 70532b0

Browse files
authored
[ET-VK][ez] Enable dynamic shape support when using push constants (#11302)
## Changes * Call `encode_execute()` upon resize in `VulkanBackend.cpp` * Minor update to `DispatchNode` to store push constant data array as a persistent member of the class ## Motivation Passing in tensor metadata (i.e. sizes, strides) via push constants is typically more performant than passing them via a UBO (uniform buffer object). However, currently dynamic shapes do not work when push constants are used as I realized that the tensor metadata contained in the push constants do not get updated. It appears that that `vkCmdPushConstants` sets the push constants when encoding the command buffer, however the push constants will not be updated if the command buffer is submitted for execution multiple times. Therefore, to update push constant values **the command buffer needs to be re-encoded**. ## Performance Impact This may add a small performance overhead (i.e. re-encoding the command buffer) when executing models with dynamic shapes. Models that do not trigger tensor resizing will not be impacted. However, I measured the impact on a llama 3.2 1B model and the impact of re-encoding a command buffer appears to be negligible. In any case, re-encoding the command buffer is a "necessary evil" when working with dynamic shapes, otherwise the tensor metadata seen by shaders may never get updated. Furthermore, re-encoding the command buffer can allow an opportunity to adjust global work group sizing to match current tensor sizes, which may have a huge performance impact when maximum tensor sizes far exceeds what tensor sizes will realistically be during inference (one instance of this is for transformer models when the max sequence length is very long). Differential Revision: [D75686051](https://our.internmc.facebook.com/intern/diff/D75686051/)
1 parent 498f249 commit 70532b0

File tree

7 files changed

+38
-17
lines changed

7 files changed

+38
-17
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
499499
compute_graph->encode_prepack();
500500
compute_graph->prepack();
501501

502+
// TODO(ssjia): remove this once we can batch compile compute pipelines
503+
// during prepare().
502504
compute_graph->encode_execute();
503505

504506
return Error::Ok;
@@ -567,9 +569,14 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
567569
}
568570
}
569571

572+
// propagate_resize() will re-encode the command buffer so that push
573+
// constants are updated and DynamicDispatchNode can update the compute
574+
// shader, global workgroup size, and local workgroup size to perform the
575+
// model inference.
570576
if (should_propagate_resize) {
571577
compute_graph->propagate_resize();
572578
}
579+
573580
compute_graph->execute();
574581

575582
for (size_t i = 0; i < compute_graph->outputs().size(); i++) {

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,11 +678,12 @@ void ComputeGraph::encode_execute() {
678678
}
679679
}
680680

681-
void ComputeGraph::execute() const {
681+
void ComputeGraph::execute() {
682682
vkapi::VulkanFence fence = context_->fences().get_fence();
683683
context_->submit_cmd_to_gpu(fence.get_submit_handle());
684684
fence.wait();
685685
context_->fences().return_fence(fence);
686+
execute_count_++;
686687
}
687688

688689
void ComputeGraph::resize_input(
@@ -696,6 +697,7 @@ void ComputeGraph::propagate_resize() {
696697
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
697698
node->trigger_resize(this);
698699
}
700+
encode_execute();
699701
}
700702

701703
} // namespace vkcompute

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ class ComputeGraph final {
187187

188188
protected:
189189
size_t values_in_use_ = 0;
190+
size_t execute_count_ = 0;
190191

191192
public:
192193
//
@@ -745,7 +746,7 @@ class ComputeGraph final {
745746
//
746747

747748
void encode_execute();
748-
void execute() const;
749+
void execute();
749750

750751
//
751752
// Dynamic Shape support
@@ -762,6 +763,10 @@ class ComputeGraph final {
762763
return context_->adapter_ptr()->supports_int16_shader_types();
763764
}
764765

766+
inline size_t execute_count() const {
767+
return execute_count_;
768+
}
769+
765770
/*
766771
* Check whether the GPU supports 8 bit buffers.
767772
*/

backends/vulkan/runtime/graph/ops/DispatchNode.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,7 @@ void DispatchNode::encode(ComputeGraph* graph) {
4646

4747
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();
4848

49-
std::array<uint8_t, kMaxPushConstantSize> push_constants_data;
50-
uint32_t push_constants_offset = 0;
51-
52-
for (const auto& push_constant : push_constants_) {
53-
push_constants_offset += push_constant.write(
54-
push_constants_data.data(),
55-
push_constants_offset,
56-
kMaxPushConstantSize);
57-
}
49+
write_push_constant_data();
5850

5951
context->report_shader_dispatch_start(
6052
shader_.kernel_name,
@@ -63,7 +55,7 @@ void DispatchNode::encode(ComputeGraph* graph) {
6355
node_id_);
6456

6557
vkapi::DescriptorSet descriptor_set = context->get_descriptor_set(
66-
shader_, local_workgroup_size_, spec_vars_, push_constants_offset);
58+
shader_, local_workgroup_size_, spec_vars_, push_constants_offset_);
6759

6860
uint32_t idx = 0;
6961
idx = bind_values_to_descriptor_set(
@@ -76,10 +68,20 @@ void DispatchNode::encode(ComputeGraph* graph) {
7668
pipeline_barrier,
7769
shader_,
7870
global_workgroup_size_,
79-
push_constants_data.data(),
80-
push_constants_offset);
71+
push_constants_data_.data(),
72+
push_constants_offset_);
8173

8274
context->report_shader_dispatch_end();
8375
}
8476

77+
void DispatchNode::write_push_constant_data() {
78+
push_constants_offset_ = 0;
79+
for (const auto& push_constant : push_constants_) {
80+
push_constants_offset_ += push_constant.write(
81+
push_constants_data_.data(),
82+
push_constants_offset_,
83+
kMaxPushConstantSize);
84+
}
85+
}
86+
8587
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/DispatchNode.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ class DispatchNode : public ExecuteNode {
5050
const vkapi::SpecVarList spec_vars_;
5151
const std::vector<PushConstantDataInfo> push_constants_;
5252

53+
// For push constants
54+
std::array<uint8_t, kMaxPushConstantSize> push_constants_data_{};
55+
uint32_t push_constants_offset_ = 0;
56+
57+
void write_push_constant_data();
58+
5359
public:
5460
operator bool() const {
5561
return shader_;

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class ExecuteNode {
6565
(void)graph;
6666
}
6767

68-
inline void trigger_resize(ComputeGraph* graph) {
68+
virtual inline void trigger_resize(ComputeGraph* graph) {
6969
if (resize_fn_ != nullptr) {
7070
resize_fn_(graph, args_, resize_args_);
7171
}

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1660,9 +1660,8 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
16601660
for (auto& new_sizes : new_sizes_list) {
16611661
graph.get_tensor(a.value)->virtual_resize(new_sizes);
16621662
graph.get_tensor(b.value)->virtual_resize(new_sizes);
1663-
graph.get_tensor(c)->virtual_resize(new_sizes);
16641663
graph.get_tensor(d.value)->virtual_resize(new_sizes);
1665-
graph.get_tensor(e)->virtual_resize(new_sizes);
1664+
graph.propagate_resize();
16661665

16671666
float val_a = new_sizes[1] + 4.0f;
16681667
float val_b = new_sizes[2] + 1.5f;

0 commit comments

Comments
 (0)