Skip to content

Commit ad9f186

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Ensure descriptor set pools don't run out of memory (#2398)
Summary: Pull Request resolved: #2398 ## Context While testing a toy model with a large number of operators, I ran into an issue on my local Pixel 6 Android device where the descriptor pool was running out of memory. This changeset implements a simple fix to ensure that descriptor pools do not run into this issue. A longer term solution is to implement layout specific descriptor pools, but that is much more technically complex so go with this for now. ## Problem Details #2285 made it so that `ComputeGraph` could tally up the total number of descriptors needed and size the descriptor pools appropriately, but it seems that this is not compatible with certain Vulkan drivers. In the toy model, 1000 binary operators were added. Counting the descriptors required for the graph provides descriptor counts of ``` descriptorPoolMaxSets: 1255 descriptorUniformBufferCount: 5013 descriptorStorageBufferCount: 4 descriptorCombinedSamplerCount: 2504 descriptorStorageImageCount: 1254 ``` Which appears to be correct, however it appears that the descriptor pool runs out of memory due to an insufficient number of `descriptorStorageBufferCount`. The `descriptorStorageBufferCount` needs to be set at a surprisingly high number (approx ~1000) before the descriptor pool does not run out of memory. I'm not sure exactly what causes this behaviour, but it could be due to the implementation details of the driver. ## Solution Ensure that all descriptor counts are at greater than or equal to the maximum number of descriptor sets seems to work. Implement this as a temporary solution. ghstack-source-id: 218509680 bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: jorgep31415 Differential Revision: D54853788 fbshipit-source-id: 391e3d10a678672df9af96e3b6a8484453b039f1
1 parent 3507412 commit ad9f186

File tree

2 files changed

+61
-7
lines changed

2 files changed

+61
-7
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,12 +191,13 @@ void ComputeGraph::prepare() {
191191
prepack_descriptor_counts_.field) * \
192192
config_.descriptorPoolSafetyFactor))
193193

194+
uint32_t max_sets = MERGE_FIELD(descriptorPoolMaxSets);
194195
api::DescriptorPoolConfig config{
195-
MERGE_FIELD(descriptorPoolMaxSets),
196-
MERGE_FIELD(descriptorUniformBufferCount),
197-
MERGE_FIELD(descriptorStorageBufferCount),
198-
MERGE_FIELD(descriptorCombinedSamplerCount),
199-
MERGE_FIELD(descriptorStorageImageCount),
196+
max_sets,
197+
std::max(MERGE_FIELD(descriptorUniformBufferCount), max_sets),
198+
std::max(MERGE_FIELD(descriptorStorageBufferCount), max_sets),
199+
std::max(MERGE_FIELD(descriptorCombinedSamplerCount), max_sets),
200+
std::max(MERGE_FIELD(descriptorStorageImageCount), max_sets),
200201
1u,
201202
};
202203

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -671,10 +671,63 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
671671
EXTRACT_TENSOR(out);
672672

673673
// Sanity check that the values are correct
674-
int i = 0;
675674
for (const auto& val : data_out) {
676675
ASSERT_TRUE(val == val_out);
677-
++i;
676+
}
677+
}
678+
}
679+
680+
TEST(VulkanComputeGraphTest, test_large_graph) {
681+
GraphConfig config;
682+
ComputeGraph graph(config);
683+
684+
int64_t input_w = 256;
685+
int64_t input_h = 256;
686+
int64_t input_c = 8;
687+
688+
std::vector<int64_t> size_big = {input_c, input_h, input_w};
689+
std::vector<int64_t> size_small = {input_c, input_h, 1};
690+
691+
// Build graph
692+
693+
IOValueRef a = graph.add_input_tensor(size_big, api::kFloat, 2);
694+
IOValueRef b = graph.add_input_tensor(size_small, api::kFloat, 4);
695+
696+
ValueRef c = graph.add_tensor(size_big, api::kFloat, 6);
697+
698+
auto addFn = VK_GET_OP_FN("aten.add.Tensor");
699+
addFn(graph, {a.value, b.value, kDummyValueRef, c});
700+
701+
int n = 100;
702+
703+
for (int i = 0; i < n; i++) {
704+
addFn(graph, {c, b.value, kDummyValueRef, a.value});
705+
706+
addFn(graph, {a.value, b.value, kDummyValueRef, c});
707+
}
708+
709+
IOValueRef out = {};
710+
out.value = c;
711+
out.staging = graph.set_output_tensor(out.value);
712+
713+
graph.prepare();
714+
graph.encode_execute();
715+
716+
for (int i = 0; i < 10; i++) {
717+
float val_a = 1.0f;
718+
float val_b = 2.0f;
719+
720+
float val_e = val_a + val_b * (2 * n + 1);
721+
722+
fill_vtensor(graph, a, val_a);
723+
fill_vtensor(graph, b, val_b);
724+
725+
graph.execute();
726+
727+
EXTRACT_TENSOR(out);
728+
729+
for (const auto& val : data_out) {
730+
EXPECT_TRUE(val == val_e);
678731
}
679732
}
680733
}

0 commit comments

Comments
 (0)