Skip to content

Commit fe20be9

Browse files
authored
Calculating axis's local wg size based on global workload and making it as close as possible to warp size of 32.
Differential Revision: D64418632 Pull Request resolved: #6409
1 parent 2c2e527 commit fe20be9

File tree

1 file changed

+31
-10
lines changed

1 file changed

+31
-10
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -485,24 +485,45 @@ utils::uvec3 ComputeGraph::create_local_wg_size(
485485
return config_.local_wg_size_override;
486486
}
487487

488-
utils::uvec3 local_group_size = {4, 4, 4};
488+
// array containing axis index and global workgroup size
489+
std::pair<uint32_t, uint32_t> global_wg_size_desc[] = {
490+
{0u, global_wg_size[0]},
491+
{1u, global_wg_size[1]},
492+
{2u, global_wg_size[2]}};
493+
494+
// sort the global workgroup size in descending order
495+
if (global_wg_size_desc[0].second < global_wg_size_desc[1].second) {
496+
std::swap(global_wg_size_desc[0], global_wg_size_desc[1]);
497+
}
498+
if (global_wg_size_desc[1].second < global_wg_size_desc[2].second) {
499+
std::swap(global_wg_size_desc[1], global_wg_size_desc[2]);
500+
}
501+
if (global_wg_size_desc[0].second < global_wg_size_desc[1].second) {
502+
std::swap(global_wg_size_desc[0], global_wg_size_desc[1]);
503+
}
489504

490-
if (global_wg_size[2u] == 1) {
491-
if (global_wg_size[1u] == 1) {
505+
utils::uvec3 local_group_size = {
506+
8,
507+
std::max(1u, std::min(4u, global_wg_size_desc[1].second)),
508+
std::max(1u, std::min(2u, global_wg_size_desc[2].second))};
509+
510+
if (global_wg_size_desc[2u].second == 1) {
511+
if (global_wg_size_desc[1u].second == 1) {
492512
local_group_size[0u] = 64;
493513
local_group_size[1u] = 1;
494-
local_group_size[2u] = 1;
495-
} else if (global_wg_size[1u] < 8) {
514+
} else if (global_wg_size_desc[1u].second % 4 == 0) {
496515
local_group_size[0u] = 16;
497516
local_group_size[1u] = 4;
498-
local_group_size[2u] = 1;
499517
} else {
500-
local_group_size[0u] = 8;
501-
local_group_size[1u] = 8;
502-
local_group_size[2u] = 1;
518+
local_group_size[0u] = 32;
519+
local_group_size[1u] = 2;
503520
}
504521
}
505-
return local_group_size;
522+
523+
return {
524+
local_group_size[global_wg_size_desc[0].first],
525+
local_group_size[global_wg_size_desc[1].first],
526+
local_group_size[global_wg_size_desc[2].first]};
506527
}
507528

508529
utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) {

0 commit comments

Comments
 (0)