@@ -52,18 +52,21 @@ std::enable_if_t<std::is_same<T, bfloat16>::value, T> fabs(T x) {
52
52
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
53
53
return oneapi::detail::bitsToBfloat16 (__clc_fabs (XBits));
54
54
#else
55
- std::ignore = x;
56
- throw runtime_error (
57
- " bfloat16 math functions are not currently supported on the host device." ,
58
- PI_ERROR_INVALID_DEVICE);
55
+ if (!isnan (x)) {
56
+ const static oneapi::detail::Bfloat16StorageT SignMask = 0x8000 ;
57
+ oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
58
+ x = ((XBits & SignMask) == SignMask)
59
+ ? oneapi::detail::bitsToBfloat16 (XBits & ~SignMask)
60
+ : x;
61
+ }
62
+ return x;
59
63
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
60
64
}
61
65
62
66
template <size_t N>
63
67
sycl::marray<bfloat16, N> fabs (sycl::marray<bfloat16, N> x) {
64
- #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
65
68
sycl::marray<bfloat16, N> res;
66
-
69
+ # if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
67
70
for (size_t i = 0 ; i < N / 2 ; i++) {
68
71
auto partial_res = __clc_fabs (detail::to_uint32_t (x, i * 2 ));
69
72
std::memcpy (&res[i * 2 ], &partial_res, sizeof (uint32_t ));
@@ -74,13 +77,12 @@ sycl::marray<bfloat16, N> fabs(sycl::marray<bfloat16, N> x) {
74
77
oneapi::detail::bfloat16ToBits (x[N - 1 ]);
75
78
res[N - 1 ] = oneapi::detail::bitsToBfloat16 (__clc_fabs (XBits));
76
79
}
77
- return res;
78
80
#else
79
- std::ignore = x;
80
- throw runtime_error (
81
- " bfloat16 math functions are not currently supported on the host device." ,
82
- PI_ERROR_INVALID_DEVICE);
81
+ for (size_t i = 0 ; i < N; i++) {
82
+ res[i] = fabs (x[i]);
83
+ }
83
84
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
85
+ return res;
84
86
}
85
87
86
88
template <typename T>
0 commit comments