Skip to content

Commit 7156f1b

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Dynamic shape support in Vulkan Backend (#2367)
Summary: Pull Request resolved: #2367 ## Context This changeset exposes API functions to the `ComputeGraph` class that allow inputs to be resized, and for the resizing to propagate through the graph via re-calculation of output shapes. ghstack-source-id: 218429131 exported-using-ghexport bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: jorgep31415 Differential Revision: D54754546 fbshipit-source-id: c312eb04849f8b9c1e3dea22ac427c76e34b9dd5
1 parent 835279e commit 7156f1b

File tree

11 files changed

+269
-42
lines changed

11 files changed

+269
-42
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <executorch/runtime/backend/interface.h>
1717
#include <executorch/runtime/core/error.h>
1818
#include <executorch/runtime/core/evalue.h>
19+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
1920
#include <executorch/runtime/platform/compiler.h>
2021
#include <executorch/runtime/platform/profiler.h>
2122

@@ -195,6 +196,68 @@ class GraphBuilder {
195196
}
196197
};
197198

199+
//
200+
// Execution tools
201+
//
202+
203+
bool maybe_resize_input(
204+
ComputeGraph* graph,
205+
const size_t input_i,
206+
exec_aten::Tensor& et_tensor) {
207+
ValueRef in_tensor_ref = graph->inputs()[input_i].value;
208+
vTensor& in_tensor = graph->get_val(in_tensor_ref).toTensor();
209+
210+
ET_CHECK_MSG(
211+
et_tensor.dim() == in_tensor.sizes().size(),
212+
"Cannot resize input tensor: old ndim %zu does not match new ndim %zu",
213+
static_cast<size_t>(in_tensor.sizes().size()),
214+
static_cast<size_t>(et_tensor.dim()));
215+
216+
bool should_resize = false;
217+
std::vector<int64_t> new_sizes(et_tensor.dim());
218+
for (size_t i = 0; i < et_tensor.dim(); i++) {
219+
if (in_tensor.sizes()[i] != et_tensor.sizes()[i]) {
220+
should_resize = true;
221+
}
222+
new_sizes.at(i) = et_tensor.sizes()[i];
223+
}
224+
225+
if (should_resize) {
226+
graph->resize_input(input_i, new_sizes);
227+
}
228+
229+
ET_CHECK_MSG(
230+
in_tensor.numel() == et_tensor.numel(),
231+
"Vulkan tensor numel %zu does not match ET tensor numel %zu",
232+
static_cast<size_t>(in_tensor.numel()),
233+
static_cast<size_t>(et_tensor.numel()));
234+
235+
return should_resize;
236+
}
237+
238+
void maybe_resize_output(
239+
ComputeGraph* graph,
240+
const size_t output_i,
241+
exec_aten::Tensor& et_tensor) {
242+
ValueRef out_tensor_ref = graph->outputs()[output_i].value;
243+
vTensor& out_tensor = graph->get_val(out_tensor_ref).toTensor();
244+
245+
exec_aten::SizesType new_output_size[kTensorDimensionLimit];
246+
size_t ndim = out_tensor.sizes().size();
247+
for (int i = 0; i < ndim; ++i) {
248+
new_output_size[i] = out_tensor.sizes()[i];
249+
}
250+
251+
exec_aten::ArrayRef<exec_aten::SizesType> output_size{new_output_size, ndim};
252+
Error err = resize_tensor(et_tensor, output_size);
253+
254+
ET_CHECK_MSG(err == Error::Ok, "Failed to resize output tensor.");
255+
}
256+
257+
//
258+
// VulkanBackend class
259+
//
260+
198261
class VulkanBackend final : public PyTorchBackendInterface {
199262
public:
200263
~VulkanBackend() override = default;
@@ -273,20 +336,28 @@ class VulkanBackend final : public PyTorchBackendInterface {
273336
ComputeGraph* compute_graph = static_cast<ComputeGraph*>(handle);
274337

275338
const size_t num_inputs = compute_graph->inputs().size();
339+
bool should_propagate_resize = false;
276340
for (size_t i = 0; i < num_inputs; i++) {
341+
bool was_resized =
342+
maybe_resize_input(compute_graph, i, args[i]->toTensor());
343+
should_propagate_resize = should_propagate_resize || was_resized;
277344
compute_graph->copy_into_staging(
278-
compute_graph->inputs()[i],
345+
compute_graph->inputs()[i].staging,
279346
args[i]->toTensor().const_data_ptr(),
280347
args[i]->toTensor().numel());
281348
}
282349

350+
if (should_propagate_resize) {
351+
compute_graph->propagate_resize();
352+
}
283353
compute_graph->execute();
284354

285355
for (size_t i = 0; i < compute_graph->outputs().size(); i++) {
356+
maybe_resize_output(compute_graph, i, args[num_inputs + i]->toTensor());
286357
// args holds inputs directly followed by outputs, so the i'th output
287358
// for compute_graph corresponds to the (i + num_inputs)'th arg
288359
compute_graph->copy_from_staging(
289-
compute_graph->outputs()[i],
360+
compute_graph->outputs()[i].staging,
290361
args[num_inputs + i]->toTensor().mutable_data_ptr(),
291362
args[num_inputs + i]->toTensor().numel());
292363
}

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,10 @@ ValueRef ComputeGraph::set_input_tensor(
135135
vTensor& tensor = get_val(idx).toTensor();
136136
ValueRef staging_idx = add_staging(tensor.dtype(), tensor.gpu_numel());
137137
add_staging_to_tensor_node(*this, staging_idx, idx);
138-
inputs_.push_back(staging_idx);
138+
inputs_.push_back({idx, staging_idx});
139139
return staging_idx;
140140
}
141-
inputs_.push_back(idx);
141+
inputs_.push_back({idx, kDummyValueRef});
142142
return idx;
143143
}
144144

@@ -149,10 +149,10 @@ ValueRef ComputeGraph::set_output_tensor(
149149
vTensor& tensor = get_val(idx).toTensor();
150150
ValueRef staging_idx = add_staging(tensor.dtype(), tensor.gpu_numel());
151151
add_tensor_to_staging_node(*this, idx, staging_idx);
152-
outputs_.push_back(staging_idx);
152+
outputs_.push_back({idx, staging_idx});
153153
return staging_idx;
154154
}
155-
outputs_.push_back(idx);
155+
outputs_.push_back({idx, kDummyValueRef});
156156
return idx;
157157
}
158158

@@ -241,6 +241,19 @@ void ComputeGraph::execute() const {
241241
fence.wait();
242242
}
243243

244+
void ComputeGraph::resize_input(
245+
const int64_t idx,
246+
const std::vector<int64_t>& new_sizes) {
247+
IOValueRef io_val = inputs_.at(idx);
248+
get_val(io_val.value).toTensor().virtual_resize(new_sizes);
249+
}
250+
251+
void ComputeGraph::propagate_resize() {
252+
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
253+
node->trigger_resize(this);
254+
}
255+
}
256+
244257
} // namespace vulkan
245258
} // namespace native
246259
} // namespace at

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ class ComputeGraph final {
6868
std::vector<std::unique_ptr<PrepackNode>> prepack_nodes_;
6969
std::vector<std::unique_ptr<ExecuteNode>> execute_nodes_;
7070

71-
std::vector<ValueRef> inputs_;
72-
std::vector<ValueRef> outputs_;
71+
std::vector<IOValueRef> inputs_;
72+
std::vector<IOValueRef> outputs_;
7373

7474
public:
7575
//
@@ -80,11 +80,11 @@ class ComputeGraph final {
8080
return context_.get();
8181
}
8282

83-
inline std::vector<ValueRef>& inputs() {
83+
inline std::vector<IOValueRef>& inputs() {
8484
return inputs_;
8585
}
8686

87-
inline std::vector<ValueRef>& outputs() {
87+
inline std::vector<IOValueRef>& outputs() {
8888
return outputs_;
8989
}
9090

@@ -201,6 +201,13 @@ class ComputeGraph final {
201201

202202
void encode_execute();
203203
void execute() const;
204+
205+
//
206+
// Dynamic Shape support
207+
//
208+
209+
void resize_input(const int64_t idx, const std::vector<int64_t>& new_sizes);
210+
void propagate_resize();
204211
};
205212

206213
template <typename T>

backends/vulkan/runtime/graph/ops/ExecuteNode.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,16 @@ ExecuteNode::ExecuteNode(
2222
const api::utils::uvec3& global_workgroup_size,
2323
const api::utils::uvec3& local_workgroup_size,
2424
const std::vector<ArgGroup>& args,
25-
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params)
25+
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
26+
const ResizeFunction& resize_fn,
27+
const std::vector<ValueRef>& resize_args)
2628
: shader_(shader),
2729
global_workgroup_size_(global_workgroup_size),
2830
local_workgroup_size_(local_workgroup_size),
2931
args_(args),
30-
params_(params) {
32+
params_(params),
33+
resize_fn_(resize_fn),
34+
resize_args_(resize_args) {
3135
graph.update_descriptor_counts(shader, /*execute = */ true);
3236
}
3337

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,25 +47,40 @@ class ExecuteNode final {
4747
friend class ComputeGraph;
4848

4949
public:
50+
using ResizeFunction = const std::function<void(
51+
ComputeGraph*,
52+
const std::vector<ArgGroup>&,
53+
const std::vector<ValueRef>&)>;
54+
5055
ExecuteNode(
5156
ComputeGraph& graph,
5257
const api::ShaderInfo& shader,
5358
const api::utils::uvec3& global_workgroup_size,
5459
const api::utils::uvec3& local_workgroup_size,
5560
const std::vector<ArgGroup>& args,
56-
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params);
61+
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
62+
const ResizeFunction& resize_fn = nullptr,
63+
const std::vector<ValueRef>& resize_args = {});
5764

5865
~ExecuteNode() = default;
5966

6067
void encode(ComputeGraph* graph);
6168

69+
inline void trigger_resize(ComputeGraph* graph) {
70+
if (resize_fn_ != nullptr) {
71+
resize_fn_(graph, args_, resize_args_);
72+
}
73+
}
74+
6275
protected:
6376
const api::ShaderInfo shader_;
6477
const api::utils::uvec3 global_workgroup_size_;
6578
const api::utils::uvec3 local_workgroup_size_;
6679
const std::vector<ArgGroup> args_;
6780
// TODO(T180906457): allow re-computing param buffers.
6881
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params_;
82+
const ResizeFunction resize_fn_;
83+
const std::vector<ValueRef> resize_args_;
6984
};
7085

7186
} // namespace vulkan

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,28 @@ namespace at {
1919
namespace native {
2020
namespace vulkan {
2121

22+
void resize_binary_op_node(
23+
ComputeGraph* graph,
24+
const std::vector<ArgGroup>& args,
25+
const std::vector<ValueRef>& extra_args) {
26+
(void)extra_args;
27+
vTensor& out = graph->get_val(args[0].refs[0]).toTensor();
28+
vTensor& self = graph->get_val(args[1].refs[0]).toTensor();
29+
vTensor& other = graph->get_val(args[1].refs[1]).toTensor();
30+
31+
std::vector<int64_t> new_out_sizes(
32+
std::max(self.sizes().size(), other.sizes().size()));
33+
34+
// Match the sizes in reverse because sizes are in NCHW order
35+
for (int i = -1; i >= -new_out_sizes.size(); --i) {
36+
new_out_sizes.at(new_out_sizes.size() + i) = std::max(
37+
api::utils::val_at(i, self.sizes()),
38+
api::utils::val_at(i, other.sizes()));
39+
}
40+
41+
out.virtual_resize(new_out_sizes);
42+
}
43+
2244
void add_binary_op_node(
2345
ComputeGraph& graph,
2446
const ValueRef in1,
@@ -52,12 +74,16 @@ void add_binary_op_node(
5274
VK_KERNEL_FROM_STR(kernel_name.str()),
5375
global_size,
5476
local_size,
77+
// Inputs and Outputs
5578
{{out, api::MemoryAccessType::WRITE},
5679
{{arg1, arg2}, api::MemoryAccessType::READ}},
80+
// Shader params buffers
5781
{t_out.gpu_sizes_ubo(),
5882
t_in1.gpu_sizes_ubo(),
5983
t_in2.gpu_sizes_ubo(),
60-
graph.create_params_buffer(alpha_val)}));
84+
graph.create_params_buffer(alpha_val)},
85+
// Resizing
86+
resize_binary_op_node));
6187
}
6288

6389
#define DEFINE_BINARY_OP_WITH_ALPHA_FN(op_name) \

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,14 @@ def process_getattr_node(self, node: Node) -> None:
218218
self.create_tensor_values(node)
219219

220220
def process_output_node(self, node: Node) -> None:
221-
if node.all_input_nodes[0] not in self.node_to_value_ids:
222-
raise AssertionError(
223-
"Cannot find input to output node in node_to_value_ids. This means the "
224-
"output node is being serialized before its corresponding internal node "
225-
"which is not allowed."
226-
)
227-
self.output_ids.append(self.node_to_value_ids[node.all_input_nodes[0]])
221+
for out_node in node.all_input_nodes:
222+
if out_node not in self.node_to_value_ids:
223+
raise AssertionError(
224+
"Cannot find input to output node in node_to_value_ids. This means "
225+
"the output node is being serialized before its corresponding "
226+
"internal node which is not allowed."
227+
)
228+
self.output_ids.append(self.node_to_value_ids[out_node])
228229

229230
def process_node(self, node: Node) -> None:
230231
if node.op == "placeholder":

backends/vulkan/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def define_common_targets(is_fbcode = False):
146146
":vk_delegate_schema",
147147
":vulkan_graph_runtime",
148148
"//executorch/runtime/backend:interface",
149+
"//executorch/runtime/core/exec_aten/util:tensor_util",
149150
],
150151
define_static_target = False,
151152
# VulkanBackend.cpp needs to compile with executor as whole

0 commit comments

Comments
 (0)