@@ -485,24 +485,45 @@ utils::uvec3 ComputeGraph::create_local_wg_size(
485
485
return config_.local_wg_size_override ;
486
486
}
487
487
488
- utils::uvec3 local_group_size = {4 , 4 , 4 };
488
+ // array containing axis index and global workgroup size
489
+ std::pair<uint32_t , uint32_t > global_wg_size_desc[] = {
490
+ {0u , global_wg_size[0 ]},
491
+ {1u , global_wg_size[1 ]},
492
+ {2u , global_wg_size[2 ]}};
493
+
494
+ // sort the global workgroup size in descending order
495
+ if (global_wg_size_desc[0 ].second < global_wg_size_desc[1 ].second ) {
496
+ std::swap (global_wg_size_desc[0 ], global_wg_size_desc[1 ]);
497
+ }
498
+ if (global_wg_size_desc[1 ].second < global_wg_size_desc[2 ].second ) {
499
+ std::swap (global_wg_size_desc[1 ], global_wg_size_desc[2 ]);
500
+ }
501
+ if (global_wg_size_desc[0 ].second < global_wg_size_desc[1 ].second ) {
502
+ std::swap (global_wg_size_desc[0 ], global_wg_size_desc[1 ]);
503
+ }
489
504
490
- if (global_wg_size[2u ] == 1 ) {
491
- if (global_wg_size[1u ] == 1 ) {
505
+ utils::uvec3 local_group_size = {
506
+ 8 ,
507
+ std::max (1u , std::min (4u , global_wg_size_desc[1 ].second )),
508
+ std::max (1u , std::min (2u , global_wg_size_desc[2 ].second ))};
509
+
510
+ if (global_wg_size_desc[2u ].second == 1 ) {
511
+ if (global_wg_size_desc[1u ].second == 1 ) {
492
512
local_group_size[0u ] = 64 ;
493
513
local_group_size[1u ] = 1 ;
494
- local_group_size[2u ] = 1 ;
495
- } else if (global_wg_size[1u ] < 8 ) {
514
+ } else if (global_wg_size_desc[1u ].second % 4 == 0 ) {
496
515
local_group_size[0u ] = 16 ;
497
516
local_group_size[1u ] = 4 ;
498
- local_group_size[2u ] = 1 ;
499
517
} else {
500
- local_group_size[0u ] = 8 ;
501
- local_group_size[1u ] = 8 ;
502
- local_group_size[2u ] = 1 ;
518
+ local_group_size[0u ] = 32 ;
519
+ local_group_size[1u ] = 2 ;
503
520
}
504
521
}
505
- return local_group_size;
522
+
523
+ return {
524
+ local_group_size[global_wg_size_desc[0 ].first ],
525
+ local_group_size[global_wg_size_desc[1 ].first ],
526
+ local_group_size[global_wg_size_desc[2 ].first ]};
506
527
}
507
528
508
529
utils::uvec3 ComputeGraph::create_local_wg_size (const ValueRef idx) {
0 commit comments