Skip to content

Commit 3ddce46

Browse files
committed
[ET-VK] Allow overwriting local workgroup size
Introduce a `GraphConfig` toggle following the convention of `storage_type` and `memory_layout`. Differential Revision: [D58957058](https://our.internmc.facebook.com/intern/diff/D58957058/) [ghstack-poisoned]
1 parent 039b258 commit 3ddce46

File tree

3 files changed

+17
-0
lines changed

3 files changed

+17
-0
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,10 @@ api::utils::uvec3 ComputeGraph::create_global_wg_size(const ValueRef idx) {
320320
}
321321

322322
api::utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) {
323+
if (config_.enable_local_wg_size_override) {
324+
return config_.local_wg_size_override;
325+
}
326+
323327
if (is_buffer_storage(idx)) {
324328
return {64u, 1u, 1u};
325329
}

backends/vulkan/runtime/graph/GraphConfig.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ GraphConfig::GraphConfig() {
6060
// QueryPool objects are used to measure execution times of individual shader
6161
// dispatches. By default, this functionality is disabled.
6262
enable_querypool = false;
63+
64+
enable_local_wg_size_override = false;
65+
local_wg_size_override = {};
6366
}
6467

6568
void GraphConfig::set_storage_type_override(api::StorageType storage_type) {
@@ -73,4 +76,10 @@ void GraphConfig::set_memory_layout_override(
7376
memory_layout_override = memory_layout;
7477
}
7578

79+
void GraphConfig::set_local_wg_size_override(
80+
const api::utils::uvec3& local_wg_size) {
81+
enable_local_wg_size_override = true;
82+
local_wg_size_override = local_wg_size;
83+
}
84+
7685
} // namespace vkcompute

backends/vulkan/runtime/graph/GraphConfig.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,15 @@ struct GraphConfig final {
3030

3131
bool enable_querypool;
3232

33+
bool enable_local_wg_size_override;
34+
api::utils::uvec3 local_wg_size_override;
35+
3336
// Generate a default graph config with pre-configured settings
3437
explicit GraphConfig();
3538

3639
void set_storage_type_override(api::StorageType storage_type);
3740
void set_memory_layout_override(api::GPUMemoryLayout memory_layout);
41+
void set_local_wg_size_override(const api::utils::uvec3& local_wg_size);
3842
};
3943

4044
} // namespace vkcompute

0 commit comments

Comments
 (0)