Skip to content

Commit 378678a

Browse files
authored
[SYCL] Add bfloat16 generic impl for fma (#7863)
1 parent 40d80c0 commit 378678a

File tree

1 file changed

+6
-15
lines changed

1 file changed

+6
-15
lines changed

sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -188,22 +188,16 @@ std::enable_if_t<std::is_same<T, bfloat16>::value, T> fma(T x, T y, T z) {
188188
oneapi::detail::Bfloat16StorageT ZBits = oneapi::detail::bfloat16ToBits(z);
189189
return oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits));
190190
#else
191-
std::ignore = x;
192-
std::ignore = y;
193-
std::ignore = z;
194-
throw runtime_error(
195-
"bfloat16 math functions are not currently supported on the host device.",
196-
PI_ERROR_INVALID_DEVICE);
191+
return sycl::ext::oneapi::bfloat16{sycl::fma(float{x}, float{y}, float{z})};
197192
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
198193
}
199194

200195
template <size_t N>
201196
sycl::marray<bfloat16, N> fma(sycl::marray<bfloat16, N> x,
202197
sycl::marray<bfloat16, N> y,
203198
sycl::marray<bfloat16, N> z) {
204-
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
205199
sycl::marray<bfloat16, N> res;
206-
200+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
207201
for (size_t i = 0; i < N / 2; i++) {
208202
auto partial_res =
209203
__clc_fma(detail::to_uint32_t(x, i * 2), detail::to_uint32_t(y, i * 2),
@@ -220,15 +214,12 @@ sycl::marray<bfloat16, N> fma(sycl::marray<bfloat16, N> x,
220214
oneapi::detail::bfloat16ToBits(z[N - 1]);
221215
res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fma(XBits, YBits, ZBits));
222216
}
223-
return res;
224217
#else
225-
std::ignore = x;
226-
std::ignore = y;
227-
std::ignore = z;
228-
throw runtime_error(
229-
"bfloat16 math functions are not currently supported on the host device.",
230-
PI_ERROR_INVALID_DEVICE);
218+
for (size_t i = 0; i < N; i++) {
219+
res[i] = fma(x[i], y[i], z[i]);
220+
}
231221
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
222+
return res;
232223
}
233224

234225
} // namespace ext::oneapi::experimental

0 commit comments

Comments
 (0)