Skip to content

Commit eb1ed10

Browse files
jinge90bader
andauthored
[SYCL] Add bfloat16 generic implementation for fmax, fmin (#7732)
Signed:sign-off-by: jinge90 <[email protected]> Co-authored-by: Alexey Bader <[email protected]>
1 parent 25d05f3 commit eb1ed10

File tree

2 files changed

+52
-28
lines changed

2 files changed

+52
-28
lines changed

sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ uint32_t to_uint32_t(sycl::marray<bfloat16, N> x, size_t start) {
3030
}
3131
} // namespace detail
3232

33+
// According to bfloat16 format, NAN value's exponent field is 0xFF and
34+
// significand has non-zero bits.
35+
template <typename T>
36+
std::enable_if_t<std::is_same<T, bfloat16>::value, bool> isnan(T x) {
37+
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
38+
return (((XBits & 0x7F80) == 0x7F80) && (XBits & 0x7F)) ? true : false;
39+
}
40+
3341
template <typename T>
3442
std::enable_if_t<std::is_same<T, bfloat16>::value, T> fabs(T x) {
3543
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
@@ -74,20 +82,31 @@ std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmin(T x, T y) {
7482
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y);
7583
return oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits));
7684
#else
77-
std::ignore = x;
78-
std::ignore = y;
79-
throw runtime_error(
80-
"bfloat16 math functions are not currently supported on the host device.",
81-
PI_ERROR_INVALID_DEVICE);
85+
static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0;
86+
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
87+
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y);
88+
if (isnan(x) && isnan(y))
89+
return oneapi::detail::bitsToBfloat16(CanonicalNan);
90+
91+
if (isnan(x))
92+
return y;
93+
if (isnan(y))
94+
return x;
95+
if (((XBits | YBits) ==
96+
static_cast<oneapi::detail::Bfloat16StorageT>(0x8000)) &&
97+
!(XBits & YBits))
98+
return oneapi::detail::bitsToBfloat16(
99+
static_cast<oneapi::detail::Bfloat16StorageT>(0x8000));
100+
101+
return (x < y) ? x : y;
82102
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
83103
}
84104

85105
template <size_t N>
86106
sycl::marray<bfloat16, N> fmin(sycl::marray<bfloat16, N> x,
87107
sycl::marray<bfloat16, N> y) {
88-
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
89108
sycl::marray<bfloat16, N> res;
90-
109+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
91110
for (size_t i = 0; i < N / 2; i++) {
92111
auto partial_res = __clc_fmin(detail::to_uint32_t(x, i * 2),
93112
detail::to_uint32_t(y, i * 2));
@@ -101,15 +120,12 @@ sycl::marray<bfloat16, N> fmin(sycl::marray<bfloat16, N> x,
101120
oneapi::detail::bfloat16ToBits(y[N - 1]);
102121
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits));
103122
}
104-
105-
return res;
106123
#else
107-
std::ignore = x;
108-
std::ignore = y;
109-
throw runtime_error(
110-
"bfloat16 math functions are not currently supported on the host device.",
111-
PI_ERROR_INVALID_DEVICE);
124+
for (size_t i = 0; i < N; i++) {
125+
res[i] = fmin(x[i], y[i]);
126+
}
112127
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
128+
return res;
113129
}
114130

115131
template <typename T>
@@ -119,20 +135,30 @@ std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmax(T x, T y) {
119135
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y);
120136
return oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits));
121137
#else
122-
std::ignore = x;
123-
std::ignore = y;
124-
throw runtime_error(
125-
"bfloat16 math functions are not currently supported on the host device.",
126-
PI_ERROR_INVALID_DEVICE);
138+
static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0;
139+
oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x);
140+
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y);
141+
if (isnan(x) && isnan(y))
142+
return oneapi::detail::bitsToBfloat16(CanonicalNan);
143+
144+
if (isnan(x))
145+
return y;
146+
if (isnan(y))
147+
return x;
148+
if (((XBits | YBits) ==
149+
static_cast<oneapi::detail::Bfloat16StorageT>(0x8000)) &&
150+
!(XBits & YBits))
151+
return oneapi::detail::bitsToBfloat16(0);
152+
153+
return (x > y) ? x : y;
127154
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
128155
}
129156

130157
template <size_t N>
131158
sycl::marray<bfloat16, N> fmax(sycl::marray<bfloat16, N> x,
132159
sycl::marray<bfloat16, N> y) {
133-
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
134160
sycl::marray<bfloat16, N> res;
135-
161+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
136162
for (size_t i = 0; i < N / 2; i++) {
137163
auto partial_res = __clc_fmax(detail::to_uint32_t(x, i * 2),
138164
detail::to_uint32_t(y, i * 2));
@@ -146,14 +172,12 @@ sycl::marray<bfloat16, N> fmax(sycl::marray<bfloat16, N> x,
146172
oneapi::detail::bfloat16ToBits(y[N - 1]);
147173
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits));
148174
}
149-
return res;
150175
#else
151-
std::ignore = x;
152-
std::ignore = y;
153-
throw runtime_error(
154-
"bfloat16 math functions are not currently supported on the host device.",
155-
PI_ERROR_INVALID_DEVICE);
176+
for (size_t i = 0; i < N; i++) {
177+
res[i] = fmax(x[i], y[i]);
178+
}
156179
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
180+
return res;
157181
}
158182

159183
template <typename T>

sycl/include/sycl/ext/oneapi/experimental/sycl_complex.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1202,7 +1202,7 @@ SYCL_EXTERNAL complex<_Tp> acos(const complex<_Tp> &__x) {
12021202
}
12031203
if (sycl::isinf(__x.imag()))
12041204
return complex<_Tp>(__pi / _Tp(2), -__x.imag());
1205-
if (__x.real() == 0 && (__x.imag() == 0 || isnan(__x.imag())))
1205+
if (__x.real() == 0 && (__x.imag() == 0 || sycl::isnan(__x.imag())))
12061206
return complex<_Tp>(__pi / _Tp(2), -__x.imag());
12071207
complex<_Tp> __z = log(__x + sqrt(__sqr(__x) - _Tp(1)));
12081208
if (sycl::signbit(__x.imag()))

0 commit comments

Comments
 (0)