Skip to content

Commit 0ab003d

Browse files
Replace while with for loop in sorting kernel for efficiency
Less branching is generated this way.
1 parent 2cb368c commit 0ab003d

File tree

1 file changed

+25
-27
lines changed

1 file changed

+25
-27
lines changed

dpctl/tensor/libtensor/include/kernels/sorting.hpp

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -539,34 +539,32 @@ sort_over_work_group_contig_impl(sycl::queue &q,
539539
sycl::group_barrier(it.get_group());
540540

541541
bool data_in_temp = false;
542-
size_t sorted_size = 1;
543-
while (true) {
544-
const size_t nelems_sorted_so_far = sorted_size * chunk;
545-
if (nelems_sorted_so_far < wg_chunk_size) {
546-
const size_t q = (lid / sorted_size);
547-
const size_t start_1 =
548-
sycl::min(2 * nelems_sorted_so_far * q, wg_chunk_size);
549-
const size_t end_1 = sycl::min(
550-
start_1 + nelems_sorted_so_far, wg_chunk_size);
551-
const size_t end_2 =
552-
sycl::min(end_1 + nelems_sorted_so_far, wg_chunk_size);
553-
const size_t offset = chunk * (lid - q * sorted_size);
554-
555-
if (data_in_temp) {
556-
merge_impl(offset, scratch_space, work_space, start_1,
557-
end_1, end_2, start_1, comp, chunk);
558-
}
559-
else {
560-
merge_impl(offset, work_space, scratch_space, start_1,
561-
end_1, end_2, start_1, comp, chunk);
562-
}
563-
sycl::group_barrier(it.get_group());
564-
565-
data_in_temp = !data_in_temp;
566-
sorted_size *= 2;
542+
size_t n_chunks_merged = 1;
543+
544+
// merge chunk while n_chunks_merged * chunk < wg_chunk_size
545+
const size_t max_chunks_merged = 1 + ((wg_chunk_size - 1) / chunk);
546+
for (; n_chunks_merged < max_chunks_merged;
547+
data_in_temp = !data_in_temp, n_chunks_merged *= 2)
548+
{
549+
const size_t nelems_sorted_so_far = n_chunks_merged * chunk;
550+
const size_t q = (lid / n_chunks_merged);
551+
const size_t start_1 =
552+
sycl::min(2 * nelems_sorted_so_far * q, wg_chunk_size);
553+
const size_t end_1 =
554+
sycl::min(start_1 + nelems_sorted_so_far, wg_chunk_size);
555+
const size_t end_2 =
556+
sycl::min(end_1 + nelems_sorted_so_far, wg_chunk_size);
557+
const size_t offset = chunk * (lid - q * n_chunks_merged);
558+
559+
if (data_in_temp) {
560+
merge_impl(offset, scratch_space, work_space, start_1,
561+
end_1, end_2, start_1, comp, chunk);
562+
}
563+
else {
564+
merge_impl(offset, work_space, scratch_space, start_1,
565+
end_1, end_2, start_1, comp, chunk);
567566
}
568-
else
569-
break;
567+
sycl::group_barrier(it.get_group());
570568
}
571569

572570
const auto &out_src = (data_in_temp) ? scratch_space : work_space;

0 commit comments

Comments
 (0)