Skip to content

Commit df5b2ab

Browse files
Abhi-hppfacebook-github-bot
authored andcommitted
Implement Graph node (#6037)
Summary: Pull Request resolved: #6037 Introduce Graph node which will be the parent class for nodes. This allows us to have different nodes that correspond to different vk command buffer functions including shader dispatch. Reviewed By: SS-JIA Differential Revision: D64080291 fbshipit-source-id: b49540b99c8901f18645c86d3f5a9d274e5191b5
1 parent d6aea3d commit df5b2ab

38 files changed

+193
-139
lines changed

backends/vulkan/runtime/graph/Logging.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ void ComputeGraph::print_readable() {
156156
size_t node_idx = 0;
157157
for (const std::unique_ptr<ExecuteNode>& node : execute_nodes()) {
158158
std::cout << std::setw(6) << node_idx;
159-
std::cout << std::setw(32) << node->shader_.kernel_name;
159+
std::cout << std::setw(32) << node->name();
160160

161161
std::stringstream read_s;
162162
for (const ArgGroup& arg_group : node->args_) {
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/DispatchNode.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h>
14+
15+
namespace vkcompute {
16+
17+
DispatchNode::DispatchNode(
18+
ComputeGraph& graph,
19+
const vkapi::ShaderInfo& shader,
20+
const utils::uvec3& global_workgroup_size,
21+
const utils::uvec3& local_workgroup_size,
22+
const std::vector<ArgGroup>& args,
23+
const vkapi::ParamsBindList& params,
24+
const vkapi::SpecVarList& spec_vars,
25+
const ResizeFunction& resize_fn,
26+
const std::vector<ValueRef>& resize_args)
27+
: ExecuteNode(resize_fn, resize_args, args, shader.kernel_name),
28+
shader_(shader),
29+
global_workgroup_size_(global_workgroup_size),
30+
local_workgroup_size_(local_workgroup_size),
31+
params_(params),
32+
spec_vars_(spec_vars) {
33+
graph.update_descriptor_counts(shader, /*execute = */ true);
34+
}
35+
36+
void DispatchNode::encode(ComputeGraph* graph) {
37+
if (!shader_) {
38+
return;
39+
}
40+
api::Context* const context = graph->context();
41+
vkapi::PipelineBarrier pipeline_barrier{};
42+
43+
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();
44+
45+
context->report_shader_dispatch_start(
46+
shader_.kernel_name,
47+
global_workgroup_size_,
48+
local_workgroup_size_,
49+
node_id_);
50+
51+
vkapi::DescriptorSet descriptor_set =
52+
context->get_descriptor_set(shader_, local_workgroup_size_, spec_vars_);
53+
54+
uint32_t idx = 0;
55+
idx = bind_values_to_descriptor_set(
56+
graph, args_, pipeline_barrier, descriptor_set, idx);
57+
58+
bind_params_to_descriptor_set(params_, descriptor_set, idx);
59+
60+
context->register_shader_dispatch(
61+
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);
62+
63+
context->report_shader_dispatch_end();
64+
}
65+
66+
} // namespace vkcompute
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/vulkan/runtime/api/api.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
14+
15+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
16+
17+
namespace vkcompute {
18+
19+
class ComputeGraph;
20+
21+
/*
22+
* Represents a single shader execution op in a ML model.
23+
*/
24+
class DispatchNode final : public ExecuteNode {
25+
friend class ComputeGraph;
26+
27+
public:
28+
explicit DispatchNode(
29+
ComputeGraph& graph,
30+
const vkapi::ShaderInfo& shader,
31+
const utils::uvec3& global_workgroup_size,
32+
const utils::uvec3& local_workgroup_size,
33+
const std::vector<ArgGroup>& args,
34+
const vkapi::ParamsBindList& params,
35+
const vkapi::SpecVarList& spec_vars = {},
36+
const ResizeFunction& resize_fn = nullptr,
37+
const std::vector<ValueRef>& resize_args = {});
38+
39+
~DispatchNode() override = default;
40+
41+
void encode(ComputeGraph* graph) override;
42+
43+
protected:
44+
const vkapi::ShaderInfo shader_;
45+
const utils::uvec3 global_workgroup_size_;
46+
const utils::uvec3 local_workgroup_size_;
47+
const vkapi::ParamsBindList params_;
48+
const vkapi::SpecVarList spec_vars_;
49+
50+
public:
51+
operator bool() const {
52+
return shader_;
53+
}
54+
};
55+
56+
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/ExecuteNode.cpp

Lines changed: 6 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -8,73 +8,14 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
1010

11-
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
12-
13-
#include <executorch/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h>
14-
1511
namespace vkcompute {
16-
1712
ExecuteNode::ExecuteNode(
18-
ComputeGraph& graph,
19-
const vkapi::ShaderInfo& shader,
20-
const utils::uvec3& global_workgroup_size,
21-
const utils::uvec3& local_workgroup_size,
22-
const std::vector<ArgGroup>& args,
23-
const vkapi::ParamsBindList& params,
24-
const vkapi::SpecVarList& spec_vars,
2513
const ResizeFunction& resize_fn,
26-
const std::vector<ValueRef>& resize_args)
27-
: shader_(shader),
28-
global_workgroup_size_(global_workgroup_size),
29-
local_workgroup_size_(local_workgroup_size),
14+
const std::vector<ValueRef>& resize_args,
15+
const std::vector<ArgGroup>& args,
16+
const std::string& name)
17+
: resize_fn_(resize_fn),
18+
resize_args_(resize_args),
3019
args_(args),
31-
params_(params),
32-
spec_vars_(spec_vars),
33-
resize_fn_(resize_fn),
34-
resize_args_(resize_args) {
35-
graph.update_descriptor_counts(shader, /*execute = */ true);
36-
}
37-
38-
ExecuteNode::ExecuteNode(
39-
const ResizeFunction& resize_fn,
40-
const std::vector<ValueRef>& resize_args)
41-
: shader_(),
42-
global_workgroup_size_({0u, 0u, 0u}),
43-
local_workgroup_size_({0u, 0u, 0u}),
44-
args_(),
45-
params_(),
46-
spec_vars_(),
47-
resize_fn_(resize_fn),
48-
resize_args_(resize_args) {}
49-
50-
void ExecuteNode::encode(ComputeGraph* graph) {
51-
if (!shader_) {
52-
return;
53-
}
54-
api::Context* const context = graph->context();
55-
vkapi::PipelineBarrier pipeline_barrier{};
56-
57-
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();
58-
59-
context->report_shader_dispatch_start(
60-
shader_.kernel_name,
61-
global_workgroup_size_,
62-
local_workgroup_size_,
63-
node_id_);
64-
65-
vkapi::DescriptorSet descriptor_set =
66-
context->get_descriptor_set(shader_, local_workgroup_size_, spec_vars_);
67-
68-
uint32_t idx = 0;
69-
idx = bind_values_to_descriptor_set(
70-
graph, args_, pipeline_barrier, descriptor_set, idx);
71-
72-
bind_params_to_descriptor_set(params_, descriptor_set, idx);
73-
74-
context->register_shader_dispatch(
75-
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);
76-
77-
context->report_shader_dispatch_end();
78-
}
79-
20+
name_(name) {}
8021
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ struct ArgGroup {
3939
* encoding of the shader corresponding to the op into the command buffer of a
4040
* ComputeGraph.
4141
*/
42-
class ExecuteNode final {
42+
class ExecuteNode {
4343
friend class ComputeGraph;
4444

4545
public:
@@ -48,29 +48,22 @@ class ExecuteNode final {
4848
const std::vector<ArgGroup>&,
4949
const std::vector<ValueRef>&)>;
5050

51-
explicit ExecuteNode(
52-
ComputeGraph& graph,
53-
const vkapi::ShaderInfo& shader,
54-
const utils::uvec3& global_workgroup_size,
55-
const utils::uvec3& local_workgroup_size,
56-
const std::vector<ArgGroup>& args,
57-
const vkapi::ParamsBindList& params,
58-
const vkapi::SpecVarList& spec_vars = {},
59-
const ResizeFunction& resize_fn = nullptr,
60-
const std::vector<ValueRef>& resize_args = {});
61-
6251
/*
63-
* This overload of the ExecuteNode constructor is used to register ops which
52+
* This overload of the DispatchNode constructor is used to register ops which
6453
* update a tensor view. No shader is dispatched, but the node still needs to
6554
* update the view's sizes and strides after a resize.
6655
*/
6756
explicit ExecuteNode(
6857
const ResizeFunction& resize_fn = nullptr,
69-
const std::vector<ValueRef>& resize_args = {});
58+
const std::vector<ValueRef>& resize_args = {},
59+
const std::vector<ArgGroup>& args = {},
60+
const std::string& name = "Graph Node");
7061

71-
~ExecuteNode() = default;
62+
virtual ~ExecuteNode() = default;
7263

73-
void encode(ComputeGraph* graph);
64+
virtual void encode(ComputeGraph* graph) {
65+
(void)graph;
66+
}
7467

7568
inline void trigger_resize(ComputeGraph* graph) {
7669
if (resize_fn_ != nullptr) {
@@ -82,21 +75,16 @@ class ExecuteNode final {
8275
node_id_ = node_id;
8376
}
8477

78+
inline const std::string& name() const {
79+
return name_;
80+
}
81+
8582
protected:
8683
uint32_t node_id_;
87-
const vkapi::ShaderInfo shader_;
88-
const utils::uvec3 global_workgroup_size_;
89-
const utils::uvec3 local_workgroup_size_;
90-
const std::vector<ArgGroup> args_;
91-
const vkapi::ParamsBindList params_;
92-
const vkapi::SpecVarList spec_vars_;
9384
const ResizeFunction resize_fn_;
9485
const std::vector<ValueRef> resize_args_;
95-
96-
public:
97-
operator bool() const {
98-
return shader_;
99-
}
86+
const std::vector<ArgGroup> args_;
87+
const std::string name_;
10088
};
10189

10290
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/OperatorRegistry.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/DispatchNode.h>
1213

1314
#include <functional>
1415
#include <unordered_map>

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
9+
#include <executorch/backends/vulkan/runtime/graph/ops/DispatchNode.h>
1010

1111
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
1212

backends/vulkan/runtime/graph/ops/impl/Arange.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ void add_arange_node(
8888
kernel_name.reserve(kShaderNameReserve);
8989
add_dtype_suffix(kernel_name, *t_out);
9090

91-
graph.execute_nodes().emplace_back(new ExecuteNode(
91+
graph.execute_nodes().emplace_back(new DispatchNode(
9292
graph,
9393
VK_KERNEL_FROM_STR(kernel_name),
9494
graph.create_global_wg_size(out),

backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ void add_native_batch_norm_node(
8080
int32_t num_texel_per_batch =
8181
utils::div_up_4((dim_at<kChannel4D>(t_in->sizes())));
8282

83-
graph.execute_nodes().emplace_back(new ExecuteNode(
83+
graph.execute_nodes().emplace_back(new DispatchNode(
8484
graph,
8585
VK_KERNEL_FROM_STR(kernel_name),
8686
graph.create_global_wg_size(out_ref),

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ void add_binary_op_node(
7575
kernel_name += op_name;
7676
add_dtype_suffix(kernel_name, *t_out);
7777

78-
graph.execute_nodes().emplace_back(new ExecuteNode(
78+
graph.execute_nodes().emplace_back(new DispatchNode(
7979
graph,
8080
VK_KERNEL_FROM_STR(kernel_name),
8181
graph.create_global_wg_size(out),

backends/vulkan/runtime/graph/ops/impl/Clone.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ void add_clone_node(
2525
std::string kernel_name = "clone";
2626
add_dtype_suffix(kernel_name, *t_out);
2727

28-
graph.execute_nodes().emplace_back(new ExecuteNode(
28+
graph.execute_nodes().emplace_back(new DispatchNode(
2929
graph,
3030
VK_KERNEL_FROM_STR(kernel_name),
3131
graph.create_global_wg_size(out),

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ void add_conv2d_node(
366366
vkapi::ShaderInfo shader = get_conv2d_shader(
367367
graph, *t_out, /*prepack_weights = */ false, method, weight, clamp_out);
368368

369-
graph.execute_nodes().emplace_back(new ExecuteNode(
369+
graph.execute_nodes().emplace_back(new DispatchNode(
370370
graph,
371371
shader,
372372
create_conv2d_global_wg_size(graph, method, out),
@@ -464,7 +464,7 @@ void add_conv1d_node(
464464

465465
add_dtype_suffix(kernel_name, *t_out);
466466

467-
graph.execute_nodes().emplace_back(new ExecuteNode(
467+
graph.execute_nodes().emplace_back(new DispatchNode(
468468
graph,
469469
VK_KERNEL_FROM_STR(kernel_name),
470470
global_size,

backends/vulkan/runtime/graph/ops/impl/Copy.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ void add_copy_offset_node(
4545

4646
auto shader = VK_KERNEL_FROM_STR(kernel_name);
4747

48-
graph.execute_nodes().emplace_back(new ExecuteNode(
48+
graph.execute_nodes().emplace_back(new DispatchNode(
4949
graph,
5050
VK_KERNEL_FROM_STR(kernel_name),
5151
graph.create_global_wg_size(out),
@@ -155,7 +155,7 @@ void add_copy_channel_offset_node(
155155

156156
auto shader = VK_KERNEL_FROM_STR(kernel_name);
157157

158-
graph.execute_nodes().emplace_back(new ExecuteNode(
158+
graph.execute_nodes().emplace_back(new DispatchNode(
159159
graph,
160160
VK_KERNEL_FROM_STR(kernel_name),
161161
global_size,

backends/vulkan/runtime/graph/ops/impl/Embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void add_embedding_node(
4141
kernel_name.reserve(kShaderNameReserve);
4242
add_dtype_suffix(kernel_name, *t_out);
4343

44-
graph.execute_nodes().emplace_back(new ExecuteNode(
44+
graph.execute_nodes().emplace_back(new DispatchNode(
4545
graph,
4646
VK_KERNEL_FROM_STR(kernel_name),
4747
graph.create_global_wg_size(out),

backends/vulkan/runtime/graph/ops/impl/Flip.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ void add_flip_node(
5858
kernel_name.reserve(kShaderNameReserve);
5959
add_dtype_suffix(kernel_name, *t_out);
6060

61-
graph.execute_nodes().emplace_back(new ExecuteNode(
61+
graph.execute_nodes().emplace_back(new DispatchNode(
6262
graph,
6363
VK_KERNEL_FROM_STR(kernel_name),
6464
graph.create_global_wg_size(out),

0 commit comments

Comments
 (0)