Skip to content

Commit c6d80e1

Browse files
Abhi-hppfacebook-github-bot
authored andcommitted
Implement Graph node (#6037)
Summary: Introduce Graph node which will be the parent class for nodes. This allows us to have different nodes that correspond to different vk command buffer functions including shader dispatch. Differential Revision: D64080291
1 parent cb3a546 commit c6d80e1

File tree

13 files changed

+138
-85
lines changed

13 files changed

+138
-85
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ void ComputeGraph::encode_execute() {
573573
shared_object.bind_users(this);
574574
}
575575

576-
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
576+
for (std::unique_ptr<GraphNode>& node : execute_nodes_) {
577577
node->encode(this);
578578
}
579579
}
@@ -592,7 +592,7 @@ void ComputeGraph::resize_input(
592592
}
593593

594594
void ComputeGraph::propagate_resize() {
595-
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
595+
for (std::unique_ptr<GraphNode>& node : execute_nodes_) {
596596
node->trigger_resize(this);
597597
}
598598
}

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#include <executorch/backends/vulkan/runtime/graph/containers/SharedObject.h>
2121
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
2222

23-
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
23+
#include <executorch/backends/vulkan/runtime/graph/ops/GraphNode.h>
2424
#include <executorch/backends/vulkan/runtime/graph/ops/PrepackNode.h>
2525

2626
namespace vkcompute {
@@ -178,7 +178,7 @@ class ComputeGraph final {
178178
std::vector<api::ParamsBuffer> param_ubos_;
179179

180180
std::vector<std::unique_ptr<PrepackNode>> prepack_nodes_;
181-
std::vector<std::unique_ptr<ExecuteNode>> execute_nodes_;
181+
std::vector<std::unique_ptr<GraphNode>> execute_nodes_;
182182

183183
std::vector<IOValueRef> inputs_;
184184
std::vector<IOValueRef> outputs_;
@@ -207,7 +207,7 @@ class ComputeGraph final {
207207
return prepack_nodes_;
208208
}
209209

210-
inline std::vector<std::unique_ptr<ExecuteNode>>& execute_nodes() {
210+
inline std::vector<std::unique_ptr<GraphNode>>& execute_nodes() {
211211
return execute_nodes_;
212212
}
213213

backends/vulkan/runtime/graph/Logging.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,9 @@ void ComputeGraph::print_readable() {
154154
<< std::endl;
155155

156156
size_t node_idx = 0;
157-
for (const std::unique_ptr<ExecuteNode>& node : execute_nodes()) {
157+
for (const std::unique_ptr<GraphNode>& node : execute_nodes()) {
158158
std::cout << std::setw(6) << node_idx;
159-
std::cout << std::setw(32) << node->shader_.kernel_name;
159+
std::cout << std::setw(32) << node->name();
160160

161161
std::stringstream read_s;
162162
for (const ArgGroup& arg_group : node->args_) {

backends/vulkan/runtime/graph/ops/ExecuteNode.cpp

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,15 @@ ExecuteNode::ExecuteNode(
2424
const vkapi::SpecVarList& spec_vars,
2525
const ResizeFunction& resize_fn,
2626
const std::vector<ValueRef>& resize_args)
27-
: shader_(shader),
27+
: GraphNode(resize_fn, resize_args, args, shader.kernel_name),
28+
shader_(shader),
2829
global_workgroup_size_(global_workgroup_size),
2930
local_workgroup_size_(local_workgroup_size),
30-
args_(args),
3131
params_(params),
32-
spec_vars_(spec_vars),
33-
resize_fn_(resize_fn),
34-
resize_args_(resize_args) {
32+
spec_vars_(spec_vars) {
3533
graph.update_descriptor_counts(shader, /*execute = */ true);
3634
}
3735

38-
ExecuteNode::ExecuteNode(
39-
const ResizeFunction& resize_fn,
40-
const std::vector<ValueRef>& resize_args)
41-
: shader_(),
42-
global_workgroup_size_({0u, 0u, 0u}),
43-
local_workgroup_size_({0u, 0u, 0u}),
44-
args_(),
45-
params_(),
46-
spec_vars_(),
47-
resize_fn_(resize_fn),
48-
resize_args_(resize_args) {}
49-
5036
void ExecuteNode::encode(ComputeGraph* graph) {
5137
if (!shader_) {
5238
return;

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 5 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -12,42 +12,19 @@
1212

1313
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
1414

15+
#include <executorch/backends/vulkan/runtime/graph/ops/GraphNode.h>
16+
1517
namespace vkcompute {
1618

1719
class ComputeGraph;
1820

1921
/*
20-
* Represents a group of shader arguments (images and/or buffers), with a common
21-
* access permission.
22-
*/
23-
struct ArgGroup {
24-
ArgGroup(const ValueRef ref, const vkapi::MemoryAccessFlags access)
25-
: refs{ref}, access(access) {}
26-
27-
ArgGroup(
28-
const std::vector<ValueRef>& refs,
29-
const vkapi::MemoryAccessFlags access)
30-
: refs(refs), access(access) {}
31-
32-
const std::vector<ValueRef> refs;
33-
const vkapi::MemoryAccessFlags access;
34-
};
35-
36-
/*
37-
* Represents a single execution op in a ML model. In graph mode, ops will be
38-
* implemented in a derived class that implements encode, which will implement
39-
* encoding of the shader corresponding to the op into the command buffer of a
40-
* ComputeGraph.
22+
* Represents a single shader execution op in a ML model.
4123
*/
42-
class ExecuteNode final {
24+
class ExecuteNode final : public GraphNode {
4325
friend class ComputeGraph;
4426

4527
public:
46-
using ResizeFunction = const std::function<void(
47-
ComputeGraph*,
48-
const std::vector<ArgGroup>&,
49-
const std::vector<ValueRef>&)>;
50-
5128
explicit ExecuteNode(
5229
ComputeGraph& graph,
5330
const vkapi::ShaderInfo& shader,
@@ -59,39 +36,16 @@ class ExecuteNode final {
5936
const ResizeFunction& resize_fn = nullptr,
6037
const std::vector<ValueRef>& resize_args = {});
6138

62-
/*
63-
* This overload of the ExecuteNode constructor is used to register ops which
64-
* update a tensor view. No shader is dispatched, but the node still needs to
65-
* update the view's sizes and strides after a resize.
66-
*/
67-
explicit ExecuteNode(
68-
const ResizeFunction& resize_fn = nullptr,
69-
const std::vector<ValueRef>& resize_args = {});
70-
7139
~ExecuteNode() = default;
7240

73-
void encode(ComputeGraph* graph);
74-
75-
inline void trigger_resize(ComputeGraph* graph) {
76-
if (resize_fn_ != nullptr) {
77-
resize_fn_(graph, args_, resize_args_);
78-
}
79-
}
80-
81-
inline void set_node_id(uint32_t node_id) {
82-
node_id_ = node_id;
83-
}
41+
void encode(ComputeGraph* graph) override;
8442

8543
protected:
86-
uint32_t node_id_;
8744
const vkapi::ShaderInfo shader_;
8845
const utils::uvec3 global_workgroup_size_;
8946
const utils::uvec3 local_workgroup_size_;
90-
const std::vector<ArgGroup> args_;
9147
const vkapi::ParamsBindList params_;
9248
const vkapi::SpecVarList spec_vars_;
93-
const ResizeFunction resize_fn_;
94-
const std::vector<ValueRef> resize_args_;
9549

9650
public:
9751
operator bool() const {
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/GraphNode.h>
10+
11+
namespace vkcompute {
12+
13+
GraphNode::GraphNode(
14+
const ResizeFunction& resize_fn,
15+
const std::vector<ValueRef>& resize_args,
16+
const std::vector<ArgGroup>& args,
17+
const std::string& name)
18+
: resize_fn_(resize_fn),
19+
resize_args_(resize_args),
20+
args_(args),
21+
name_(name) {}
22+
23+
} // namespace vkcompute
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/vulkan/runtime/api/api.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
14+
15+
namespace vkcompute {
16+
17+
class ComputeGraph;
18+
19+
/*
20+
* Represents a group of shader arguments (images and/or buffers), with a common
21+
* access permission.
22+
*/
23+
struct ArgGroup {
24+
ArgGroup(const ValueRef ref, const vkapi::MemoryAccessFlags access)
25+
: refs{ref}, access(access) {}
26+
27+
ArgGroup(
28+
const std::vector<ValueRef>& refs,
29+
const vkapi::MemoryAccessFlags access)
30+
: refs(refs), access(access) {}
31+
32+
const std::vector<ValueRef> refs;
33+
const vkapi::MemoryAccessFlags access;
34+
};
35+
36+
/*
37+
* Represents a single execution op in a ML model. In graph mode, ops will be
38+
* implemented in a derived class that implements encode, which will implement
39+
* encoding of the shader corresponding to the op into the command buffer of a
40+
* ComputeGraph.
41+
*/
42+
class GraphNode {
43+
friend class ComputeGraph;
44+
45+
public:
46+
using ResizeFunction = const std::function<void(
47+
ComputeGraph*,
48+
const std::vector<ArgGroup>&,
49+
const std::vector<ValueRef>&)>;
50+
51+
/*
52+
* This overload of the GraphNode constructor is used to register ops which
53+
* update a tensor view. No shader is dispatched, but the node still needs to
54+
* update the view's sizes and strides after a resize.
55+
*/
56+
explicit GraphNode(
57+
const ResizeFunction& resize_fn = nullptr,
58+
const std::vector<ValueRef>& resize_args = {},
59+
const std::vector<ArgGroup>& args = {},
60+
const std::string& name = "Graph Node");
61+
62+
virtual ~GraphNode() = default;
63+
64+
virtual void encode(ComputeGraph* graph) {}
65+
66+
inline void trigger_resize(ComputeGraph* graph) {
67+
if (resize_fn_ != nullptr) {
68+
resize_fn_(graph, args_, resize_args_);
69+
}
70+
}
71+
72+
inline void set_node_id(uint32_t node_id) {
73+
node_id_ = node_id;
74+
}
75+
76+
inline const std::string& name() const {
77+
return name_;
78+
}
79+
80+
protected:
81+
uint32_t node_id_;
82+
const ResizeFunction resize_fn_;
83+
const std::vector<ValueRef> resize_args_;
84+
const std::vector<ArgGroup> args_;
85+
const std::string name_;
86+
};
87+
88+
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/OperatorRegistry.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
12+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
1213

1314
#include <functional>
1415
#include <unordered_map>

backends/vulkan/runtime/graph/ops/impl/SDPA.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ void add_cache_slice_view_node(
159159

160160
graph.get_tensor(cache_sliced)->virtual_resize(slice_sizes);
161161

162-
graph.execute_nodes().emplace_back(new ExecuteNode(
162+
graph.execute_nodes().emplace_back(new GraphNode(
163163
resize_cache_slice_view_node,
164164
{cache, input_pos_symint, q_projected, cache_sliced}));
165165
}
@@ -328,7 +328,7 @@ void sdpa_with_kv_cache_impl(
328328
mat2_is_transposed);
329329

330330
graph.execute_nodes().emplace_back(
331-
new ExecuteNode(resize_sdpa_out, {q_projected, out}));
331+
new GraphNode(resize_sdpa_out, {q_projected, out}));
332332
}
333333

334334
REGISTER_OPERATORS {

backends/vulkan/runtime/graph/ops/impl/Slice.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ void add_slice_view_node(
264264

265265
graph.get_tensor(out_ref)->virtual_resize(new_out_sizes);
266266

267-
graph.execute_nodes().emplace_back(new ExecuteNode(
267+
graph.execute_nodes().emplace_back(new GraphNode(
268268
resize_slice_view_node,
269269
{out_ref, in_ref, dim_ref, opt_start_ref, opt_end_ref, opt_step_ref}));
270270
}

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* This source code is licensed under the BSD-style license found in the
66
* LICENSE file in the root directory of this source tree.
77
*/
8-
8+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
99
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1010

1111
#include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>

backends/vulkan/runtime/graph/ops/impl/Transpose.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ void add_transpose_view_node(
6666
graph.get_tensor(out_ref)->virtual_clone(*in);
6767
graph.get_tensor(out_ref)->virtual_transpose(dim0, dim1);
6868

69-
graph.execute_nodes().emplace_back(new ExecuteNode(
69+
graph.execute_nodes().emplace_back(new GraphNode(
7070
resize_transpose_view_node, {out_ref, input_ref, dim0_ref, dim1_ref}));
7171
}
7272

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
#include <executorch/backends/vulkan/test/utils/test_utils.h>
2727

28+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
29+
2830
using namespace vkcompute;
2931
using namespace vkcompute::api;
3032

@@ -1148,16 +1150,15 @@ TEST(VulkanComputeGraphTest, test_values_string) {
11481150
EXPECT_TRUE(stored == "hello, world");
11491151
}
11501152

1151-
TEST(VulkanComputeGraphTest, empty_init_executenode_test) {
1152-
ExecuteNode node(nullptr, {});
1153-
EXPECT_FALSE(node);
1153+
TEST(VulkanComputeGraphTest, empty_init_graphnode_test) {
1154+
GraphNode node(nullptr, {});
11541155

11551156
GraphConfig config;
11561157
ComputeGraph graph(config);
11571158

1158-
// Encode an empty ExecuteNode and check that command buffer encoding does not
1159+
// Encode an empty GraphNode and check that command buffer encoding does not
11591160
// crash.
1160-
graph.execute_nodes().emplace_back(new ExecuteNode(nullptr, {}));
1161+
graph.execute_nodes().emplace_back(new GraphNode(nullptr, {}));
11611162
EXPECT_NO_FATAL_FAILURE(graph.encode_execute());
11621163
}
11631164

0 commit comments

Comments
 (0)