@@ -62,13 +62,26 @@ template <size_t N> sycl::marray<bool, N> isnan(sycl::marray<bfloat16, N> x) {
62
62
template <typename T, int N = num_elements_v<T>>
63
63
std::enable_if_t <is_vec_or_swizzle_bf16_v<T>, sycl::vec<int16_t , N>>
64
64
isnan (T x) {
65
+
66
+ #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
67
+ // Convert BFloat16 vector to float vec and call isnan().
68
+ sycl::vec<float , N> FVec =
69
+ x.template convert <float , sycl::rounding_mode::automatic>();
70
+ auto Res = isnan (FVec);
71
+
72
+ // For vec<float>, the return type of isnan is vec<int32_t> so,
73
+ // an explicit conversion is required to vec<int16_t>.
74
+ return Res.template convert <int16_t >();
75
+ #else
76
+
65
77
sycl::vec<int16_t , N> res;
66
78
for (size_t i = 0 ; i < N; i++) {
67
79
// The result of isnan is 0 or 1 but SPEC requires
68
80
// isnan() of vec/swizzle to return -1 or 0.
69
81
res[i] = isnan (x[i]) ? -1 : 0 ;
70
82
}
71
83
return res;
84
+ #endif
72
85
}
73
86
74
87
/* ****************** fabs ********************/
@@ -120,11 +133,19 @@ sycl::marray<bfloat16, N> fabs(sycl::marray<bfloat16, N> x) {
120
133
template <typename T, int N = num_elements_v<T>>
121
134
std::enable_if_t <is_vec_or_swizzle_bf16_v<T>, sycl::vec<bfloat16, N>>
122
135
fabs (T x) {
136
+ #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
137
+ // Convert BFloat16 vector to float vec.
138
+ sycl::vec<float , N> FVec =
139
+ x.template convert <float , sycl::rounding_mode::automatic>();
140
+ auto Res = fabs (FVec);
141
+ return Res.template convert <bfloat16>();
142
+ #else
123
143
sycl::vec<bfloat16, N> res;
124
144
for (size_t i = 0 ; i < N; i++) {
125
145
res[i] = fabs (x[i]);
126
146
}
127
147
return res;
148
+ #endif
128
149
}
129
150
130
151
/* ****************** fmin ********************/
@@ -193,11 +214,21 @@ std::enable_if_t<is_vec_or_swizzle_bf16_v<T1> && is_vec_or_swizzle_bf16_v<T2> &&
193
214
N1 == N2,
194
215
sycl::vec<bfloat16, N1>>
195
216
fmin (T1 x, T2 y) {
217
+ #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
218
+ // Convert BFloat16 vectors to float vecs.
219
+ sycl::vec<float , N1> FVecX =
220
+ x.template convert <float , sycl::rounding_mode::automatic>();
221
+ sycl::vec<float , N1> FVecY =
222
+ y.template convert <float , sycl::rounding_mode::automatic>();
223
+ auto Res = fmin (FVecX, FVecY);
224
+ return Res.template convert <bfloat16>();
225
+ #else
196
226
sycl::vec<bfloat16, N1> res;
197
227
for (size_t i = 0 ; i < N1; i++) {
198
228
res[i] = fmin (x[i], y[i]);
199
229
}
200
230
return res;
231
+ #endif
201
232
}
202
233
203
234
/* ****************** fmax ********************/
@@ -265,11 +296,21 @@ std::enable_if_t<is_vec_or_swizzle_bf16_v<T1> && is_vec_or_swizzle_bf16_v<T2> &&
265
296
N1 == N2,
266
297
sycl::vec<bfloat16, N1>>
267
298
fmax (T1 x, T2 y) {
299
+ #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
300
+ // Convert BFloat16 vectors to float vecs.
301
+ sycl::vec<float , N1> FVecX =
302
+ x.template convert <float , sycl::rounding_mode::automatic>();
303
+ sycl::vec<float , N1> FVecY =
304
+ y.template convert <float , sycl::rounding_mode::automatic>();
305
+ auto Res = fmax (FVecX, FVecY);
306
+ return Res.template convert <bfloat16>();
307
+ #else
268
308
sycl::vec<bfloat16, N1> res;
269
309
for (size_t i = 0 ; i < N1; i++) {
270
310
res[i] = fmax (x[i], y[i]);
271
311
}
272
312
return res;
313
+ #endif
273
314
}
274
315
275
316
/* ****************** fma *********************/
@@ -327,11 +368,24 @@ std::enable_if_t<is_vec_or_swizzle_bf16_v<T1> && is_vec_or_swizzle_bf16_v<T2> &&
327
368
is_vec_or_swizzle_bf16_v<T3> && N1 == N2 && N2 == N3,
328
369
sycl::vec<bfloat16, N1>>
329
370
fma (T1 x, T2 y, T3 z) {
371
+ #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
372
+ // Convert BFloat16 vectors to float vecs.
373
+ sycl::vec<float , N1> FVecX =
374
+ x.template convert <float , sycl::rounding_mode::automatic>();
375
+ sycl::vec<float , N1> FVecY =
376
+ y.template convert <float , sycl::rounding_mode::automatic>();
377
+ sycl::vec<float , N1> FVecZ =
378
+ z.template convert <float , sycl::rounding_mode::automatic>();
379
+
380
+ auto Res = fma (FVecX, FVecY, FVecZ);
381
+ return Res.template convert <bfloat16>();
382
+ #else
330
383
sycl::vec<bfloat16, N1> res;
331
384
for (size_t i = 0 ; i < N1; i++) {
332
385
res[i] = fma (x[i], y[i], z[i]);
333
386
}
334
387
return res;
388
+ #endif
335
389
}
336
390
337
391
/* ****************** unary math operations ********************/
@@ -352,6 +406,18 @@ fma(T1 x, T2 y, T3 z) {
352
406
return res; \
353
407
}
354
408
409
+ #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
410
+ #define BFLOAT16_MATH_FP32_WRAPPERS_VEC (op ) \
411
+ /* Overload for BF16 vec and swizzles. */ \
412
+ template <typename T, int N = num_elements_v<T>> \
413
+ std::enable_if_t <is_vec_or_swizzle_bf16_v<T>, sycl::vec<bfloat16, N>> op ( \
414
+ T x) { \
415
+ sycl::vec<float , N> FVec = \
416
+ x.template convert <float , sycl::rounding_mode::automatic>(); \
417
+ auto Res = op (FVec); \
418
+ return Res.template convert <bfloat16>(); \
419
+ }
420
+ #else
355
421
#define BFLOAT16_MATH_FP32_WRAPPERS_VEC (op ) \
356
422
/* Overload for BF16 vec and swizzles. */ \
357
423
template <typename T, int N = num_elements_v<T>> \
@@ -363,6 +429,7 @@ fma(T1 x, T2 y, T3 z) {
363
429
} \
364
430
return res; \
365
431
}
432
+ #endif
366
433
367
434
BFLOAT16_MATH_FP32_WRAPPERS (ceil)
368
435
BFLOAT16_MATH_FP32_WRAPPERS_MARRAY (ceil)
0 commit comments