Skip to content

[ET-VK] Introduce generalized shaders for transfer ops and use it for select and slice #11304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
compute_graph->encode_prepack();
compute_graph->prepack();

// TODO(ssjia): remove this once we can batch compile compute pipelines
// during prepare().
compute_graph->encode_execute();

return Error::Ok;
Expand Down Expand Up @@ -567,9 +569,14 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
}
}

// propagate_resize() will re-encode the command buffer so that push
// constants are updated and DynamicDispatchNode can update the compute
// shader, global workgroup size, and local workgroup size to perform the
// model inference.
if (should_propagate_resize) {
compute_graph->propagate_resize();
}

compute_graph->execute();

for (size_t i = 0; i < compute_graph->outputs().size(); i++) {
Expand Down
22 changes: 20 additions & 2 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,14 +492,24 @@ vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
const ValueRef idx) {
if (values_.at(idx).isInt()) {
const int32_t val = extract_scalar<int32_t>(idx);
create_params_buffer(val);
return create_params_buffer(val);
} else if (values_.at(idx).isSymInt()) {
SymIntPtr symint = get_symint(idx);
return vkapi::BufferBindInfo(symint->gpu_buffer.buffer());
}
VK_THROW("Cannot create a int param buffer for the given value");
}

vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
const ValueRef idx,
const int32_t default_val) {
if (values_.at(idx).isNone()) {
return create_params_buffer(default_val);
} else {
return get_or_create_int_param_buffer(idx);
}
}

void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
get_symint(idx)->set(val);
}
Expand Down Expand Up @@ -678,11 +688,12 @@ void ComputeGraph::encode_execute() {
}
}

void ComputeGraph::execute() const {
void ComputeGraph::execute() {
vkapi::VulkanFence fence = context_->fences().get_fence();
context_->submit_cmd_to_gpu(fence.get_submit_handle());
fence.wait();
context_->fences().return_fence(fence);
execute_count_++;
}

void ComputeGraph::resize_input(
Expand All @@ -692,10 +703,17 @@ void ComputeGraph::resize_input(
get_tensor(io_val.value)->virtual_resize(new_sizes);
}

void ComputeGraph::virtual_resize(
const ValueRef idx,
const std::vector<int64_t>& new_sizes) {
get_tensor(idx)->virtual_resize(new_sizes);
}

void ComputeGraph::propagate_resize() {
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
node->trigger_resize(this);
}
encode_execute();
}

} // namespace vkcompute
27 changes: 26 additions & 1 deletion backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ class ComputeGraph final {

protected:
size_t values_in_use_ = 0;
size_t execute_count_ = 0;

public:
//
Expand Down Expand Up @@ -397,6 +398,19 @@ class ComputeGraph final {
std::optional<T> extract_optional_scalar(const ValueRef idx) {
if (val_is_none(idx)) {
return ::std::nullopt;
} else if (val_is_symint(idx)) {
return utils::safe_downcast<T>(read_symint(idx));
} else {
return extract_scalar<T>(idx);
}
}

template <typename T>
T extract_optional_scalar(const ValueRef idx, const T default_val) {
if (val_is_none(idx)) {
return default_val;
} else if (val_is_symint(idx)) {
return utils::safe_downcast<T>(read_symint(idx));
} else {
return extract_scalar<T>(idx);
}
Expand Down Expand Up @@ -608,6 +622,10 @@ class ComputeGraph final {
*/
vkapi::BufferBindInfo get_or_create_int_param_buffer(const ValueRef idx);

vkapi::BufferBindInfo get_or_create_int_param_buffer(
const ValueRef idx,
const int32_t default_value);

void set_symint(const ValueRef idx, const int32_t val);

int32_t read_symint(const ValueRef idx);
Expand Down Expand Up @@ -745,13 +763,16 @@ class ComputeGraph final {
//

void encode_execute();
void execute() const;
void execute();

//
// Dynamic Shape support
//

void resize_input(const int64_t idx, const std::vector<int64_t>& new_sizes);
void virtual_resize(
const ValueRef idx,
const std::vector<int64_t>& new_sizes);
void propagate_resize();

//
Expand All @@ -762,6 +783,10 @@ class ComputeGraph final {
return context_->adapter_ptr()->supports_int16_shader_types();
}

inline size_t execute_count() const {
return execute_count_;
}

/*
* Check whether the GPU supports 8 bit buffers.
*/
Expand Down
26 changes: 14 additions & 12 deletions backends/vulkan/runtime/graph/ops/DispatchNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,7 @@ void DispatchNode::encode(ComputeGraph* graph) {

std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();

std::array<uint8_t, kMaxPushConstantSize> push_constants_data;
uint32_t push_constants_offset = 0;

for (const auto& push_constant : push_constants_) {
push_constants_offset += push_constant.write(
push_constants_data.data(),
push_constants_offset,
kMaxPushConstantSize);
}
write_push_constant_data();

context->report_shader_dispatch_start(
shader_.kernel_name,
Expand All @@ -63,7 +55,7 @@ void DispatchNode::encode(ComputeGraph* graph) {
node_id_);

vkapi::DescriptorSet descriptor_set = context->get_descriptor_set(
shader_, local_workgroup_size_, spec_vars_, push_constants_offset);
shader_, local_workgroup_size_, spec_vars_, push_constants_offset_);

uint32_t idx = 0;
idx = bind_values_to_descriptor_set(
Expand All @@ -76,10 +68,20 @@ void DispatchNode::encode(ComputeGraph* graph) {
pipeline_barrier,
shader_,
global_workgroup_size_,
push_constants_data.data(),
push_constants_offset);
push_constants_data_.data(),
push_constants_offset_);

context->report_shader_dispatch_end();
}

void DispatchNode::write_push_constant_data() {
push_constants_offset_ = 0;
for (const auto& push_constant : push_constants_) {
push_constants_offset_ += push_constant.write(
push_constants_data_.data(),
push_constants_offset_,
kMaxPushConstantSize);
}
}

} // namespace vkcompute
6 changes: 6 additions & 0 deletions backends/vulkan/runtime/graph/ops/DispatchNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ class DispatchNode : public ExecuteNode {
const vkapi::SpecVarList spec_vars_;
const std::vector<PushConstantDataInfo> push_constants_;

// For push constants
std::array<uint8_t, kMaxPushConstantSize> push_constants_data_{};
uint32_t push_constants_offset_ = 0;

void write_push_constant_data();

public:
operator bool() const {
return shader_;
Expand Down
58 changes: 51 additions & 7 deletions backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ DynamicDispatchNode::DynamicDispatchNode(
const ResizeFunction& resize_fn)
: DispatchNode(
graph,
pick_shader_fn(&graph, args, resize_args),
pick_global_wg_fn(&graph, args, resize_args),
pick_local_wg_fn(&graph, args, resize_args),
vkapi::ShaderInfo(),
{1u, 1u, 1u},
{1u, 1u, 1u},
args,
params,
push_constants,
Expand All @@ -36,13 +36,57 @@ DynamicDispatchNode::DynamicDispatchNode(
resize_fn),
pick_shader_fn_(pick_shader_fn),
pick_global_wg_fn_(pick_global_wg_fn),
pick_local_wg_fn_(pick_local_wg_fn) {
shader_ = pick_shader_fn(&graph, args, resize_args);
global_workgroup_size_ =
pick_global_wg_fn(&graph, shader_, args, resize_args);
local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn(
&graph, shader_, global_workgroup_size_, args, resize_args));
}

DynamicDispatchNode::DynamicDispatchNode(
ComputeGraph& graph,
const vkapi::ShaderInfo& shader,
const PickGlobalFn& pick_global_wg_fn,
const PickLocalFn& pick_local_wg_fn,
const std::vector<ArgGroup>& args,
const vkapi::ParamsBindList& params,
const std::vector<PushConstantDataInfo>& push_constants,
const vkapi::SpecVarList& spec_vars,
const std::vector<ValueRef>& resize_args,
const ResizeFunction& resize_fn)
: DispatchNode(
graph,
shader,
pick_global_wg_fn(&graph, shader, args, resize_args),
pick_local_wg_fn(
&graph,
shader,
pick_global_wg_fn(&graph, shader, args, resize_args),
args,
resize_args),
args,
params,
push_constants,
spec_vars,
resize_args,
resize_fn),
pick_shader_fn_{nullptr},
pick_global_wg_fn_(pick_global_wg_fn),
pick_local_wg_fn_(pick_local_wg_fn) {}

void DynamicDispatchNode::encode(ComputeGraph* graph) {
shader_ = pick_shader_fn_(graph, args_, resize_args_);
global_workgroup_size_ = pick_global_wg_fn_(graph, args_, resize_args_);
local_workgroup_size_ =
utils::WorkgroupSize(pick_local_wg_fn_(graph, args_, resize_args_));
if (pick_shader_fn_) {
shader_ = pick_shader_fn_(graph, args_, resize_args_);
}
if (pick_global_wg_fn_) {
global_workgroup_size_ =
pick_global_wg_fn_(graph, shader_, args_, resize_args_);
}
if (pick_local_wg_fn_) {
local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn_(
graph, shader_, global_workgroup_size_, args_, resize_args_));
}
DispatchNode::encode(graph);
}

Expand Down
15 changes: 15 additions & 0 deletions backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ class DynamicDispatchNode final : public DispatchNode {
const std::vector<ValueRef>&)>;
using PickGlobalFn = const std::function<utils::uvec3(
ComputeGraph*,
const vkapi::ShaderInfo& shader,
const std::vector<ArgGroup>&,
const std::vector<ValueRef>&)>;
using PickLocalFn = const std::function<utils::uvec3(
ComputeGraph*,
const vkapi::ShaderInfo& shader,
const utils::uvec3& global_workgroup_size,
const std::vector<ArgGroup>&,
const std::vector<ValueRef>&)>;

Expand All @@ -51,6 +54,18 @@ class DynamicDispatchNode final : public DispatchNode {
const std::vector<ValueRef>& resize_args,
const ResizeFunction& resize_fn = nullptr);

explicit DynamicDispatchNode(
ComputeGraph& graph,
const vkapi::ShaderInfo& shader,
const PickGlobalFn& pick_global_wg_fn,
const PickLocalFn& pick_local_wg_fn,
const std::vector<ArgGroup>& args,
const vkapi::ParamsBindList& params,
const std::vector<PushConstantDataInfo>& push_constants,
const vkapi::SpecVarList& spec_vars,
const std::vector<ValueRef>& resize_args,
const ResizeFunction& resize_fn = nullptr);

~DynamicDispatchNode() override = default;

void encode(ComputeGraph* graph) override;
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/ExecuteNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class ExecuteNode {
(void)graph;
}

inline void trigger_resize(ComputeGraph* graph) {
virtual inline void trigger_resize(ComputeGraph* graph) {
if (resize_fn_ != nullptr) {
resize_fn_(graph, args_, resize_args_);
}
Expand Down
Loading
Loading