Skip to content

Commit 9714120

Browse files
committed
[ET-VK] Dynamic shape support in Vulkan Backend
## Context This changeset exposes API functions to the `ComputeGraph` class that allow inputs to be resized, and for the resizing to propagate through the graph via re-calculation of output shapes. Differential Revision: [D54754546](https://our.internmc.facebook.com/intern/diff/D54754546/) [ghstack-poisoned]
1 parent b24a594 commit 9714120

File tree

10 files changed

+287
-16
lines changed

10 files changed

+287
-16
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <executorch/runtime/backend/interface.h>
1717
#include <executorch/runtime/core/error.h>
1818
#include <executorch/runtime/core/evalue.h>
19+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
1920
#include <executorch/runtime/platform/compiler.h>
2021
#include <executorch/runtime/platform/profiler.h>
2122

@@ -195,6 +196,68 @@ class GraphBuilder {
195196
}
196197
};
197198

199+
//
200+
// Execution tools
201+
//
202+
203+
bool maybe_resize_input(
204+
ComputeGraph* graph,
205+
const size_t input_i,
206+
exec_aten::Tensor& et_tensor) {
207+
ValueRef in_tensor_ref = graph->inputs()[input_i].value;
208+
vTensor& in_tensor = graph->get_val(in_tensor_ref).toTensor();
209+
210+
ET_CHECK_MSG(
211+
et_tensor.dim() == in_tensor.sizes().size(),
212+
"Cannot resize input tensor: old ndim %zu does not match new ndim %zu",
213+
static_cast<size_t>(in_tensor.sizes().size()),
214+
static_cast<size_t>(et_tensor.dim()));
215+
216+
bool should_resize = false;
217+
std::vector<int64_t> new_sizes(et_tensor.dim());
218+
for (size_t i = 0; i < et_tensor.dim(); i++) {
219+
if (in_tensor.sizes()[i] != et_tensor.sizes()[i]) {
220+
should_resize = true;
221+
}
222+
new_sizes[i] = et_tensor.sizes()[i];
223+
}
224+
225+
if (should_resize) {
226+
graph->resize_input(input_i, new_sizes);
227+
}
228+
229+
ET_CHECK_MSG(
230+
in_tensor.numel() == et_tensor.numel(),
231+
"Vulkan tensor numel %zu does not match ET tensor numel %zu",
232+
static_cast<size_t>(in_tensor.numel()),
233+
static_cast<size_t>(et_tensor.numel()));
234+
235+
return should_resize;
236+
}
237+
238+
void resize_output(
239+
ComputeGraph* graph,
240+
const size_t output_i,
241+
exec_aten::Tensor& et_tensor) {
242+
ValueRef out_tensor_ref = graph->outputs()[output_i].value;
243+
vTensor& out_tensor = graph->get_val(out_tensor_ref).toTensor();
244+
245+
exec_aten::SizesType new_output_size[kTensorDimensionLimit];
246+
size_t ndim = out_tensor.sizes().size();
247+
for (int i = 0; i < ndim; ++i) {
248+
new_output_size[i] = out_tensor.sizes()[i];
249+
}
250+
251+
exec_aten::ArrayRef<exec_aten::SizesType> output_size{new_output_size, ndim};
252+
Error err = resize_tensor(et_tensor, output_size);
253+
254+
ET_CHECK_MSG(err == Error::Ok, "Failed to resize output tensor.");
255+
}
256+
257+
//
258+
// VulkanBackend class
259+
//
260+
198261
class VulkanBackend final : public PyTorchBackendInterface {
199262
public:
200263
~VulkanBackend() override = default;
@@ -273,20 +336,28 @@ class VulkanBackend final : public PyTorchBackendInterface {
273336
ComputeGraph* compute_graph = static_cast<ComputeGraph*>(handle);
274337

275338
const size_t num_inputs = compute_graph->inputs().size();
339+
bool should_propagate_resize = false;
276340
for (size_t i = 0; i < num_inputs; i++) {
341+
bool was_resized =
342+
maybe_resize_input(compute_graph, i, args[i]->toTensor());
343+
should_propagate_resize = should_propagate_resize || was_resized;
277344
compute_graph->copy_into_staging(
278-
compute_graph->inputs()[i],
345+
compute_graph->inputs()[i].staging,
279346
args[i]->toTensor().const_data_ptr(),
280347
args[i]->toTensor().numel());
281348
}
282349

350+
if (should_propagate_resize) {
351+
compute_graph->propagate_resize();
352+
}
283353
compute_graph->execute();
284354

285355
for (size_t i = 0; i < compute_graph->outputs().size(); i++) {
356+
resize_output(compute_graph, i, args[num_inputs + i]->toTensor());
286357
// args holds inputs directly followed by outputs, so the i'th output
287358
// for compute_graph corresponds to the (i + num_inputs)'th arg
288359
compute_graph->copy_from_staging(
289-
compute_graph->outputs()[i],
360+
compute_graph->outputs()[i].staging,
290361
args[num_inputs + i]->toTensor().mutable_data_ptr(),
291362
args[num_inputs + i]->toTensor().numel());
292363
}

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,10 @@ ValueRef ComputeGraph::set_input_tensor(
135135
vTensor& tensor = get_val(idx).toTensor();
136136
ValueRef staging_idx = add_staging(tensor.dtype(), tensor.gpu_numel());
137137
add_staging_to_tensor_node(*this, staging_idx, idx);
138-
inputs_.push_back(staging_idx);
138+
inputs_.push_back({idx, staging_idx});
139139
return staging_idx;
140140
}
141-
inputs_.push_back(idx);
141+
inputs_.push_back({idx, -1});
142142
return idx;
143143
}
144144

@@ -149,10 +149,10 @@ ValueRef ComputeGraph::set_output_tensor(
149149
vTensor& tensor = get_val(idx).toTensor();
150150
ValueRef staging_idx = add_staging(tensor.dtype(), tensor.gpu_numel());
151151
add_tensor_to_staging_node(*this, idx, staging_idx);
152-
outputs_.push_back(staging_idx);
152+
outputs_.push_back({idx, staging_idx});
153153
return staging_idx;
154154
}
155-
outputs_.push_back(idx);
155+
outputs_.push_back({idx, -1});
156156
return idx;
157157
}
158158

@@ -241,6 +241,19 @@ void ComputeGraph::execute() const {
241241
fence.wait();
242242
}
243243

244+
void ComputeGraph::resize_input(
245+
const int64_t idx,
246+
const std::vector<int64_t>& new_sizes) {
247+
IOValueRef io_val = inputs_.at(idx);
248+
get_val(io_val.value).toTensor().virtual_resize(new_sizes);
249+
}
250+
251+
void ComputeGraph::propagate_resize() {
252+
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
253+
node->trigger_resize(this);
254+
}
255+
}
256+
244257
} // namespace vulkan
245258
} // namespace native
246259
} // namespace at

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ class ComputeGraph final {
6868
std::vector<std::unique_ptr<PrepackNode>> prepack_nodes_;
6969
std::vector<std::unique_ptr<ExecuteNode>> execute_nodes_;
7070

71-
std::vector<ValueRef> inputs_;
72-
std::vector<ValueRef> outputs_;
71+
std::vector<IOValueRef> inputs_;
72+
std::vector<IOValueRef> outputs_;
7373

7474
public:
7575
//
@@ -80,11 +80,11 @@ class ComputeGraph final {
8080
return context_.get();
8181
}
8282

83-
inline std::vector<ValueRef>& inputs() {
83+
inline std::vector<IOValueRef>& inputs() {
8484
return inputs_;
8585
}
8686

87-
inline std::vector<ValueRef>& outputs() {
87+
inline std::vector<IOValueRef>& outputs() {
8888
return outputs_;
8989
}
9090

@@ -201,6 +201,13 @@ class ComputeGraph final {
201201

202202
void encode_execute();
203203
void execute() const;
204+
205+
//
206+
// Dynamic Shape support
207+
//
208+
209+
void resize_input(const int64_t idx, const std::vector<int64_t>& new_sizes);
210+
void propagate_resize();
204211
};
205212

206213
template <typename T>

backends/vulkan/runtime/graph/ops/ExecuteNode.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,16 @@ ExecuteNode::ExecuteNode(
2222
const api::utils::uvec3& global_workgroup_size,
2323
const api::utils::uvec3& local_workgroup_size,
2424
const std::vector<ArgGroup>& args,
25-
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params)
25+
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params,
26+
const std::vector<ValueRef>& extra_args,
27+
const ResizeFunction& resize_fn)
2628
: shader_(shader),
2729
global_workgroup_size_(global_workgroup_size),
2830
local_workgroup_size_(local_workgroup_size),
2931
args_(args),
30-
params_(params) {
32+
params_(params),
33+
extra_args_(extra_args),
34+
resize_fn_(resize_fn) {
3135
graph.update_descriptor_counts(shader, /*execute = */ true);
3236
}
3337

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,25 +47,40 @@ class ExecuteNode final {
4747
friend class ComputeGraph;
4848

4949
public:
50+
using ResizeFunction = const std::function<void(
51+
ComputeGraph*,
52+
const std::vector<ArgGroup>&,
53+
const std::vector<ValueRef>&)>;
54+
5055
ExecuteNode(
5156
ComputeGraph& graph,
5257
const api::ShaderInfo& shader,
5358
const api::utils::uvec3& global_workgroup_size,
5459
const api::utils::uvec3& local_workgroup_size,
5560
const std::vector<ArgGroup>& args,
56-
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params);
61+
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params,
62+
const std::vector<ValueRef>& extra_args = {},
63+
const ResizeFunction& resize_fn = nullptr);
5764

5865
~ExecuteNode() = default;
5966

6067
void encode(ComputeGraph* graph);
6168

69+
inline void trigger_resize(ComputeGraph* graph) {
70+
if (resize_fn_ != nullptr) {
71+
resize_fn_(graph, args_, extra_args_);
72+
}
73+
}
74+
6275
protected:
6376
const api::ShaderInfo shader_;
6477
const api::utils::uvec3 global_workgroup_size_;
6578
const api::utils::uvec3 local_workgroup_size_;
6679
const std::vector<ArgGroup> args_;
6780
// TODO(T180906457): allow re-computing param buffers.
6881
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params_;
82+
const std::vector<ValueRef> extra_args_;
83+
const ResizeFunction resize_fn_;
6984
};
7085

7186
} // namespace vulkan

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,26 @@ std::string get_arithmetic_shader_name(const std::string& op_name) {
2323
return "arithmetic_" + op_name;
2424
}
2525

26+
void resize_arithmetic_node(
27+
ComputeGraph* graph,
28+
const std::vector<ArgGroup>& args,
29+
const std::vector<ValueRef>& extra_args) {
30+
vTensor& out = graph->get_val(args[0].refs[0]).toTensor();
31+
vTensor& self = graph->get_val(args[1].refs[0]).toTensor();
32+
vTensor& other = graph->get_val(args[1].refs[1]).toTensor();
33+
34+
std::vector<int64_t> new_out_sizes(
35+
std::max(self.sizes().size(), other.sizes().size()));
36+
37+
for (int i = -1; i >= -new_out_sizes.size(); --i) {
38+
new_out_sizes[new_out_sizes.size() + i] = std::max(
39+
api::utils::val_at(i, self.sizes()),
40+
api::utils::val_at(i, other.sizes()));
41+
}
42+
43+
out.virtual_resize(new_out_sizes);
44+
}
45+
2646
void add_arithmetic_node(
2747
ComputeGraph& graph,
2848
const ValueRef in1,
@@ -56,12 +76,17 @@ void add_arithmetic_node(
5676
VK_KERNEL_FROM_STR(kernel_name.str()),
5777
global_size,
5878
local_size,
79+
// Inputs and Outputs
5980
{{out, api::MemoryAccessType::WRITE},
6081
{{arg1, arg2}, api::MemoryAccessType::READ}},
82+
// Shader params buffers
6183
{t_out.gpu_sizes_ubo(),
6284
t_in1.gpu_sizes_ubo(),
6385
t_in2.gpu_sizes_ubo(),
64-
graph.create_params_buffer(alpha_val)}));
86+
graph.create_params_buffer(alpha_val)},
87+
// Resizing
88+
{alpha},
89+
resize_arithmetic_node));
6590
}
6691

6792
#define DEFINE_ARITHMETIC_WITH_ALPHA_FN(function, shader) \

backends/vulkan/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def define_common_targets():
143143
":vk_delegate_schema",
144144
":vulkan_graph_runtime",
145145
"//executorch/runtime/backend:interface",
146+
"//executorch/runtime/core/exec_aten/util:tensor_util",
146147
],
147148
define_static_target = False,
148149
# VulkanBackend.cpp needs to compile with executor as whole

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend
1515

1616
from executorch.exir import EdgeProgramManager, to_edge
17-
from torch.export import export, ExportedProgram
17+
from torch.export import Dim, export, ExportedProgram
1818

1919
ctypes.CDLL("libvulkan.so.1")
2020

@@ -54,13 +54,17 @@ def lower_module_and_test_output(
5454
sample_inputs: Tuple[torch.Tensor],
5555
atol=1e-03,
5656
rtol=1e-01,
57+
dynamic_shapes=None,
58+
test_inputs=None,
5759
):
5860
"""
5961
Helper testing function that takes a torch.nn.Module and lowers it to Vulkan with
6062
the given sample inputs. It then runs the lowered module and compares its
6163
outputs with the outputs of the eager module.
6264
"""
63-
program: ExportedProgram = export(model, sample_inputs)
65+
program: ExportedProgram = export(
66+
model, sample_inputs, dynamic_shapes=dynamic_shapes
67+
)
6468
edge_program: EdgeProgramManager = to_edge(program)
6569
edge_program = edge_program.to_backend(VulkanPartitioner())
6670

@@ -80,6 +84,19 @@ def lower_module_and_test_output(
8084

8185
self.assert_outputs_equal(model_output, ref_output, atol=atol, rtol=rtol)
8286

87+
if test_inputs is not None:
88+
for test_input in test_inputs:
89+
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
90+
test_inputs_flattened, _ = tree_flatten(test_input)
91+
model_output = executorch_module.run_method(
92+
"forward", tuple(test_inputs_flattened)
93+
)
94+
ref_output = model(*test_input)
95+
96+
self.assert_outputs_equal(
97+
model_output, ref_output, atol=atol, rtol=rtol
98+
)
99+
83100
def test_vulkan_backend_add(self):
84101
# This test is the simplest test by manually lowering some submodules, we can use paritioner for auto detecting lowerable parts
85102
class AddModule(torch.nn.Module):
@@ -251,3 +268,38 @@ def forward(self, x):
251268
model_inputs = (torch.rand(size=(2, 10), dtype=torch.float32),)
252269

253270
self.lower_module_and_test_output(model, model_inputs)
271+
272+
def test_vulkan_backend_partial_dynamic_shapes(self):
273+
class SimpleModel(torch.nn.Module):
274+
def __init__(self):
275+
super().__init__()
276+
self.branch1 = torch.nn.Sequential(
277+
torch.nn.Linear(64, 64), torch.nn.ReLU()
278+
)
279+
self.branch2 = torch.nn.Sequential(
280+
torch.nn.Linear(128, 64), torch.nn.ReLU()
281+
)
282+
self.buffer_1 = torch.ones((1, 64)) * 0.5
283+
self.buffer_2 = torch.ones((1, 64)) * 1.4
284+
285+
def forward(self, x1, x2):
286+
out1 = self.branch1(x1)
287+
out2 = self.branch2(x2)
288+
return (out1 + self.buffer_1 + out2) * self.buffer_2
289+
290+
model = SimpleModel()
291+
model_inputs = (torch.randn(32, 64), torch.randn(32, 128))
292+
batch = Dim("batch", max=124)
293+
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
294+
295+
test_inputs = [
296+
(torch.randn(15, 64), torch.randn(15, 128)),
297+
(torch.randn(6, 64), torch.randn(6, 128)),
298+
(torch.randn(30, 64), torch.randn(30, 128)),
299+
(torch.randn(20, 64), torch.randn(20, 128)),
300+
(torch.randn(19, 64), torch.randn(19, 128)),
301+
]
302+
303+
self.lower_module_and_test_output(
304+
model, model_inputs, dynamic_shapes=dynamic_shapes, test_inputs=test_inputs
305+
)

0 commit comments

Comments
 (0)