Skip to content

[ET-VK] Migrate ops to use DynamicDispatchNode #11312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,15 @@ ValueRef ComputeGraph::add_symint(const int32_t val) {
return idx;
}

ValueRef ComputeGraph::get_or_add_value_for_int(const int64_t val) {
for (int i = 0; i < values_.size(); ++i) {
if (values_.at(i).isInt() && values_.at(i).toInt() == val) {
return i;
}
}
return add_scalar(val);
}

ValueRef ComputeGraph::set_input_tensor(
const ValueRef idx,
const bool use_staging) {
Expand Down
7 changes: 7 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,13 @@ class ComputeGraph final {

ValueRef add_symint(const int32_t val);

/*
* Searches the graph's value list for a Int value with the specified value.
* If one is found, returns the index of the value. Otherwise, add a new value
* and return the index of the new value.
*/
ValueRef get_or_add_value_for_int(const int64_t val);

ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true);
ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true);

Expand Down
5 changes: 2 additions & 3 deletions backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ DynamicDispatchNode::DynamicDispatchNode(
const ResizeFunction& resize_fn)
: DispatchNode(
graph,
vkapi::ShaderInfo(),
{1u, 1u, 1u},
pick_shader_fn(&graph, args, resize_args),
{1u, 1u, 1u},
{8u, 8u, 1u},
args,
params,
push_constants,
Expand All @@ -37,7 +37,6 @@ DynamicDispatchNode::DynamicDispatchNode(
pick_shader_fn_(pick_shader_fn),
pick_global_wg_fn_(pick_global_wg_fn),
pick_local_wg_fn_(pick_local_wg_fn) {
shader_ = pick_shader_fn(&graph, args, resize_args);
global_workgroup_size_ =
pick_global_wg_fn(&graph, shader_, args, resize_args);
local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn(
Expand Down
17 changes: 9 additions & 8 deletions backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
Expand All @@ -30,8 +31,8 @@ void check_binary_op_args(
void resize_binary_op_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
(void)extra_args;
const std::vector<ValueRef>& resize_args) {
(void)resize_args;
vTensorPtr out = graph->get_tensor(args[0].refs[0]);

// TODO(T183442143): Verify tensors are broadcastable.
Expand Down Expand Up @@ -78,11 +79,11 @@ void add_binary_op_texture_node(
add_storage_type_suffix(kernel_name, *t_out);
add_dtype_suffix(kernel_name, *t_out);

graph.execute_nodes().emplace_back(new DispatchNode(
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
graph.create_local_wg_size(out),
default_pick_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{arg1, arg2}, vkapi::kRead}},
// Shader params buffers
Expand Down Expand Up @@ -122,11 +123,11 @@ void add_binary_op_buffer_node(
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_dtype_suffix(kernel_name, graph.dtype_of(out));

graph.execute_nodes().emplace_back(new DispatchNode(
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
graph.create_local_wg_size(out),
default_pick_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{in1, in2}, vkapi::kRead}},
// Shader params buffers
Expand Down
36 changes: 23 additions & 13 deletions backends/vulkan/runtime/graph/ops/impl/Clone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <executorch/backends/vulkan/runtime/graph/Logging.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/View.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
Expand All @@ -21,8 +22,8 @@ namespace vkcompute {
void resize_clone_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
(void)extra_args;
const std::vector<ValueRef>& resize_args) {
(void)resize_args;
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
// TODO: support for when dimensionality doesn't match, i.e. clone is used to
Expand All @@ -41,11 +42,11 @@ void add_clone_node(
std::string kernel_name = "clone";
add_dtype_suffix(kernel_name, *t_out);

graph.execute_nodes().emplace_back(new DispatchNode(
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
graph.create_local_wg_size(out),
default_pick_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
// Parameter Buffers
Expand All @@ -60,6 +61,17 @@ void add_clone_node(
resize_clone_node));
}

utils::uvec3 clone_image_to_buffer_global_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)shader;
(void)resize_args;
const ValueRef image = args.at(1).refs.at(0);
return graph->create_global_wg_size(image);
}

void add_image_to_buffer_node(
ComputeGraph& graph,
const ValueRef image,
Expand All @@ -68,12 +80,11 @@ void add_image_to_buffer_node(
add_dtype_suffix(kernel_name, graph.dtype_of(image));
vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name);

utils::uvec3 global_wg_size = graph.create_global_wg_size(image);
graph.execute_nodes().emplace_back(new DispatchNode(
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
shader,
global_wg_size,
graph.create_local_wg_size(global_wg_size),
clone_image_to_buffer_global_wg_size,
default_pick_local_wg_size,
// Input and Outputs
{{buffer, vkapi::kWrite}, {image, vkapi::kRead}},
// Parameter Buffers
Expand All @@ -96,12 +107,11 @@ void add_buffer_to_image_node(
add_dtype_suffix(kernel_name, graph.dtype_of(image));
vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name);

utils::uvec3 global_wg_size = graph.create_global_wg_size(image);
graph.execute_nodes().emplace_back(new DispatchNode(
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
shader,
global_wg_size,
graph.create_local_wg_size(global_wg_size),
default_pick_global_wg_size,
default_pick_local_wg_size,
// Input and Outputs
{{image, vkapi::kWrite}, {buffer, vkapi::kRead}},
// Parameter Buffers
Expand Down
7 changes: 5 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/Common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ utils::uvec3 default_pick_global_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& additional_args) {
const std::vector<ValueRef>& resize_args) {
(void)shader;
(void)resize_args;
const ValueRef out = args.at(0).refs.at(0);
return graph->create_global_wg_size(out);
}
Expand All @@ -25,8 +26,10 @@ utils::uvec3 default_pick_local_wg_size(
const vkapi::ShaderInfo& shader,
const utils::uvec3& global_workgroup_size,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& additional_args) {
const std::vector<ValueRef>& resize_args) {
(void)shader;
(void)args;
(void)resize_args;
return graph->create_local_wg_size(global_workgroup_size);
}

Expand Down
12 changes: 2 additions & 10 deletions backends/vulkan/runtime/graph/ops/impl/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,23 @@ namespace vkcompute {
* Creates a global workgroup size based on the first output tensor in the args.
* This is a utility function that extracts the output tensor from
* args.at(0).refs.at(0) and calls graph->create_global_wg_size(out) on it.
*
* @param graph The ComputeGraph instance
* @param args Vector of ArgGroup containing the output tensor reference
* @return utils::uvec3 The global workgroup size
*/
utils::uvec3 default_pick_global_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& additional_args);
const std::vector<ValueRef>& resize_args);

/**
* Creates a local workgroup size based on the first output tensor in the args.
* This is a utility function that extracts the output tensor from
* args.at(0).refs.at(0) and calls graph->create_local_wg_size(out) on it.
*
* @param graph The ComputeGraph instance
* @param args Vector of ArgGroup containing the output tensor reference
* @return utils::uvec3 The local workgroup size
*/
utils::uvec3 default_pick_local_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const utils::uvec3& global_workgroup_size,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& additional_args);
const std::vector<ValueRef>& resize_args);

} // namespace vkcompute
Loading
Loading