Skip to content

Commit 2b00cf9

Browse files
authored
[SYCL] Add generic fabs imp for bf16 (#8143)
Signed-off-by: jinge90 <[email protected]>
1 parent c64f88e commit 2b00cf9

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,18 +52,21 @@ std::enable_if_t<std::is_same<T, bfloat16>::value, T> fabs(T x) {
5252
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
5353
return oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
5454
#else
55-
std::ignore = x;
56-
throw runtime_error(
57-
"bfloat16 math functions are not currently supported on the host device.",
58-
PI_ERROR_INVALID_DEVICE);
55+
if (!isnan(x)) {
56+
const static oneapi::detail::Bfloat16StorageT SignMask = 0x8000;
57+
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
58+
x = ((XBits & SignMask) == SignMask)
59+
? oneapi::detail::bitsToBfloat16(XBits & ~SignMask)
60+
: x;
61+
}
62+
return x;
5963
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
6064
}
6165

6266
template <size_t N>
6367
sycl::marray<bfloat16, N> fabs(sycl::marray<bfloat16, N> x) {
64-
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
6568
sycl::marray<bfloat16, N> res;
66-
69+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
6770
for (size_t i = 0; i < N / 2; i++) {
6871
auto partial_res = __clc_fabs(detail::to_uint32_t(x, i * 2));
6972
std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
@@ -74,13 +77,12 @@ sycl::marray<bfloat16, N> fabs(sycl::marray<bfloat16, N> x) {
7477
oneapi::detail::bfloat16ToBits(x[N - 1]);
7578
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
7679
}
77-
return res;
7880
#else
79-
std::ignore = x;
80-
throw runtime_error(
81-
"bfloat16 math functions are not currently supported on the host device.",
82-
PI_ERROR_INVALID_DEVICE);
81+
for (size_t i = 0; i < N; i++) {
82+
res[i] = fabs(x[i]);
83+
}
8384
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
85+
return res;
8486
}
8587

8688
template <typename T>

0 commit comments

Comments
 (0)