@@ -188,22 +188,16 @@ std::enable_if_t<std::is_same<T, bfloat16>::value, T> fma(T x, T y, T z) {
188
188
oneapi::detail::Bfloat16StorageT ZBits = oneapi::detail::bfloat16ToBits (z);
189
189
return oneapi::detail::bitsToBfloat16 (__clc_fma (XBits, YBits, ZBits));
190
190
#else
191
- std::ignore = x;
192
- std::ignore = y;
193
- std::ignore = z;
194
- throw runtime_error (
195
- " bfloat16 math functions are not currently supported on the host device." ,
196
- PI_ERROR_INVALID_DEVICE);
191
+ return sycl::ext::oneapi::bfloat16{sycl::fma (float {x}, float {y}, float {z})};
197
192
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
198
193
}
199
194
200
195
template <size_t N>
201
196
sycl::marray<bfloat16, N> fma (sycl::marray<bfloat16, N> x,
202
197
sycl::marray<bfloat16, N> y,
203
198
sycl::marray<bfloat16, N> z) {
204
- #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
205
199
sycl::marray<bfloat16, N> res;
206
-
200
+ # if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
207
201
for (size_t i = 0 ; i < N / 2 ; i++) {
208
202
auto partial_res =
209
203
__clc_fma (detail::to_uint32_t (x, i * 2 ), detail::to_uint32_t (y, i * 2 ),
@@ -220,15 +214,12 @@ sycl::marray<bfloat16, N> fma(sycl::marray<bfloat16, N> x,
220
214
oneapi::detail::bfloat16ToBits (z[N - 1 ]);
221
215
res[N - 1 ] = oneapi::detail::bitsToBfloat16 (__clc_fma (XBits, YBits, ZBits));
222
216
}
223
- return res;
224
217
#else
225
- std::ignore = x;
226
- std::ignore = y;
227
- std::ignore = z;
228
- throw runtime_error (
229
- " bfloat16 math functions are not currently supported on the host device." ,
230
- PI_ERROR_INVALID_DEVICE);
218
+ for (size_t i = 0 ; i < N; i++) {
219
+ res[i] = fma (x[i], y[i], z[i]);
220
+ }
231
221
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
222
+ return res;
232
223
}
233
224
234
225
} // namespace ext::oneapi::experimental
0 commit comments