Skip to content

Commit 9a80b47

Browse files
Merge pull request #1498 from IntelPython/fix-for-complex-sorting-test-failure
Fix for complex sorting test failure
2 parents b84d5f8 + af3ed63 commit 9a80b47

File tree

2 files changed

+31
-33
lines changed

2 files changed

+31
-33
lines changed

dpctl/tensor/libtensor/include/kernels/sorting.hpp

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -539,34 +539,32 @@ sort_over_work_group_contig_impl(sycl::queue &q,
539539
sycl::group_barrier(it.get_group());
540540

541541
bool data_in_temp = false;
542-
size_t sorted_size = 1;
543-
while (true) {
544-
const size_t nelems_sorted_so_far = sorted_size * chunk;
545-
if (nelems_sorted_so_far < wg_chunk_size) {
546-
const size_t q = (lid / sorted_size);
547-
const size_t start_1 =
548-
sycl::min(2 * nelems_sorted_so_far * q, wg_chunk_size);
549-
const size_t end_1 = sycl::min(
550-
start_1 + nelems_sorted_so_far, wg_chunk_size);
551-
const size_t end_2 =
552-
sycl::min(end_1 + nelems_sorted_so_far, wg_chunk_size);
553-
const size_t offset = chunk * (lid - q * sorted_size);
554-
555-
if (data_in_temp) {
556-
merge_impl(offset, scratch_space, work_space, start_1,
557-
end_1, end_2, start_1, comp, chunk);
558-
}
559-
else {
560-
merge_impl(offset, work_space, scratch_space, start_1,
561-
end_1, end_2, start_1, comp, chunk);
562-
}
563-
sycl::group_barrier(it.get_group());
564-
565-
data_in_temp = !data_in_temp;
566-
sorted_size *= 2;
542+
size_t n_chunks_merged = 1;
543+
544+
// merge chunk while n_chunks_merged * chunk < wg_chunk_size
545+
const size_t max_chunks_merged = 1 + ((wg_chunk_size - 1) / chunk);
546+
for (; n_chunks_merged < max_chunks_merged;
547+
data_in_temp = !data_in_temp, n_chunks_merged *= 2)
548+
{
549+
const size_t nelems_sorted_so_far = n_chunks_merged * chunk;
550+
const size_t q = (lid / n_chunks_merged);
551+
const size_t start_1 =
552+
sycl::min(2 * nelems_sorted_so_far * q, wg_chunk_size);
553+
const size_t end_1 =
554+
sycl::min(start_1 + nelems_sorted_so_far, wg_chunk_size);
555+
const size_t end_2 =
556+
sycl::min(end_1 + nelems_sorted_so_far, wg_chunk_size);
557+
const size_t offset = chunk * (lid - q * n_chunks_merged);
558+
559+
if (data_in_temp) {
560+
merge_impl(offset, scratch_space, work_space, start_1,
561+
end_1, end_2, start_1, comp, chunk);
562+
}
563+
else {
564+
merge_impl(offset, work_space, scratch_space, start_1,
565+
end_1, end_2, start_1, comp, chunk);
567566
}
568-
else
569-
break;
567+
sycl::group_barrier(it.get_group());
570568
}
571569

572570
const auto &out_src = (data_in_temp) ? scratch_space : work_space;

dpctl/tensor/libtensor/source/sorting/sorting_common.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ template <typename fpT> struct ExtendedRealFPLess
4141
/* [R, nan] */
4242
bool operator()(const fpT v1, const fpT v2) const
4343
{
44-
return (!sycl::isnan(v1) && (sycl::isnan(v2) || (v1 < v2)));
44+
return (!std::isnan(v1) && (std::isnan(v2) || (v1 < v2)));
4545
}
4646
};
4747

4848
template <typename fpT> struct ExtendedRealFPGreater
4949
{
5050
bool operator()(const fpT v1, const fpT v2) const
5151
{
52-
return (!sycl::isnan(v2) && (sycl::isnan(v1) || (v2 < v1)));
52+
return (!std::isnan(v2) && (std::isnan(v1) || (v2 < v1)));
5353
}
5454
};
5555

@@ -64,14 +64,14 @@ template <typename cT> struct ExtendedComplexFPLess
6464
const realT real1 = std::real(v1);
6565
const realT real2 = std::real(v2);
6666

67-
const bool r1_nan = sycl::isnan(real1);
68-
const bool r2_nan = sycl::isnan(real2);
67+
const bool r1_nan = std::isnan(real1);
68+
const bool r2_nan = std::isnan(real2);
6969

7070
const realT imag1 = std::imag(v1);
7171
const realT imag2 = std::imag(v2);
7272

73-
const bool i1_nan = sycl::isnan(imag1);
74-
const bool i2_nan = sycl::isnan(imag2);
73+
const bool i1_nan = std::isnan(imag1);
74+
const bool i2_nan = std::isnan(imag2);
7575

7676
const int idx1 = ((r1_nan) ? 2 : 0) + ((i1_nan) ? 1 : 0);
7777
const int idx2 = ((r2_nan) ? 2 : 0) + ((i2_nan) ? 1 : 0);

0 commit comments

Comments
 (0)