Skip to content

Commit b808ae9

Browse files
Revert "[SYCL] Add generic impl for bf16 fabs" (#8061)
Reverts #7959 See #8040 (comment)
1 parent 3e4717a commit b808ae9

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,21 +52,18 @@ std::enable_if_t<std::is_same<T, bfloat16>::value, T> fabs(T x) {
5252
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
5353
return oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
5454
#else
55-
if (!isnan(x)) {
56-
const static oneapi::detail::Bfloat16StorageT SignMask = 0x8000;
57-
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
58-
x = ((XBits & SignMask) == SignMask)
59-
? oneapi::detail::bitsToBfloat16(XBits & ~SignMask)
60-
: x;
61-
}
62-
return x;
55+
std::ignore = x;
56+
throw runtime_error(
57+
"bfloat16 math functions are not currently supported on the host device.",
58+
PI_ERROR_INVALID_DEVICE);
6359
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
6460
}
6561

6662
template <size_t N>
6763
sycl::marray<bfloat16, N> fabs(sycl::marray<bfloat16, N> x) {
68-
sycl::marray<bfloat16, N> res;
6964
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
65+
sycl::marray<bfloat16, N> res;
66+
7067
for (size_t i = 0; i < N / 2; i++) {
7168
auto partial_res = __clc_fabs(detail::to_uint32_t(x, i * 2));
7269
std::memcpy(&res[i * 2], &partial_res, sizeof(uint32_t));
@@ -77,12 +74,13 @@ sycl::marray<bfloat16, N> fabs(sycl::marray<bfloat16, N> x) {
7774
oneapi::detail::bfloat16ToBits(x[N - 1]);
7875
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fabs(XBits));
7976
}
77+
return res;
8078
#else
81-
for (size_t i = 0; i < N; i++) {
82-
res[i] = fabs(x[i]);
83-
}
79+
std::ignore = x;
80+
throw runtime_error(
81+
"bfloat16 math functions are not currently supported on the host device.",
82+
PI_ERROR_INVALID_DEVICE);
8483
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
85-
return res;
8684
}
8785

8886
template <typename T>

0 commit comments

Comments
 (0)