@@ -539,34 +539,32 @@ sort_over_work_group_contig_impl(sycl::queue &q,
539
539
sycl::group_barrier (it.get_group ());
540
540
541
541
bool data_in_temp = false ;
542
- size_t sorted_size = 1 ;
543
- while (true ) {
544
- const size_t nelems_sorted_so_far = sorted_size * chunk;
545
- if (nelems_sorted_so_far < wg_chunk_size) {
546
- const size_t q = (lid / sorted_size);
547
- const size_t start_1 =
548
- sycl::min (2 * nelems_sorted_so_far * q, wg_chunk_size);
549
- const size_t end_1 = sycl::min (
550
- start_1 + nelems_sorted_so_far, wg_chunk_size);
551
- const size_t end_2 =
552
- sycl::min (end_1 + nelems_sorted_so_far, wg_chunk_size);
553
- const size_t offset = chunk * (lid - q * sorted_size);
554
-
555
- if (data_in_temp) {
556
- merge_impl (offset, scratch_space, work_space, start_1,
557
- end_1, end_2, start_1, comp, chunk);
558
- }
559
- else {
560
- merge_impl (offset, work_space, scratch_space, start_1,
561
- end_1, end_2, start_1, comp, chunk);
562
- }
563
- sycl::group_barrier (it.get_group ());
564
-
565
- data_in_temp = !data_in_temp;
566
- sorted_size *= 2 ;
542
+ size_t n_chunks_merged = 1 ;
543
+
544
+ // merge chunk while n_chunks_merged * chunk < wg_chunk_size
545
+ const size_t max_chunks_merged = 1 + ((wg_chunk_size - 1 ) / chunk);
546
+ for (; n_chunks_merged < max_chunks_merged;
547
+ data_in_temp = !data_in_temp, n_chunks_merged *= 2 )
548
+ {
549
+ const size_t nelems_sorted_so_far = n_chunks_merged * chunk;
550
+ const size_t q = (lid / n_chunks_merged);
551
+ const size_t start_1 =
552
+ sycl::min (2 * nelems_sorted_so_far * q, wg_chunk_size);
553
+ const size_t end_1 =
554
+ sycl::min (start_1 + nelems_sorted_so_far, wg_chunk_size);
555
+ const size_t end_2 =
556
+ sycl::min (end_1 + nelems_sorted_so_far, wg_chunk_size);
557
+ const size_t offset = chunk * (lid - q * n_chunks_merged);
558
+
559
+ if (data_in_temp) {
560
+ merge_impl (offset, scratch_space, work_space, start_1,
561
+ end_1, end_2, start_1, comp, chunk);
562
+ }
563
+ else {
564
+ merge_impl (offset, work_space, scratch_space, start_1,
565
+ end_1, end_2, start_1, comp, chunk);
567
566
}
568
- else
569
- break ;
567
+ sycl::group_barrier (it.get_group ());
570
568
}
571
569
572
570
const auto &out_src = (data_in_temp) ? scratch_space : work_space;
0 commit comments