@@ -52,21 +52,18 @@ std::enable_if_t<std::is_same<T, bfloat16>::value, T> fabs(T x) {
52
52
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
53
53
return oneapi::detail::bitsToBfloat16 (__clc_fabs (XBits));
54
54
#else
55
- if (!isnan (x)) {
56
- const static oneapi::detail::Bfloat16StorageT SignMask = 0x8000 ;
57
- oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
58
- x = ((XBits & SignMask) == SignMask)
59
- ? oneapi::detail::bitsToBfloat16 (XBits & ~SignMask)
60
- : x;
61
- }
62
- return x;
55
+ std::ignore = x;
56
+ throw runtime_error (
57
+ " bfloat16 math functions are not currently supported on the host device." ,
58
+ PI_ERROR_INVALID_DEVICE);
63
59
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
64
60
}
65
61
66
62
template <size_t N>
67
63
sycl::marray<bfloat16, N> fabs (sycl::marray<bfloat16, N> x) {
68
- sycl::marray<bfloat16, N> res;
69
64
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
65
+ sycl::marray<bfloat16, N> res;
66
+
70
67
for (size_t i = 0 ; i < N / 2 ; i++) {
71
68
auto partial_res = __clc_fabs (detail::to_uint32_t (x, i * 2 ));
72
69
std::memcpy (&res[i * 2 ], &partial_res, sizeof (uint32_t ));
@@ -77,12 +74,13 @@ sycl::marray<bfloat16, N> fabs(sycl::marray<bfloat16, N> x) {
77
74
oneapi::detail::bfloat16ToBits (x[N - 1 ]);
78
75
res[N - 1 ] = oneapi::detail::bitsToBfloat16 (__clc_fabs (XBits));
79
76
}
77
+ return res;
80
78
#else
81
- for (size_t i = 0 ; i < N; i++) {
82
- res[i] = fabs (x[i]);
83
- }
79
+ std::ignore = x;
80
+ throw runtime_error (
81
+ " bfloat16 math functions are not currently supported on the host device." ,
82
+ PI_ERROR_INVALID_DEVICE);
84
83
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
85
- return res;
86
84
}
87
85
88
86
template <typename T>
0 commit comments