@@ -30,6 +30,14 @@ uint32_t to_uint32_t(sycl::marray<bfloat16, N> x, size_t start) {
30
30
}
31
31
} // namespace detail
32
32
33
+ // According to bfloat16 format, NAN value's exponent field is 0xFF and
34
+ // significand has non-zero bits.
35
+ template <typename T>
36
+ std::enable_if_t <std::is_same<T, bfloat16>::value, bool > isnan (T x) {
37
+ oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
38
+ return (((XBits & 0x7F80 ) == 0x7F80 ) && (XBits & 0x7F )) ? true : false ;
39
+ }
40
+
33
41
template <typename T>
34
42
std::enable_if_t <std::is_same<T, bfloat16>::value, T> fabs (T x) {
35
43
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
@@ -74,20 +82,31 @@ std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmin(T x, T y) {
74
82
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits (y);
75
83
return oneapi::detail::bitsToBfloat16 (__clc_fmin (XBits, YBits));
76
84
#else
77
- std::ignore = x;
78
- std::ignore = y;
79
- throw runtime_error (
80
- " bfloat16 math functions are not currently supported on the host device." ,
81
- PI_ERROR_INVALID_DEVICE);
85
+ static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0 ;
86
+ oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
87
+ oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits (y);
88
+ if (isnan (x) && isnan (y))
89
+ return oneapi::detail::bitsToBfloat16 (CanonicalNan);
90
+
91
+ if (isnan (x))
92
+ return y;
93
+ if (isnan (y))
94
+ return x;
95
+ if (((XBits | YBits) ==
96
+ static_cast <oneapi::detail::Bfloat16StorageT>(0x8000 )) &&
97
+ !(XBits & YBits))
98
+ return oneapi::detail::bitsToBfloat16 (
99
+ static_cast <oneapi::detail::Bfloat16StorageT>(0x8000 ));
100
+
101
+ return (x < y) ? x : y;
82
102
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
83
103
}
84
104
85
105
template <size_t N>
86
106
sycl::marray<bfloat16, N> fmin (sycl::marray<bfloat16, N> x,
87
107
sycl::marray<bfloat16, N> y) {
88
- #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
89
108
sycl::marray<bfloat16, N> res;
90
-
109
+ # if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
91
110
for (size_t i = 0 ; i < N / 2 ; i++) {
92
111
auto partial_res = __clc_fmin (detail::to_uint32_t (x, i * 2 ),
93
112
detail::to_uint32_t (y, i * 2 ));
@@ -101,15 +120,12 @@ sycl::marray<bfloat16, N> fmin(sycl::marray<bfloat16, N> x,
101
120
oneapi::detail::bfloat16ToBits (y[N - 1 ]);
102
121
res[N - 1 ] = oneapi::detail::bitsToBfloat16 (__clc_fmin (XBits, YBits));
103
122
}
104
-
105
- return res;
106
123
#else
107
- std::ignore = x;
108
- std::ignore = y;
109
- throw runtime_error (
110
- " bfloat16 math functions are not currently supported on the host device." ,
111
- PI_ERROR_INVALID_DEVICE);
124
+ for (size_t i = 0 ; i < N; i++) {
125
+ res[i] = fmin (x[i], y[i]);
126
+ }
112
127
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
128
+ return res;
113
129
}
114
130
115
131
template <typename T>
@@ -119,20 +135,30 @@ std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmax(T x, T y) {
119
135
oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits (y);
120
136
return oneapi::detail::bitsToBfloat16 (__clc_fmax (XBits, YBits));
121
137
#else
122
- std::ignore = x;
123
- std::ignore = y;
124
- throw runtime_error (
125
- " bfloat16 math functions are not currently supported on the host device." ,
126
- PI_ERROR_INVALID_DEVICE);
138
+ static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0 ;
139
+ oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits (x);
140
+ oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits (y);
141
+ if (isnan (x) && isnan (y))
142
+ return oneapi::detail::bitsToBfloat16 (CanonicalNan);
143
+
144
+ if (isnan (x))
145
+ return y;
146
+ if (isnan (y))
147
+ return x;
148
+ if (((XBits | YBits) ==
149
+ static_cast <oneapi::detail::Bfloat16StorageT>(0x8000 )) &&
150
+ !(XBits & YBits))
151
+ return oneapi::detail::bitsToBfloat16 (0 );
152
+
153
+ return (x > y) ? x : y;
127
154
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
128
155
}
129
156
130
157
template <size_t N>
131
158
sycl::marray<bfloat16, N> fmax (sycl::marray<bfloat16, N> x,
132
159
sycl::marray<bfloat16, N> y) {
133
- #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
134
160
sycl::marray<bfloat16, N> res;
135
-
161
+ # if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
136
162
for (size_t i = 0 ; i < N / 2 ; i++) {
137
163
auto partial_res = __clc_fmax (detail::to_uint32_t (x, i * 2 ),
138
164
detail::to_uint32_t (y, i * 2 ));
@@ -146,14 +172,12 @@ sycl::marray<bfloat16, N> fmax(sycl::marray<bfloat16, N> x,
146
172
oneapi::detail::bfloat16ToBits (y[N - 1 ]);
147
173
res[N - 1 ] = oneapi::detail::bitsToBfloat16 (__clc_fmax (XBits, YBits));
148
174
}
149
- return res;
150
175
#else
151
- std::ignore = x;
152
- std::ignore = y;
153
- throw runtime_error (
154
- " bfloat16 math functions are not currently supported on the host device." ,
155
- PI_ERROR_INVALID_DEVICE);
176
+ for (size_t i = 0 ; i < N; i++) {
177
+ res[i] = fmax (x[i], y[i]);
178
+ }
156
179
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
180
+ return res;
157
181
}
158
182
159
183
template <typename T>
0 commit comments