Skip to content

Commit 1f10f35

Browse files
[SYCL] Fix nextafter with half on host (#10558)
Currently the host-side implementation of sycl::nextafter with sycl::half uses the float variant of std::nextafter. However, due to the conversion between half and float, the result may be unexpected. Likewise, KhronosGroup/SYCL-Docs#440 removes the reference to single-precision floating point results. Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 474461c commit 1f10f35

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

sycl/source/detail/builtins_math.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,23 @@ __SYCL_EXPORT s::cl_double sycl_host_nextafter(s::cl_double x,
538538
}
539539
__SYCL_EXPORT s::cl_half sycl_host_nextafter(s::cl_half x,
540540
s::cl_half y) __NOEXC {
541-
return std::nextafter(x, y);
541+
if (std::isnan(d::cast_if_host_half(x)))
542+
return x;
543+
if (std::isnan(d::cast_if_host_half(y)) || x == y)
544+
return y;
545+
546+
uint16_t x_bits = s::bit_cast<uint16_t>(x);
547+
uint16_t x_sign = x_bits & 0x8000;
548+
int16_t movement = (x > y ? -1 : 1) * (x_sign ? -1 : 1);
549+
if (x_bits == x_sign && movement == -1) {
550+
// Special case where we underflow in the decrement, in which case we turn
551+
// it around and flip the sign. The overflow case does not need special
552+
// handling.
553+
movement = 1;
554+
x_bits ^= 0x8000;
555+
}
556+
x_bits += movement;
557+
return s::bit_cast<s::cl_half>(x_bits);
542558
}
543559
MAKE_1V_2V(sycl_host_nextafter, s::cl_float, s::cl_float, s::cl_float)
544560
MAKE_1V_2V(sycl_host_nextafter, s::cl_double, s::cl_double, s::cl_double)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: %clangxx -fsycl %s -o %t.out
2+
// RUN: %t.out
3+
//
4+
// Checks that sycl::nextafter with sycl::half on host correctly converts based
5+
// on half-precision.
6+
7+
#include <sycl/sycl.hpp>
8+
9+
void check(uint16_t x, uint16_t y, uint16_t ref) {
10+
assert(sycl::nextafter(sycl::bit_cast<sycl::half>(x),
11+
sycl::bit_cast<sycl::half>(y)) ==
12+
sycl::bit_cast<sycl::half>(ref));
13+
}
14+
15+
int main() {
16+
check(0x0, 0x0, 0x0);
17+
check(0x1, 0x1, 0x1);
18+
check(0x8001, 0x8001, 0x8001);
19+
check(0x0, 0x1, 0x1);
20+
check(0x8000, 0x8001, 0x8001);
21+
check(0x0, 0x8001, 0x8001);
22+
check(0x8000, 0x1, 0x1);
23+
check(0x8001, 0x0, 0x0);
24+
check(0x1, 0x8000, 0x8000);
25+
check(0x8001, 0x1, 0x0);
26+
check(0x1, 0x8001, 0x8000);
27+
28+
std::cout << "Passed!" << std::endl;
29+
return 0;
30+
}

0 commit comments

Comments
 (0)