Skip to content

Commit dc5a9af

Browse files
jorgep31415facebook-github-bot
authored andcommitted
Separate OpNode to PrepackNode/ExecuteNode (#2102)
Summary: Pull Request resolved: #2102 Derived classes of `OpNode` are currently used only for prepack or execute, never both. This means they need not have both API. Inspired by [Stephen's comment](https://www.internalfb.com/diff/D53982441?dst_version_fbid=370105355800543&transaction_fbid=940226924354801), we will build on `ExecuteNode` to be initialized with member functions for `create_params_block()`, `get_shader()`, etc. `PrepackNode` doesn't need these members. Hence, it makes sense to separate the classes. Feel free to suggest better names. I don't really like mine. ghstack-source-id: 216481872 exported-using-ghexport Reviewed By: SS-JIA Differential Revision: D54042646 fbshipit-source-id: a8965d69c94cdb3fe9837e9b83b6db8a877949f0
1 parent 3a90aa6 commit dc5a9af

File tree

8 files changed

+58
-37
lines changed

8 files changed

+58
-37
lines changed

backends/vulkan/runtime/graph/Graph.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,8 @@ void ComputeGraph::copy_from_staging(
193193
}
194194

195195
void ComputeGraph::encode_prepack() {
196-
for (std::unique_ptr<OpNode>& node : prepack_nodes_) {
197-
node->encode_prepack(this);
196+
for (std::unique_ptr<PrepackNode>& node : prepack_nodes_) {
197+
node->encode(this);
198198
}
199199
}
200200

@@ -216,8 +216,8 @@ void ComputeGraph::encode_execute() {
216216
shared_object.bind_users(this);
217217
}
218218

219-
for (std::unique_ptr<OpNode>& node : execute_nodes_) {
220-
node->encode_execute(this);
219+
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
220+
node->encode(this);
221221
}
222222
}
223223

backends/vulkan/runtime/graph/Graph.h

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,32 +33,52 @@ struct IOValueRef {
3333
class ComputeGraph;
3434

3535
/*
36-
* Represents a single op in a ML model. In graph mode, ops will be implemented
37-
* introducing a derived class that implements encode_execute, which will
38-
* implement encoding of the shader corresponding to the op into the command
39-
* buffer of a ComputeGraph, as well as encode_prepack, which will implement
36+
* Represents a single prepacking op in a ML model. In graph mode, ops will be
37+
* implemented in a derived class that implements encode, which will implement
4038
* encoding of shaders transferring necessary data (such as weights and biases)
41-
* to the GPU, wherever prepacking is necessary.
39+
* to the GPU.
4240
*/
43-
class OpNode {
41+
class PrepackNode {
4442
friend class ComputeGraph;
4543

4644
public:
47-
OpNode(ValueRef input, ValueRef output) : inputs_{input}, outputs_{output} {}
48-
OpNode(
45+
PrepackNode(ValueRef tref, ValueRef packed) : tref_{tref}, packed_{packed} {}
46+
47+
virtual ~PrepackNode() = default;
48+
49+
protected:
50+
ValueRef tref_;
51+
ValueRef packed_;
52+
53+
public:
54+
virtual void encode(ComputeGraph* graph) const = 0;
55+
};
56+
57+
/*
58+
* Represents a single execution op in a ML model. In graph mode, ops will be
59+
* implemented in a derived class that implements encode, which will implement
60+
* encoding of the shader corresponding to the op into the command buffer of a
61+
* ComputeGraph.
62+
*/
63+
class ExecuteNode {
64+
friend class ComputeGraph;
65+
66+
public:
67+
ExecuteNode(ValueRef input, ValueRef output)
68+
: inputs_{input}, outputs_{output} {}
69+
ExecuteNode(
4970
const std::vector<ValueRef>& inputs,
5071
const std::vector<ValueRef>& outputs)
5172
: inputs_(inputs), outputs_(outputs) {}
5273

53-
virtual ~OpNode() = default;
74+
virtual ~ExecuteNode() = default;
5475

5576
protected:
5677
std::vector<ValueRef> inputs_;
5778
std::vector<ValueRef> outputs_;
5879

5980
public:
60-
virtual void encode_prepack(ComputeGraph* graph) const {}
61-
virtual void encode_execute(ComputeGraph* graph) const {}
81+
virtual void encode(ComputeGraph* graph) const = 0;
6282
};
6383

6484
struct SharedObject {
@@ -99,8 +119,8 @@ class ComputeGraph final {
99119
std::vector<SharedObject> shared_objects_;
100120
std::vector<Value> values_;
101121

102-
std::vector<std::unique_ptr<OpNode>> prepack_nodes_;
103-
std::vector<std::unique_ptr<OpNode>> execute_nodes_;
122+
std::vector<std::unique_ptr<PrepackNode>> prepack_nodes_;
123+
std::vector<std::unique_ptr<ExecuteNode>> execute_nodes_;
104124

105125
std::vector<ValueRef> inputs_;
106126
std::vector<ValueRef> outputs_;
@@ -149,11 +169,11 @@ class ComputeGraph final {
149169
VK_THROW("Could not get dtype of value with type ", val.type());
150170
}
151171

152-
inline std::vector<std::unique_ptr<OpNode>>& prepack_nodes() {
172+
inline std::vector<std::unique_ptr<PrepackNode>>& prepack_nodes() {
153173
return prepack_nodes_;
154174
}
155175

156-
inline std::vector<std::unique_ptr<OpNode>>& execute_nodes() {
176+
inline std::vector<std::unique_ptr<ExecuteNode>>& execute_nodes() {
157177
return execute_nodes_;
158178
}
159179

backends/vulkan/runtime/graph/ops/Arithmetic.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@ ValueRef add_arithmetic_node(
6161
}
6262

6363
ArithmeticPrepack::ArithmeticPrepack(const ValueRef tref, const ValueRef packed)
64-
: OpNode(tref, packed) {}
64+
: PrepackNode(tref, packed) {}
6565

66-
void ArithmeticPrepack::encode_prepack(ComputeGraph* graph) const {
67-
TensorRef tref = graph->get_val(inputs_[0]).toTensorRef();
68-
vTensor packed = graph->get_val(outputs_[0]).toTensor();
66+
void ArithmeticPrepack::encode(ComputeGraph* graph) const {
67+
TensorRef tref = graph->get_val(tref_).toTensorRef();
68+
vTensor packed = graph->get_val(packed_).toTensor();
6969

7070
api::StorageBuffer staging(
7171
graph->context(), packed.dtype(), packed.gpu_nbytes());
@@ -83,9 +83,9 @@ ArithmeticNode::ArithmeticNode(
8383
const ValueRef out,
8484
const float alpha,
8585
const arithmetic::OpType optype)
86-
: OpNode({t1, t2}, {out}), alpha_(alpha), optype_(optype) {}
86+
: ExecuteNode({t1, t2}, {out}), alpha_(alpha), optype_(optype) {}
8787

88-
void ArithmeticNode::encode_execute(ComputeGraph* graph) const {
88+
void ArithmeticNode::encode(ComputeGraph* graph) const {
8989
vTensor& in1 = graph->get_val(inputs_[0]).toTensor();
9090
vTensor& in2 = graph->get_val(inputs_[1]).toTensor();
9191
vTensor& out = graph->get_val(outputs_[0]).toTensor();

backends/vulkan/runtime/graph/ops/Arithmetic.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@ ValueRef add_arithmetic_node(
3434
const arithmetic::OpType optype,
3535
const int64_t shared_object_idx = -1);
3636

37-
class ArithmeticPrepack : public virtual OpNode {
37+
class ArithmeticPrepack : public virtual PrepackNode {
3838
public:
3939
explicit ArithmeticPrepack(const ValueRef tref, const ValueRef packed);
4040

41-
void encode_prepack(ComputeGraph* graph) const override;
41+
void encode(ComputeGraph* graph) const override;
4242
};
4343

44-
class ArithmeticNode : public virtual OpNode {
44+
class ArithmeticNode : public virtual ExecuteNode {
4545
public:
4646
explicit ArithmeticNode(
4747
const ValueRef t1,
@@ -50,7 +50,7 @@ class ArithmeticNode : public virtual OpNode {
5050
const float alpha,
5151
const arithmetic::OpType optype);
5252

53-
void encode_execute(ComputeGraph* graph) const override;
53+
void encode(ComputeGraph* graph) const override;
5454

5555
private:
5656
float alpha_;

backends/vulkan/runtime/graph/ops/Copy.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@ ValueRef add_copy_node(ComputeGraph& graph, const ValueRef from) {
2727
return to;
2828
}
2929

30-
CopyNode::CopyNode(const ValueRef from, const ValueRef to) : OpNode(from, to) {}
30+
CopyNode::CopyNode(const ValueRef from, const ValueRef to)
31+
: ExecuteNode(from, to) {}
3132

32-
void CopyNode::encode_execute(ComputeGraph* graph) const {
33+
void CopyNode::encode(ComputeGraph* graph) const {
3334
api::PipelineBarrier pipeline_barrier{};
3435

3536
vTensor& from_tensor = graph->get_val(inputs_[0]).toTensor();

backends/vulkan/runtime/graph/ops/Copy.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ namespace vulkan {
1919
void add_copy_node(ComputeGraph& graph, const ValueRef from, const ValueRef to);
2020
ValueRef add_copy_node(ComputeGraph& graph, const ValueRef from);
2121

22-
class CopyNode : public virtual OpNode {
22+
class CopyNode : public virtual ExecuteNode {
2323
public:
2424
explicit CopyNode(const ValueRef from, const ValueRef to);
2525

26-
void encode_execute(ComputeGraph* graph) const override;
26+
void encode(ComputeGraph* graph) const override;
2727
};
2828

2929
} // namespace vulkan

backends/vulkan/runtime/graph/ops/Staging.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,9 @@ void encode_copy_from_vtensor(
9898
VK_NULL_HANDLE);
9999
}
100100

101-
StagingNode::StagingNode(ValueRef from, ValueRef to) : OpNode(from, to) {}
101+
StagingNode::StagingNode(ValueRef from, ValueRef to) : ExecuteNode(from, to) {}
102102

103-
void StagingNode::encode_execute(ComputeGraph* graph) const {
103+
void StagingNode::encode(ComputeGraph* graph) const {
104104
Value& in_val = graph->get_val(inputs_[0]);
105105
Value& out_val = graph->get_val(outputs_[0]);
106106

backends/vulkan/runtime/graph/ops/Staging.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,11 @@ void encode_copy_from_vtensor(
8484
/*
8585
* OpNode that allows copying data into and out of a staging buffer.
8686
*/
87-
class StagingNode : public virtual OpNode {
87+
class StagingNode : public virtual ExecuteNode {
8888
public:
8989
explicit StagingNode(ValueRef from, ValueRef to);
9090

91-
void encode_execute(ComputeGraph* graph) const override;
91+
void encode(ComputeGraph* graph) const override;
9292
};
9393

9494
} // namespace vulkan

0 commit comments

Comments
 (0)