Skip to content

Commit aed32c4

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Use lazy descriptor pool allocation (#2285)
Summary: Pull Request resolved: #2285 ## Context In Vulkan, memory for Descriptor Sets (which are used to bind data to shader arguments) must be pre-allocated. Previously, the convention is that a large number of descriptor sets are allocated upon creation of a Vulkan Context. While this worked well in Lite Interpreter, where only a global vulkan context is used, it will lead to overallocating descriptor sets in the Vulkan Delegate, where every `ComputeGraph` has its own dedicated Context. pytorch/pytorch#121134 allows the Descriptor Set pool to be initialized in a deferred fashion. This means that a ComputeGraph can count the total number of descriptors needed across all the compute shaders that will be encoded, and then allocate a Descriptor Set Pool of the appropriate size. ## Implementation Overview 1. When constructing `ComputeGraph`, make sure that the descriptor pool config contains 0 for number of max sets. This will ensure that no descriptor pool will be initialized when constructing the graph's `api::Context` instance 2. When building the graph, `ExecuteNode` and `PrepackNode` will call `graph.update_descriptor_counts(shader)` upon construction, which allows `ComputeGraph` to count the total number of descriptor sets needed. 3. There is a separate descriptor count object for prepack and execute, since they correspond to different command buffers. 4. Before encoding any command buffers, call `graph.prepare()` which will construct a descriptor pool config from the descriptor counts. ## Notes One interesting finding is that I had to apply a safety factor to the descriptor counts to prevent the pool from running out of memory. This was reproducible on both Linux and Android. A more robust design, i.e. as discussed [here](https://www.reddit.com/r/vulkan/comments/17v66fi/question_about_descriptor_pool_allocations/) may be to maintain separate descriptor pools for each layout type. We should revisit this refactor at a later time. bypass-github-export-checks Reviewed By: jorgep31415 Differential Revision: D54603935 fbshipit-source-id: eb04403b5f0967d69b390153c778b58bd940004e
1 parent 0570294 commit aed32c4

File tree

12 files changed

+198
-88
lines changed

12 files changed

+198
-88
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -62,39 +62,6 @@ api::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) {
6262
}
6363
}
6464

65-
GraphConfig generate_config() {
66-
const uint32_t submit_frequency = UINT32_MAX;
67-
68-
const api::CommandPoolConfig cmd_config{
69-
4u, // cmdPoolInitialSize
70-
2u, // cmdPoolBatchSize
71-
};
72-
73-
const api::DescriptorPoolConfig descriptor_pool_config{
74-
1024u, // descriptorPoolMaxSets
75-
1024u, // descriptorUniformBufferCount
76-
1024u, // descriptorStorageBufferCount
77-
1024u, // descriptorCombinedSamplerCount
78-
1024u, // descriptorStorageImageCount
79-
32u, // descriptorPileSizes
80-
};
81-
82-
const api::QueryPoolConfig query_pool_config{};
83-
84-
const api::ContextConfig context_config{
85-
submit_frequency, // cmdSubmitFrequency
86-
cmd_config, // cmdPoolConfig
87-
descriptor_pool_config, // descriptorPoolConfig
88-
query_pool_config, // queryPoolConfig
89-
};
90-
91-
const GraphConfig graph_config{
92-
context_config,
93-
};
94-
95-
return graph_config;
96-
}
97-
9865
class GraphBuilder {
9966
ComputeGraph* compute_graph_;
10067
VkGraphPtr flatbuffer_;
@@ -269,6 +236,8 @@ class VulkanBackend final : public PyTorchBackendInterface {
269236

270237
builder.build_graph();
271238

239+
compute_graph->prepare();
240+
272241
compute_graph->encode_prepack();
273242
compute_graph->prepack();
274243

@@ -284,7 +253,7 @@ class VulkanBackend final : public PyTorchBackendInterface {
284253
ComputeGraph* compute_graph = ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(
285254
context.get_runtime_allocator(), ComputeGraph);
286255

287-
new (compute_graph) ComputeGraph(generate_config());
256+
new (compute_graph) ComputeGraph(GraphConfig());
288257

289258
Error err = compileModel(processed->data(), compute_graph);
290259

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ namespace vulkan {
1818

1919
ComputeGraph::ComputeGraph(GraphConfig config)
2020
: config_{config},
21+
prepack_descriptor_counts_{},
22+
execute_descriptor_counts_{},
2123
context_{new api::Context(
2224
api::runtime()->default_adapter_i(),
2325
config_.contextConfig)},
@@ -27,6 +29,19 @@ ComputeGraph::ComputeGraph(GraphConfig config)
2729
execute_nodes_{},
2830
inputs_{},
2931
outputs_{} {
32+
// Ensure that descriptor counts are initialized to 0
33+
prepack_descriptor_counts_.descriptorPoolMaxSets = 0;
34+
prepack_descriptor_counts_.descriptorUniformBufferCount = 0;
35+
prepack_descriptor_counts_.descriptorStorageBufferCount = 0;
36+
prepack_descriptor_counts_.descriptorCombinedSamplerCount = 0;
37+
prepack_descriptor_counts_.descriptorStorageImageCount = 0;
38+
39+
execute_descriptor_counts_.descriptorPoolMaxSets = 0;
40+
execute_descriptor_counts_.descriptorUniformBufferCount = 0;
41+
execute_descriptor_counts_.descriptorStorageBufferCount = 0;
42+
execute_descriptor_counts_.descriptorCombinedSamplerCount = 0;
43+
execute_descriptor_counts_.descriptorStorageImageCount = 0;
44+
3045
context_->set_cmd(/*reusable = */ true);
3146
}
3247

@@ -39,6 +54,33 @@ ComputeGraph::~ComputeGraph() {
3954
context_->flush();
4055
}
4156

57+
void ComputeGraph::update_descriptor_counts(
58+
const api::ShaderInfo& shader_info,
59+
bool execute) {
60+
api::DescriptorPoolConfig* config =
61+
execute ? &execute_descriptor_counts_ : &prepack_descriptor_counts_;
62+
63+
config->descriptorPoolMaxSets += 1;
64+
for (const VkDescriptorType arg_type : shader_info.kernel_layout) {
65+
switch (arg_type) {
66+
case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
67+
config->descriptorUniformBufferCount += 1;
68+
break;
69+
case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER:
70+
config->descriptorStorageBufferCount += 1;
71+
break;
72+
case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
73+
config->descriptorCombinedSamplerCount += 1;
74+
break;
75+
case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE:
76+
config->descriptorStorageImageCount += 1;
77+
break;
78+
default:
79+
VK_THROW("Unsupported descriptor type!");
80+
}
81+
}
82+
}
83+
4284
ValueRef ComputeGraph::add_tensor(
4385
const std::vector<int64_t>& sizes,
4486
const api::ScalarType dtype,
@@ -138,6 +180,29 @@ void ComputeGraph::copy_from_staging(
138180
copy_staging_to_ptr(staging, data, nbytes);
139181
}
140182

183+
void ComputeGraph::prepare() {
184+
#define MERGE_FIELD(field) \
185+
static_cast<uint32_t>(std::ceil( \
186+
std::max( \
187+
execute_descriptor_counts_.field, \
188+
prepack_descriptor_counts_.field) * \
189+
config_.descriptorPoolSafetyFactor))
190+
191+
api::DescriptorPoolConfig config{
192+
MERGE_FIELD(descriptorPoolMaxSets),
193+
MERGE_FIELD(descriptorUniformBufferCount),
194+
MERGE_FIELD(descriptorStorageBufferCount),
195+
MERGE_FIELD(descriptorCombinedSamplerCount),
196+
MERGE_FIELD(descriptorStorageImageCount),
197+
1u,
198+
};
199+
200+
if (!context_->descriptor_pool()) {
201+
context_->descriptor_pool().init(config);
202+
}
203+
#undef MERGE_FIELD
204+
}
205+
141206
void ComputeGraph::encode_prepack() {
142207
for (std::unique_ptr<PrepackNode>& node : prepack_nodes_) {
143208
node->encode(this);

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ class ComputeGraph final {
6060

6161
private:
6262
GraphConfig config_;
63+
api::DescriptorPoolConfig prepack_descriptor_counts_;
64+
api::DescriptorPoolConfig execute_descriptor_counts_;
65+
6366
std::unique_ptr<api::Context> context_;
6467
std::vector<SharedObject> shared_objects_;
6568
std::vector<Value> values_;
@@ -87,6 +90,10 @@ class ComputeGraph final {
8790
return outputs_;
8891
}
8992

93+
void update_descriptor_counts(
94+
const api::ShaderInfo& shader_info,
95+
bool execute);
96+
9097
/*
9198
* Returns the value at a particular reference
9299
*/
@@ -163,6 +170,12 @@ class ComputeGraph final {
163170

164171
SharedObject& get_shared_object(const int64_t idx);
165172

173+
//
174+
// Graph Preparation
175+
//
176+
177+
void prepare();
178+
166179
//
167180
// Input/Output
168181
//
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/GraphConfig.h>
10+
11+
namespace at {
12+
namespace native {
13+
namespace vulkan {
14+
15+
GraphConfig::GraphConfig() {
16+
// No automatic submissions
17+
const uint32_t submit_frequency = UINT32_MAX;
18+
19+
// Only one command buffer will be encoded at a time
20+
const api::CommandPoolConfig cmd_config{
21+
1u, // cmdPoolInitialSize
22+
1u, // cmdPoolBatchSize
23+
};
24+
25+
// Use lazy descriptor pool initialization by default; the graph runtime will
26+
// tally up the number of descriptor sets needed while building the graph and
27+
// trigger descriptor pool initialization with exact sizes before encoding the
28+
// command buffer.
29+
const api::DescriptorPoolConfig descriptor_pool_config{
30+
0u, // descriptorPoolMaxSets
31+
0u, // descriptorUniformBufferCount
32+
0u, // descriptorStorageBufferCount
33+
0u, // descriptorCombinedSamplerCount
34+
0u, // descriptorStorageImageCount
35+
0u, // descriptorPileSizes
36+
};
37+
38+
const api::QueryPoolConfig query_pool_config{};
39+
40+
const api::ContextConfig context_config{
41+
submit_frequency, // cmdSubmitFrequency
42+
cmd_config, // cmdPoolConfig
43+
descriptor_pool_config, // descriptorPoolConfig
44+
query_pool_config, // queryPoolConfig
45+
};
46+
47+
contextConfig = context_config;
48+
49+
// Empirically selected safety factor. If descriptor pools start running out
50+
// of memory, increase this safety factor.
51+
descriptorPoolSafetyFactor = 1.25;
52+
}
53+
54+
} // namespace vulkan
55+
} // namespace native
56+
} // namespace at

backends/vulkan/runtime/graph/GraphConfig.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@ namespace vulkan {
1818

1919
struct GraphConfig final {
2020
api::ContextConfig contextConfig;
21+
22+
// Creating a descriptor pool with exactly the number of descriptors tallied
23+
// by iterating through the shader layouts of shaders used in the graph risks
24+
// the descriptor pool running out of memory, therefore apply a safety factor
25+
// to descriptor counts when creating the descriptor pool to mitigate this
26+
// risk.
27+
float descriptorPoolSafetyFactor;
28+
29+
// Generate a default graph config with pre-configured settings
30+
explicit GraphConfig();
2131
};
2232

2333
} // namespace vulkan

backends/vulkan/runtime/graph/ops/ExecuteNode.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,21 @@ namespace at {
1616
namespace native {
1717
namespace vulkan {
1818

19+
ExecuteNode::ExecuteNode(
20+
ComputeGraph& graph,
21+
const api::ShaderInfo& shader,
22+
const api::utils::uvec3& global_workgroup_size,
23+
const api::utils::uvec3& local_workgroup_size,
24+
const std::vector<ArgGroup>& args,
25+
api::UniformParamsBuffer&& params)
26+
: shader_(shader),
27+
global_workgroup_size_(global_workgroup_size),
28+
local_workgroup_size_(local_workgroup_size),
29+
args_(args),
30+
params_(std::move(params)) {
31+
graph.update_descriptor_counts(shader, /*execute = */ true);
32+
}
33+
1934
void ExecuteNode::encode(ComputeGraph* graph) {
2035
api::Context* const context = graph->context();
2136
api::PipelineBarrier pipeline_barrier{};

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,12 @@ class ExecuteNode final {
5050

5151
public:
5252
ExecuteNode(
53+
ComputeGraph& graph,
5354
const api::ShaderInfo& shader,
5455
const api::utils::uvec3& global_workgroup_size,
5556
const api::utils::uvec3& local_workgroup_size,
5657
const std::vector<ArgGroup>& args,
57-
api::UniformParamsBuffer&& params)
58-
: shader_(shader),
59-
global_workgroup_size_(global_workgroup_size),
60-
local_workgroup_size_(local_workgroup_size),
61-
args_(args),
62-
params_(std::move(params)) {}
58+
api::UniformParamsBuffer&& params);
6359

6460
~ExecuteNode() = default;
6561

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,23 @@ namespace at {
1717
namespace native {
1818
namespace vulkan {
1919

20+
PrepackNode::PrepackNode(
21+
ComputeGraph& graph,
22+
const api::ShaderInfo& shader,
23+
const api::utils::uvec3& global_workgroup_size,
24+
const api::utils::uvec3& local_workgroup_size,
25+
const ValueRef tref,
26+
const ValueRef packed,
27+
api::UniformParamsBuffer&& params)
28+
: shader_(shader),
29+
global_workgroup_size_(global_workgroup_size),
30+
local_workgroup_size_(local_workgroup_size),
31+
tref_(tref),
32+
packed_(packed),
33+
params_(std::move(params)) {
34+
graph.update_descriptor_counts(shader, /*execute = */ false);
35+
}
36+
2037
void PrepackNode::encode(ComputeGraph* graph) {
2138
api::Context* const context = graph->context();
2239
api::PipelineBarrier pipeline_barrier{};

backends/vulkan/runtime/graph/ops/PrepackNode.h

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,13 @@ class PrepackNode final {
3333

3434
public:
3535
PrepackNode(
36+
ComputeGraph& graph,
3637
const api::ShaderInfo& shader,
3738
const api::utils::uvec3& global_workgroup_size,
3839
const api::utils::uvec3& local_workgroup_size,
3940
const ValueRef tref,
4041
const ValueRef packed,
41-
api::UniformParamsBuffer&& params)
42-
: shader_(shader),
43-
global_workgroup_size_(global_workgroup_size),
44-
local_workgroup_size_(local_workgroup_size),
45-
tref_(tref),
46-
packed_(packed),
47-
params_(std::move(params)) {}
42+
api::UniformParamsBuffer&& params);
4843

4944
~PrepackNode() = default;
5045

backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ void add_arithmetic_node(
7272
api::UniformParamsBuffer params(graph.context(), block);
7373

7474
graph.execute_nodes().emplace_back(new ExecuteNode(
75+
graph,
7576
shader,
7677
global_size,
7778
local_size,

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ void add_staging_to_tensor_node(
4848
graph.context(), create_staging_params(t_out));
4949

5050
graph.execute_nodes().emplace_back(new ExecuteNode(
51+
graph,
5152
shader,
5253
global_size,
5354
local_size,
@@ -90,6 +91,7 @@ void add_tensor_to_staging_node(
9091
}
9192

9293
graph.execute_nodes().emplace_back(new ExecuteNode(
94+
graph,
9395
shader,
9496
global_size,
9597
local_size,
@@ -112,7 +114,7 @@ ValueRef prepack(ComputeGraph& graph, const ValueRef vref) {
112114
api::UniformParamsBuffer params(graph.context(), sp);
113115

114116
graph.prepack_nodes().emplace_back(new PrepackNode(
115-
shader, global_size, local_size, vref, v, std::move(params)));
117+
graph, shader, global_size, local_size, vref, v, std::move(params)));
116118

117119
return v;
118120
}

0 commit comments

Comments
 (0)