Skip to content

Commit 0ecc438

Browse files
Abhi-hppfacebook-github-bot
authored andcommitted
Implement Graph node (#6037)
Summary: Pull Request resolved: #6037 Introduce Graph node which will be the parent class for nodes. This allows us to have different nodes that correspond to different vk command buffer functions including shader dispatch. Differential Revision: D64080291
1 parent cb3a546 commit 0ecc438

39 files changed

+164
-84
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ void ComputeGraph::encode_execute() {
573573
shared_object.bind_users(this);
574574
}
575575

576-
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
576+
for (std::unique_ptr<GraphNode>& node : execute_nodes_) {
577577
node->encode(this);
578578
}
579579
}
@@ -592,7 +592,7 @@ void ComputeGraph::resize_input(
592592
}
593593

594594
void ComputeGraph::propagate_resize() {
595-
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
595+
for (std::unique_ptr<GraphNode>& node : execute_nodes_) {
596596
node->trigger_resize(this);
597597
}
598598
}

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#include <executorch/backends/vulkan/runtime/graph/containers/SharedObject.h>
2121
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
2222

23-
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
23+
#include <executorch/backends/vulkan/runtime/graph/ops/GraphNode.h>
2424
#include <executorch/backends/vulkan/runtime/graph/ops/PrepackNode.h>
2525

2626
namespace vkcompute {
@@ -178,7 +178,7 @@ class ComputeGraph final {
178178
std::vector<api::ParamsBuffer> param_ubos_;
179179

180180
std::vector<std::unique_ptr<PrepackNode>> prepack_nodes_;
181-
std::vector<std::unique_ptr<ExecuteNode>> execute_nodes_;
181+
std::vector<std::unique_ptr<GraphNode>> execute_nodes_;
182182

183183
std::vector<IOValueRef> inputs_;
184184
std::vector<IOValueRef> outputs_;
@@ -207,7 +207,7 @@ class ComputeGraph final {
207207
return prepack_nodes_;
208208
}
209209

210-
inline std::vector<std::unique_ptr<ExecuteNode>>& execute_nodes() {
210+
inline std::vector<std::unique_ptr<GraphNode>>& execute_nodes() {
211211
return execute_nodes_;
212212
}
213213

backends/vulkan/runtime/graph/Logging.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,9 @@ void ComputeGraph::print_readable() {
154154
<< std::endl;
155155

156156
size_t node_idx = 0;
157-
for (const std::unique_ptr<ExecuteNode>& node : execute_nodes()) {
157+
for (const std::unique_ptr<GraphNode>& node : execute_nodes()) {
158158
std::cout << std::setw(6) << node_idx;
159-
std::cout << std::setw(32) << node->shader_.kernel_name;
159+
std::cout << std::setw(32) << node->name();
160160

161161
std::stringstream read_s;
162162
for (const ArgGroup& arg_group : node->args_) {

backends/vulkan/runtime/graph/ops/ExecuteNode.cpp

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,15 @@ ExecuteNode::ExecuteNode(
2424
const vkapi::SpecVarList& spec_vars,
2525
const ResizeFunction& resize_fn,
2626
const std::vector<ValueRef>& resize_args)
27-
: shader_(shader),
27+
: GraphNode(resize_fn, resize_args, args, shader.kernel_name),
28+
shader_(shader),
2829
global_workgroup_size_(global_workgroup_size),
2930
local_workgroup_size_(local_workgroup_size),
30-
args_(args),
3131
params_(params),
32-
spec_vars_(spec_vars),
33-
resize_fn_(resize_fn),
34-
resize_args_(resize_args) {
32+
spec_vars_(spec_vars) {
3533
graph.update_descriptor_counts(shader, /*execute = */ true);
3634
}
3735

38-
ExecuteNode::ExecuteNode(
39-
const ResizeFunction& resize_fn,
40-
const std::vector<ValueRef>& resize_args)
41-
: shader_(),
42-
global_workgroup_size_({0u, 0u, 0u}),
43-
local_workgroup_size_({0u, 0u, 0u}),
44-
args_(),
45-
params_(),
46-
spec_vars_(),
47-
resize_fn_(resize_fn),
48-
resize_args_(resize_args) {}
49-
5036
void ExecuteNode::encode(ComputeGraph* graph) {
5137
if (!shader_) {
5238
return;

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 5 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -12,42 +12,19 @@
1212

1313
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
1414

15+
#include <executorch/backends/vulkan/runtime/graph/ops/GraphNode.h>
16+
1517
namespace vkcompute {
1618

1719
class ComputeGraph;
1820

1921
/*
20-
* Represents a group of shader arguments (images and/or buffers), with a common
21-
* access permission.
22-
*/
23-
struct ArgGroup {
24-
ArgGroup(const ValueRef ref, const vkapi::MemoryAccessFlags access)
25-
: refs{ref}, access(access) {}
26-
27-
ArgGroup(
28-
const std::vector<ValueRef>& refs,
29-
const vkapi::MemoryAccessFlags access)
30-
: refs(refs), access(access) {}
31-
32-
const std::vector<ValueRef> refs;
33-
const vkapi::MemoryAccessFlags access;
34-
};
35-
36-
/*
37-
* Represents a single execution op in a ML model. In graph mode, ops will be
38-
* implemented in a derived class that implements encode, which will implement
39-
* encoding of the shader corresponding to the op into the command buffer of a
40-
* ComputeGraph.
22+
* Represents a single shader execution op in a ML model.
4123
*/
42-
class ExecuteNode final {
24+
class ExecuteNode final : public GraphNode {
4325
friend class ComputeGraph;
4426

4527
public:
46-
using ResizeFunction = const std::function<void(
47-
ComputeGraph*,
48-
const std::vector<ArgGroup>&,
49-
const std::vector<ValueRef>&)>;
50-
5128
explicit ExecuteNode(
5229
ComputeGraph& graph,
5330
const vkapi::ShaderInfo& shader,
@@ -59,39 +36,16 @@ class ExecuteNode final {
5936
const ResizeFunction& resize_fn = nullptr,
6037
const std::vector<ValueRef>& resize_args = {});
6138

62-
/*
63-
* This overload of the ExecuteNode constructor is used to register ops which
64-
* update a tensor view. No shader is dispatched, but the node still needs to
65-
* update the view's sizes and strides after a resize.
66-
*/
67-
explicit ExecuteNode(
68-
const ResizeFunction& resize_fn = nullptr,
69-
const std::vector<ValueRef>& resize_args = {});
70-
7139
~ExecuteNode() = default;
7240

73-
void encode(ComputeGraph* graph);
74-
75-
inline void trigger_resize(ComputeGraph* graph) {
76-
if (resize_fn_ != nullptr) {
77-
resize_fn_(graph, args_, resize_args_);
78-
}
79-
}
80-
81-
inline void set_node_id(uint32_t node_id) {
82-
node_id_ = node_id;
83-
}
41+
void encode(ComputeGraph* graph) override;
8442

8543
protected:
86-
uint32_t node_id_;
8744
const vkapi::ShaderInfo shader_;
8845
const utils::uvec3 global_workgroup_size_;
8946
const utils::uvec3 local_workgroup_size_;
90-
const std::vector<ArgGroup> args_;
9147
const vkapi::ParamsBindList params_;
9248
const vkapi::SpecVarList spec_vars_;
93-
const ResizeFunction resize_fn_;
94-
const std::vector<ValueRef> resize_args_;
9549

9650
public:
9751
operator bool() const {
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/GraphNode.h>
10+
11+
namespace vkcompute {
12+
GraphNode::GraphNode(
13+
const ResizeFunction& resize_fn,
14+
const std::vector<ValueRef>& resize_args,
15+
const std::vector<ArgGroup>& args,
16+
const std::string name)
17+
: resize_fn_(resize_fn),
18+
resize_args_(resize_args),
19+
args_(args),
20+
name_(name) {}
21+
} // namespace vkcompute
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/backends/vulkan/runtime/api/api.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
14+
15+
namespace vkcompute {
16+
17+
class ComputeGraph;
18+
19+
/*
20+
* Represents a group of shader arguments (images and/or buffers), with a common
21+
* access permission.
22+
*/
23+
struct ArgGroup {
24+
ArgGroup(const ValueRef ref, const vkapi::MemoryAccessFlags access)
25+
: refs{ref}, access(access) {}
26+
27+
ArgGroup(
28+
const std::vector<ValueRef>& refs,
29+
const vkapi::MemoryAccessFlags access)
30+
: refs(refs), access(access) {}
31+
32+
const std::vector<ValueRef> refs;
33+
const vkapi::MemoryAccessFlags access;
34+
};
35+
36+
/*
37+
* Represents a single execution op in a ML model. In graph mode, ops will be
38+
* implemented in a derived class that implements encode, which will implement
39+
* encoding of the shader corresponding to the op into the command buffer of a
40+
* ComputeGraph.
41+
*/
42+
class GraphNode {
43+
friend class ComputeGraph;
44+
45+
public:
46+
using ResizeFunction = const std::function<void(
47+
ComputeGraph*,
48+
const std::vector<ArgGroup>&,
49+
const std::vector<ValueRef>&)>;
50+
51+
/*
52+
* This overload of the GraphNode constructor is used to register ops which
53+
* update a tensor view. No shader is dispatched, but the node still needs to
54+
* update the view's sizes and strides after a resize.
55+
*/
56+
explicit GraphNode(
57+
const ResizeFunction& resize_fn = nullptr,
58+
const std::vector<ValueRef>& resize_args = {},
59+
const std::vector<ArgGroup>& args = {},
60+
const std::string name = "Graph Node");
61+
62+
virtual ~GraphNode() = default;
63+
64+
virtual void encode(ComputeGraph* graph) {}
65+
66+
inline void trigger_resize(ComputeGraph* graph) {
67+
if (resize_fn_ != nullptr) {
68+
resize_fn_(graph, args_, resize_args_);
69+
}
70+
}
71+
72+
inline void set_node_id(uint32_t node_id) {
73+
node_id_ = node_id;
74+
}
75+
76+
inline const std::string& name() const {
77+
return name_;
78+
}
79+
80+
protected:
81+
uint32_t node_id_;
82+
const ResizeFunction resize_fn_;
83+
const std::vector<ValueRef> resize_args_;
84+
const std::vector<ArgGroup> args_;
85+
const std::string name_;
86+
};
87+
88+
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Arange.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/backends/vulkan/runtime/api/api.h>
10+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
1011

1112
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1213

backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1515
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1616

17+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
1718
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1819

1920
namespace vkcompute {

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
1414
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1515

16+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
1617
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1718

1819
namespace vkcompute {

backends/vulkan/runtime/graph/ops/impl/Clone.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/backends/vulkan/runtime/graph/Logging.h>
1212

13+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
1314
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1415
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1516
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1616
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1717

18+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
1819
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1920

2021
namespace vkcompute {

backends/vulkan/runtime/graph/ops/impl/Copy.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
1112
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
1213
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1314
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>

backends/vulkan/runtime/graph/ops/impl/Embedding.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
1414
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1515

16+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
1617
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1718

1819
namespace vkcompute {

backends/vulkan/runtime/graph/ops/impl/Flip.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1313
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1414
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
1516

1617
namespace vkcompute {
1718

backends/vulkan/runtime/graph/ops/impl/Full.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1212
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1313

14+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
1415
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1516

1617
namespace vkcompute {

backends/vulkan/runtime/graph/ops/impl/GridPriors.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1212

13+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
1314
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1415

1516
namespace vkcompute {

backends/vulkan/runtime/graph/ops/impl/IndexSelect.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
1414
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1515

16+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
1617
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1718

1819
namespace vkcompute {

backends/vulkan/runtime/graph/ops/impl/Linear.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
1515
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1616

17+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
1718
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1819

1920
namespace vkcompute {

backends/vulkan/runtime/graph/ops/impl/MatMul.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
1515
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1616

17+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
1718
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1819

1920
namespace vkcompute {

0 commit comments

Comments
 (0)