Skip to content

Commit 59e0bf7

Browse files
committed
Update base for Update on "[ET-VK] Enable Partial GPU lowering via Vulkan in stories model export"
## Context Simple change to add Vulkan Partitioner as a dependency for the llama exporter and runner, and provide a command line flag to invoke the vulkan partitioner during export. Included a small change to the Vulkan serializer which was needed for everything to work (i.e. enable serializing multiple graph outputs). Differential Revision: [D54805831](https://our.internmc.facebook.com/intern/diff/D54805831/) [ghstack-poisoned]
2 parents 9714120 + 35a847e commit 59e0bf7

32 files changed

+693
-1062
lines changed

backends/vulkan/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ load(":targets.bzl", "define_common_targets")
33

44
oncall("executorch")
55

6-
define_common_targets()
6+
define_common_targets(is_fbcode = True)
77

88
runtime.python_library(
99
name = "vulkan_preprocess",

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ bool maybe_resize_input(
219219
if (in_tensor.sizes()[i] != et_tensor.sizes()[i]) {
220220
should_resize = true;
221221
}
222-
new_sizes[i] = et_tensor.sizes()[i];
222+
new_sizes.at(i) = et_tensor.sizes()[i];
223223
}
224224

225225
if (should_resize) {
@@ -235,7 +235,7 @@ bool maybe_resize_input(
235235
return should_resize;
236236
}
237237

238-
void resize_output(
238+
void maybe_resize_output(
239239
ComputeGraph* graph,
240240
const size_t output_i,
241241
exec_aten::Tensor& et_tensor) {
@@ -353,7 +353,7 @@ class VulkanBackend final : public PyTorchBackendInterface {
353353
compute_graph->execute();
354354

355355
for (size_t i = 0; i < compute_graph->outputs().size(); i++) {
356-
resize_output(compute_graph, i, args[num_inputs + i]->toTensor());
356+
maybe_resize_output(compute_graph, i, args[num_inputs + i]->toTensor());
357357
// args holds inputs directly followed by outputs, so the i'th output
358358
// for compute_graph corresponds to the (i + num_inputs)'th arg
359359
compute_graph->copy_from_staging(

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ ValueRef ComputeGraph::set_input_tensor(
138138
inputs_.push_back({idx, staging_idx});
139139
return staging_idx;
140140
}
141-
inputs_.push_back({idx, -1});
141+
inputs_.push_back({idx, kDummyValueRef});
142142
return idx;
143143
}
144144

@@ -152,7 +152,7 @@ ValueRef ComputeGraph::set_output_tensor(
152152
outputs_.push_back({idx, staging_idx});
153153
return staging_idx;
154154
}
155-
outputs_.push_back({idx, -1});
155+
outputs_.push_back({idx, kDummyValueRef});
156156
return idx;
157157
}
158158

backends/vulkan/runtime/graph/ops/ExecuteNode.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@ ExecuteNode::ExecuteNode(
2222
const api::utils::uvec3& global_workgroup_size,
2323
const api::utils::uvec3& local_workgroup_size,
2424
const std::vector<ArgGroup>& args,
25-
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params,
26-
const std::vector<ValueRef>& extra_args,
27-
const ResizeFunction& resize_fn)
25+
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
26+
const ResizeFunction& resize_fn,
27+
const std::vector<ValueRef>& resize_args)
2828
: shader_(shader),
2929
global_workgroup_size_(global_workgroup_size),
3030
local_workgroup_size_(local_workgroup_size),
3131
args_(args),
3232
params_(params),
33-
extra_args_(extra_args),
34-
resize_fn_(resize_fn) {
33+
resize_fn_(resize_fn),
34+
resize_args_(resize_args) {
3535
graph.update_descriptor_counts(shader, /*execute = */ true);
3636
}
3737

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,17 @@ class ExecuteNode final {
5858
const api::utils::uvec3& global_workgroup_size,
5959
const api::utils::uvec3& local_workgroup_size,
6060
const std::vector<ArgGroup>& args,
61-
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params,
62-
const std::vector<ValueRef>& extra_args = {},
63-
const ResizeFunction& resize_fn = nullptr);
61+
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
62+
const ResizeFunction& resize_fn = nullptr,
63+
const std::vector<ValueRef>& resize_args = {});
6464

6565
~ExecuteNode() = default;
6666

6767
void encode(ComputeGraph* graph);
6868

6969
inline void trigger_resize(ComputeGraph* graph) {
7070
if (resize_fn_ != nullptr) {
71-
resize_fn_(graph, args_, extra_args_);
71+
resize_fn_(graph, args_, resize_args_);
7272
}
7373
}
7474

@@ -79,8 +79,8 @@ class ExecuteNode final {
7979
const std::vector<ArgGroup> args_;
8080
// TODO(T180906457): allow re-computing param buffers.
8181
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params_;
82-
const std::vector<ValueRef> extra_args_;
8382
const ResizeFunction resize_fn_;
83+
const std::vector<ValueRef> resize_args_;
8484
};
8585

8686
} // namespace vulkan

backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
#version 450 core
1010

11-
#include "indexing_utils.h"
1211
#include "broadcasting_utils.h"
12+
#include "indexing_utils.h"
1313

1414
#define PRECISION ${PRECISION}
1515

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,31 +19,29 @@ namespace at {
1919
namespace native {
2020
namespace vulkan {
2121

22-
std::string get_arithmetic_shader_name(const std::string& op_name) {
23-
return "arithmetic_" + op_name;
24-
}
25-
26-
void resize_arithmetic_node(
22+
void resize_binary_op_node(
2723
ComputeGraph* graph,
2824
const std::vector<ArgGroup>& args,
2925
const std::vector<ValueRef>& extra_args) {
26+
(void)extra_args;
3027
vTensor& out = graph->get_val(args[0].refs[0]).toTensor();
3128
vTensor& self = graph->get_val(args[1].refs[0]).toTensor();
3229
vTensor& other = graph->get_val(args[1].refs[1]).toTensor();
3330

3431
std::vector<int64_t> new_out_sizes(
3532
std::max(self.sizes().size(), other.sizes().size()));
3633

34+
// Match the sizes in reverse because sizes are in NCHW order
3735
for (int i = -1; i >= -new_out_sizes.size(); --i) {
38-
new_out_sizes[new_out_sizes.size() + i] = std::max(
36+
new_out_sizes.at(new_out_sizes.size() + i) = std::max(
3937
api::utils::val_at(i, self.sizes()),
4038
api::utils::val_at(i, other.sizes()));
4139
}
4240

4341
out.virtual_resize(new_out_sizes);
4442
}
4543

46-
void add_arithmetic_node(
44+
void add_binary_op_node(
4745
ComputeGraph& graph,
4846
const ValueRef in1,
4947
const ValueRef in2,
@@ -85,39 +83,38 @@ void add_arithmetic_node(
8583
t_in2.gpu_sizes_ubo(),
8684
graph.create_params_buffer(alpha_val)},
8785
// Resizing
88-
{alpha},
89-
resize_arithmetic_node));
86+
resize_binary_op_node));
9087
}
9188

92-
#define DEFINE_ARITHMETIC_WITH_ALPHA_FN(function, shader) \
93-
void function(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
94-
return add_arithmetic_node( \
95-
graph, args[0], args[1], args[2], args[3], #shader); \
89+
#define DEFINE_BINARY_OP_WITH_ALPHA_FN(op_name) \
90+
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
91+
return add_binary_op_node( \
92+
graph, args[0], args[1], args[2], args[3], #op_name); \
9693
}
9794

98-
#define DEFINE_ARITHMETIC_FN(function, shader) \
99-
void function(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
100-
return add_arithmetic_node( \
101-
graph, args[0], args[1], kDummyValueRef, args[2], #shader); \
95+
#define DEFINE_BINARY_OP_FN(op_name) \
96+
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
97+
return add_binary_op_node( \
98+
graph, args[0], args[1], kDummyValueRef, args[2], #op_name); \
10299
}
103100

104-
DEFINE_ARITHMETIC_WITH_ALPHA_FN(add, add);
105-
DEFINE_ARITHMETIC_WITH_ALPHA_FN(sub, sub);
101+
DEFINE_BINARY_OP_WITH_ALPHA_FN(add);
102+
DEFINE_BINARY_OP_WITH_ALPHA_FN(sub);
106103

107104
// Floor div does not have an alpha, but a string argument (which is unused) is
108105
// passed in at the same location as the alpha argument in other op.
109-
DEFINE_ARITHMETIC_WITH_ALPHA_FN(floor_div, floor_divide);
106+
DEFINE_BINARY_OP_WITH_ALPHA_FN(floor_divide);
110107

111-
DEFINE_ARITHMETIC_FN(mul, mul);
112-
DEFINE_ARITHMETIC_FN(div, div);
113-
DEFINE_ARITHMETIC_FN(pow, pow);
108+
DEFINE_BINARY_OP_FN(mul);
109+
DEFINE_BINARY_OP_FN(div);
110+
DEFINE_BINARY_OP_FN(pow);
114111

115112
REGISTER_OPERATORS {
116113
VK_REGISTER_OP(aten.add.Tensor, add);
117114
VK_REGISTER_OP(aten.sub.Tensor, sub);
118115
VK_REGISTER_OP(aten.mul.Tensor, mul);
119116
VK_REGISTER_OP(aten.div.Tensor, div);
120-
VK_REGISTER_OP(aten.div.Tensor_mode, floor_div);
117+
VK_REGISTER_OP(aten.div.Tensor_mode, floor_divide);
121118
VK_REGISTER_OP(aten.pow.Tensor_Tensor, pow);
122119
}
123120

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,14 @@ def process_getattr_node(self, node: Node) -> None:
218218
self.create_tensor_values(node)
219219

220220
def process_output_node(self, node: Node) -> None:
221-
if node.all_input_nodes[0] not in self.node_to_value_ids:
222-
raise AssertionError(
223-
"Cannot find input to output node in node_to_value_ids. This means the "
224-
"output node is being serialized before its corresponding internal node "
225-
"which is not allowed."
226-
)
227-
self.output_ids.append(self.node_to_value_ids[node.all_input_nodes[0]])
221+
for out_node in node.all_input_nodes:
222+
if out_node not in self.node_to_value_ids:
223+
raise AssertionError(
224+
"Cannot find input to output node in node_to_value_ids. This means "
225+
"the output node is being serialized before its corresponding "
226+
"internal node which is not allowed."
227+
)
228+
self.output_ids.append(self.node_to_value_ids[out_node])
228229

229230
def process_node(self, node: Node) -> None:
230231
if node.op == "placeholder":

backends/vulkan/targets.bzl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
1-
load("@fbsource//tools/build_defs:fbsource_utils.bzl", "is_fbcode")
2-
load("@fbsource//tools/build_defs:glob_defs.bzl", "subdir_glob")
31
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
42

5-
def get_glsl_image_format():
6-
if native.read_config("pt", "vulkan_full_precision", "0") == "0":
7-
return "rgba16f"
8-
return "rgba32f"
9-
10-
def vulkan_spv_shader_lib(name, spv_filegroup):
3+
def vulkan_spv_shader_lib(name, spv_filegroups, is_fbcode = False):
114
gen_aten_vulkan_spv_target = "//caffe2/tools:gen_aten_vulkan_spv_bin"
125
glslc_path = "//caffe2/fb/vulkan/dotslash:glslc"
13-
if is_fbcode():
6+
if is_fbcode:
147
gen_aten_vulkan_spv_target = "//caffe2:gen_vulkan_spv_bin"
158
glslc_path = "//caffe2/fb/vulkan/tools:glslc"
169

10+
glsl_paths = []
11+
12+
# TODO(ssjia): remove the need for subpath once subdir_glob is enabled in OSS
13+
for target, subpath in spv_filegroups.items():
14+
glsl_paths.append("$(location {})/{}".format(target, subpath))
15+
1716
genrule_cmd = [
1817
"$(exe {})".format(gen_aten_vulkan_spv_target),
19-
"--glsl-paths $(location {})".format(spv_filegroup),
20-
"--output-path $OUT --env FLOAT_IMAGE_FORMAT={}".format(get_glsl_image_format()),
18+
"--glsl-paths {}".format(" ".join(glsl_paths)),
19+
"--output-path $OUT",
2120
"--glslc-path=$(exe {})".format(glslc_path),
2221
"--tmp-dir-path=$OUT",
2322
]
@@ -49,7 +48,7 @@ def vulkan_spv_shader_lib(name, spv_filegroup):
4948
],
5049
)
5150

52-
def define_common_targets():
51+
def define_common_targets(is_fbcode = False):
5352
runtime.genrule(
5453
name = "gen_vk_delegate_schema",
5554
srcs = [
@@ -89,14 +88,17 @@ def define_common_targets():
8988

9089
runtime.filegroup(
9190
name = "vulkan_graph_runtime_shaders",
92-
srcs = subdir_glob([
93-
("runtime/graph/ops/glsl", "*"),
91+
srcs = native.glob([
92+
"runtime/graph/ops/glsl/*",
9493
]),
9594
)
9695

9796
vulkan_spv_shader_lib(
9897
name = "vulkan_graph_runtime_shaderlib",
99-
spv_filegroup = ":vulkan_graph_runtime_shaders",
98+
spv_filegroups = {
99+
":vulkan_graph_runtime_shaders": "runtime/graph/ops/glsl",
100+
},
101+
is_fbcode = is_fbcode,
100102
)
101103

102104
runtime.cxx_library(

0 commit comments

Comments
 (0)