Skip to content

Commit 2106ab9

Browse files
authored
[ET-VK] Migrate ops to use DynamicDispatchNode
Differential Revision: D75878398 Pull Request resolved: #11312
1 parent c81cc04 commit 2106ab9

20 files changed

+521
-264
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,15 @@ ValueRef ComputeGraph::add_symint(const int32_t val) {
449449
return idx;
450450
}
451451

452+
ValueRef ComputeGraph::get_or_add_value_for_int(const int64_t val) {
453+
for (int i = 0; i < values_.size(); ++i) {
454+
if (values_.at(i).isInt() && values_.at(i).toInt() == val) {
455+
return i;
456+
}
457+
}
458+
return add_scalar(val);
459+
}
460+
452461
ValueRef ComputeGraph::set_input_tensor(
453462
const ValueRef idx,
454463
const bool use_staging) {

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,13 @@ class ComputeGraph final {
604604

605605
ValueRef add_symint(const int32_t val);
606606

607+
/*
608+
* Searches the graph's value list for a Int value with the specified value.
609+
* If one is found, returns the index of the value. Otherwise, add a new value
610+
* and return the index of the new value.
611+
*/
612+
ValueRef get_or_add_value_for_int(const int64_t val);
613+
607614
ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true);
608615
ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true);
609616

backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ DynamicDispatchNode::DynamicDispatchNode(
2525
const ResizeFunction& resize_fn)
2626
: DispatchNode(
2727
graph,
28-
vkapi::ShaderInfo(),
29-
{1u, 1u, 1u},
28+
pick_shader_fn(&graph, args, resize_args),
3029
{1u, 1u, 1u},
30+
{8u, 8u, 1u},
3131
args,
3232
params,
3333
push_constants,
@@ -37,7 +37,6 @@ DynamicDispatchNode::DynamicDispatchNode(
3737
pick_shader_fn_(pick_shader_fn),
3838
pick_global_wg_fn_(pick_global_wg_fn),
3939
pick_local_wg_fn_(pick_local_wg_fn) {
40-
shader_ = pick_shader_fn(&graph, args, resize_args);
4140
global_workgroup_size_ =
4241
pick_global_wg_fn(&graph, shader_, args, resize_args);
4342
local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn(

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1213

1314
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
@@ -30,8 +31,8 @@ void check_binary_op_args(
3031
void resize_binary_op_node(
3132
ComputeGraph* graph,
3233
const std::vector<ArgGroup>& args,
33-
const std::vector<ValueRef>& extra_args) {
34-
(void)extra_args;
34+
const std::vector<ValueRef>& resize_args) {
35+
(void)resize_args;
3536
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
3637

3738
// TODO(T183442143): Verify tensors are broadcastable.
@@ -78,11 +79,11 @@ void add_binary_op_texture_node(
7879
add_storage_type_suffix(kernel_name, *t_out);
7980
add_dtype_suffix(kernel_name, *t_out);
8081

81-
graph.execute_nodes().emplace_back(new DispatchNode(
82+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
8283
graph,
8384
VK_KERNEL_FROM_STR(kernel_name),
84-
graph.create_global_wg_size(out),
85-
graph.create_local_wg_size(out),
85+
default_pick_global_wg_size,
86+
default_pick_local_wg_size,
8687
// Inputs and Outputs
8788
{{out, vkapi::kWrite}, {{arg1, arg2}, vkapi::kRead}},
8889
// Shader params buffers
@@ -122,11 +123,11 @@ void add_binary_op_buffer_node(
122123
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
123124
add_dtype_suffix(kernel_name, graph.dtype_of(out));
124125

125-
graph.execute_nodes().emplace_back(new DispatchNode(
126+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
126127
graph,
127128
VK_KERNEL_FROM_STR(kernel_name),
128-
graph.create_global_wg_size(out),
129-
graph.create_local_wg_size(out),
129+
default_pick_global_wg_size,
130+
default_pick_local_wg_size,
130131
// Inputs and Outputs
131132
{{out, vkapi::kWrite}, {{in1, in2}, vkapi::kRead}},
132133
// Shader params buffers

backends/vulkan/runtime/graph/ops/impl/Clone.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/backends/vulkan/runtime/graph/Logging.h>
1212

13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
1314
#include <executorch/backends/vulkan/runtime/graph/ops/impl/View.h>
1415

1516
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
@@ -21,8 +22,8 @@ namespace vkcompute {
2122
void resize_clone_node(
2223
ComputeGraph* graph,
2324
const std::vector<ArgGroup>& args,
24-
const std::vector<ValueRef>& extra_args) {
25-
(void)extra_args;
25+
const std::vector<ValueRef>& resize_args) {
26+
(void)resize_args;
2627
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
2728
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
2829
// TODO: support for when dimensionality doesn't match, i.e. clone is used to
@@ -41,11 +42,11 @@ void add_clone_node(
4142
std::string kernel_name = "clone";
4243
add_dtype_suffix(kernel_name, *t_out);
4344

44-
graph.execute_nodes().emplace_back(new DispatchNode(
45+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
4546
graph,
4647
VK_KERNEL_FROM_STR(kernel_name),
47-
graph.create_global_wg_size(out),
48-
graph.create_local_wg_size(out),
48+
default_pick_global_wg_size,
49+
default_pick_local_wg_size,
4950
// Inputs and Outputs
5051
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
5152
// Parameter Buffers
@@ -60,6 +61,17 @@ void add_clone_node(
6061
resize_clone_node));
6162
}
6263

64+
utils::uvec3 clone_image_to_buffer_global_wg_size(
65+
ComputeGraph* graph,
66+
const vkapi::ShaderInfo& shader,
67+
const std::vector<ArgGroup>& args,
68+
const std::vector<ValueRef>& resize_args) {
69+
(void)shader;
70+
(void)resize_args;
71+
const ValueRef image = args.at(1).refs.at(0);
72+
return graph->create_global_wg_size(image);
73+
}
74+
6375
void add_image_to_buffer_node(
6476
ComputeGraph& graph,
6577
const ValueRef image,
@@ -68,12 +80,11 @@ void add_image_to_buffer_node(
6880
add_dtype_suffix(kernel_name, graph.dtype_of(image));
6981
vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name);
7082

71-
utils::uvec3 global_wg_size = graph.create_global_wg_size(image);
72-
graph.execute_nodes().emplace_back(new DispatchNode(
83+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
7384
graph,
7485
shader,
75-
global_wg_size,
76-
graph.create_local_wg_size(global_wg_size),
86+
clone_image_to_buffer_global_wg_size,
87+
default_pick_local_wg_size,
7788
// Input and Outputs
7889
{{buffer, vkapi::kWrite}, {image, vkapi::kRead}},
7990
// Parameter Buffers
@@ -96,12 +107,11 @@ void add_buffer_to_image_node(
96107
add_dtype_suffix(kernel_name, graph.dtype_of(image));
97108
vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name);
98109

99-
utils::uvec3 global_wg_size = graph.create_global_wg_size(image);
100-
graph.execute_nodes().emplace_back(new DispatchNode(
110+
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
101111
graph,
102112
shader,
103-
global_wg_size,
104-
graph.create_local_wg_size(global_wg_size),
113+
default_pick_global_wg_size,
114+
default_pick_local_wg_size,
105115
// Input and Outputs
106116
{{image, vkapi::kWrite}, {buffer, vkapi::kRead}},
107117
// Parameter Buffers

backends/vulkan/runtime/graph/ops/impl/Common.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ utils::uvec3 default_pick_global_wg_size(
1414
ComputeGraph* graph,
1515
const vkapi::ShaderInfo& shader,
1616
const std::vector<ArgGroup>& args,
17-
const std::vector<ValueRef>& additional_args) {
17+
const std::vector<ValueRef>& resize_args) {
1818
(void)shader;
19+
(void)resize_args;
1920
const ValueRef out = args.at(0).refs.at(0);
2021
return graph->create_global_wg_size(out);
2122
}
@@ -25,8 +26,10 @@ utils::uvec3 default_pick_local_wg_size(
2526
const vkapi::ShaderInfo& shader,
2627
const utils::uvec3& global_workgroup_size,
2728
const std::vector<ArgGroup>& args,
28-
const std::vector<ValueRef>& additional_args) {
29+
const std::vector<ValueRef>& resize_args) {
2930
(void)shader;
31+
(void)args;
32+
(void)resize_args;
3033
return graph->create_local_wg_size(global_workgroup_size);
3134
}
3235

backends/vulkan/runtime/graph/ops/impl/Common.h

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,23 @@ namespace vkcompute {
1717
* Creates a global workgroup size based on the first output tensor in the args.
1818
* This is a utility function that extracts the output tensor from
1919
* args.at(0).refs.at(0) and calls graph->create_global_wg_size(out) on it.
20-
*
21-
* @param graph The ComputeGraph instance
22-
* @param args Vector of ArgGroup containing the output tensor reference
23-
* @return utils::uvec3 The global workgroup size
2420
*/
2521
utils::uvec3 default_pick_global_wg_size(
2622
ComputeGraph* graph,
2723
const vkapi::ShaderInfo& shader,
2824
const std::vector<ArgGroup>& args,
29-
const std::vector<ValueRef>& additional_args);
25+
const std::vector<ValueRef>& resize_args);
3026

3127
/**
3228
* Creates a local workgroup size based on the first output tensor in the args.
3329
* This is a utility function that extracts the output tensor from
3430
* args.at(0).refs.at(0) and calls graph->create_local_wg_size(out) on it.
35-
*
36-
* @param graph The ComputeGraph instance
37-
* @param args Vector of ArgGroup containing the output tensor reference
38-
* @return utils::uvec3 The local workgroup size
3931
*/
4032
utils::uvec3 default_pick_local_wg_size(
4133
ComputeGraph* graph,
4234
const vkapi::ShaderInfo& shader,
4335
const utils::uvec3& global_workgroup_size,
4436
const std::vector<ArgGroup>& args,
45-
const std::vector<ValueRef>& additional_args);
37+
const std::vector<ValueRef>& resize_args);
4638

4739
} // namespace vkcompute

0 commit comments

Comments
 (0)