Skip to content

Commit 0032a84

Browse files
[SYCL] Optimize vec<bfloat> math builtins (#14106)
Followup and blocked by: #14105 Currently, `vec<bfloat>` math builtins do element-by-element operations. This PR optimize `vec<bfloat>` math builtins by: (1) Converting `vec<bfloat>` to `vec<float>`. (2) Do the operation on `vec<float>` (which uses Spirv built-ins underneath for optimized vector operations). (3) Convert back the return value to `vec<bfloat>`. Look at the beautiful diff in `check_device_code/vector/vector_bf16_builtins.cpp` to visualize the device code generated before and after this optimization.
1 parent 157310a commit 0032a84

File tree

4 files changed

+297
-324
lines changed

4 files changed

+297
-324
lines changed

sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,26 @@ template <size_t N> sycl::marray<bool, N> isnan(sycl::marray<bfloat16, N> x) {
6262
template <typename T, int N = num_elements_v<T>>
6363
std::enable_if_t<is_vec_or_swizzle_bf16_v<T>, sycl::vec<int16_t, N>>
6464
isnan(T x) {
65+
66+
#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
67+
// Convert BFloat16 vector to float vec and call isnan().
68+
sycl::vec<float, N> FVec =
69+
x.template convert<float, sycl::rounding_mode::automatic>();
70+
auto Res = isnan(FVec);
71+
72+
// For vec<float>, the return type of isnan is vec<int32_t> so,
73+
// an explicit conversion is required to vec<int16_t>.
74+
return Res.template convert<int16_t>();
75+
#else
76+
6577
sycl::vec<int16_t, N> res;
6678
for (size_t i = 0; i < N; i++) {
6779
// The result of isnan is 0 or 1 but SPEC requires
6880
// isnan() of vec/swizzle to return -1 or 0.
6981
res[i] = isnan(x[i]) ? -1 : 0;
7082
}
7183
return res;
84+
#endif
7285
}
7386

7487
/******************* fabs ********************/
@@ -120,11 +133,19 @@ sycl::marray<bfloat16, N> fabs(sycl::marray<bfloat16, N> x) {
120133
template <typename T, int N = num_elements_v<T>>
121134
std::enable_if_t<is_vec_or_swizzle_bf16_v<T>, sycl::vec<bfloat16, N>>
122135
fabs(T x) {
136+
#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
137+
// Convert BFloat16 vector to float vec.
138+
sycl::vec<float, N> FVec =
139+
x.template convert<float, sycl::rounding_mode::automatic>();
140+
auto Res = fabs(FVec);
141+
return Res.template convert<bfloat16>();
142+
#else
123143
sycl::vec<bfloat16, N> res;
124144
for (size_t i = 0; i < N; i++) {
125145
res[i] = fabs(x[i]);
126146
}
127147
return res;
148+
#endif
128149
}
129150

130151
/******************* fmin ********************/
@@ -193,11 +214,21 @@ std::enable_if_t<is_vec_or_swizzle_bf16_v<T1> && is_vec_or_swizzle_bf16_v<T2> &&
193214
N1 == N2,
194215
sycl::vec<bfloat16, N1>>
195216
fmin(T1 x, T2 y) {
217+
#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
218+
// Convert BFloat16 vectors to float vecs.
219+
sycl::vec<float, N1> FVecX =
220+
x.template convert<float, sycl::rounding_mode::automatic>();
221+
sycl::vec<float, N1> FVecY =
222+
y.template convert<float, sycl::rounding_mode::automatic>();
223+
auto Res = fmin(FVecX, FVecY);
224+
return Res.template convert<bfloat16>();
225+
#else
196226
sycl::vec<bfloat16, N1> res;
197227
for (size_t i = 0; i < N1; i++) {
198228
res[i] = fmin(x[i], y[i]);
199229
}
200230
return res;
231+
#endif
201232
}
202233

203234
/******************* fmax ********************/
@@ -265,11 +296,21 @@ std::enable_if_t<is_vec_or_swizzle_bf16_v<T1> && is_vec_or_swizzle_bf16_v<T2> &&
265296
N1 == N2,
266297
sycl::vec<bfloat16, N1>>
267298
fmax(T1 x, T2 y) {
299+
#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
300+
// Convert BFloat16 vectors to float vecs.
301+
sycl::vec<float, N1> FVecX =
302+
x.template convert<float, sycl::rounding_mode::automatic>();
303+
sycl::vec<float, N1> FVecY =
304+
y.template convert<float, sycl::rounding_mode::automatic>();
305+
auto Res = fmax(FVecX, FVecY);
306+
return Res.template convert<bfloat16>();
307+
#else
268308
sycl::vec<bfloat16, N1> res;
269309
for (size_t i = 0; i < N1; i++) {
270310
res[i] = fmax(x[i], y[i]);
271311
}
272312
return res;
313+
#endif
273314
}
274315

275316
/******************* fma *********************/
@@ -327,11 +368,24 @@ std::enable_if_t<is_vec_or_swizzle_bf16_v<T1> && is_vec_or_swizzle_bf16_v<T2> &&
327368
is_vec_or_swizzle_bf16_v<T3> && N1 == N2 && N2 == N3,
328369
sycl::vec<bfloat16, N1>>
329370
fma(T1 x, T2 y, T3 z) {
371+
#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
372+
// Convert BFloat16 vectors to float vecs.
373+
sycl::vec<float, N1> FVecX =
374+
x.template convert<float, sycl::rounding_mode::automatic>();
375+
sycl::vec<float, N1> FVecY =
376+
y.template convert<float, sycl::rounding_mode::automatic>();
377+
sycl::vec<float, N1> FVecZ =
378+
z.template convert<float, sycl::rounding_mode::automatic>();
379+
380+
auto Res = fma(FVecX, FVecY, FVecZ);
381+
return Res.template convert<bfloat16>();
382+
#else
330383
sycl::vec<bfloat16, N1> res;
331384
for (size_t i = 0; i < N1; i++) {
332385
res[i] = fma(x[i], y[i], z[i]);
333386
}
334387
return res;
388+
#endif
335389
}
336390

337391
/******************* unary math operations ********************/
@@ -352,6 +406,18 @@ fma(T1 x, T2 y, T3 z) {
352406
return res; \
353407
}
354408

409+
#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
410+
#define BFLOAT16_MATH_FP32_WRAPPERS_VEC(op) \
411+
/* Overload for BF16 vec and swizzles. */ \
412+
template <typename T, int N = num_elements_v<T>> \
413+
std::enable_if_t<is_vec_or_swizzle_bf16_v<T>, sycl::vec<bfloat16, N>> op( \
414+
T x) { \
415+
sycl::vec<float, N> FVec = \
416+
x.template convert<float, sycl::rounding_mode::automatic>(); \
417+
auto Res = op(FVec); \
418+
return Res.template convert<bfloat16>(); \
419+
}
420+
#else
355421
#define BFLOAT16_MATH_FP32_WRAPPERS_VEC(op) \
356422
/* Overload for BF16 vec and swizzles. */ \
357423
template <typename T, int N = num_elements_v<T>> \
@@ -363,6 +429,7 @@ fma(T1 x, T2 y, T3 z) {
363429
} \
364430
return res; \
365431
}
432+
#endif
366433

367434
BFLOAT16_MATH_FP32_WRAPPERS(ceil)
368435
BFLOAT16_MATH_FP32_WRAPPERS_MARRAY(ceil)

sycl/include/sycl/vector_preview.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1363,7 +1363,11 @@ class SwizzleOp {
13631363
template <typename convertT, rounding_mode roundingMode>
13641364
vec<convertT, sizeof...(Indexes)> convert() const {
13651365
// First materialize the swizzle to vec_t and then apply convert() to it.
1366-
vec_t Tmp = *this;
1366+
vec_t Tmp;
1367+
std::array<int, getNumElements()> Idxs{Indexes...};
1368+
for (size_t I = 0; I < Idxs.size(); ++I) {
1369+
Tmp[I] = (*m_Vector)[Idxs[I]];
1370+
}
13671371
return Tmp.template convert<convertT, roundingMode>();
13681372
}
13691373

sycl/test-e2e/BFloat16/bfloat16_vec_builtins.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,11 @@ bool check(bool a, bool b) { return (a != b); }
3131
for (int i = 0; i < SZ; i++) { \
3232
arg[i] = INPVAL; \
3333
} \
34-
/* Perform the operation. */ \
35-
vec<RETTY, SZ> \
36-
res = sycl::ext::oneapi::experimental::NAME(arg); \
34+
/* Perform the operation. */ \
35+
vec<RETTY, SZ> res = sycl::ext::oneapi::experimental::NAME(arg); \
3736
vec<RETTY, 2> res2 = \
3837
sycl::ext::oneapi::experimental::NAME(arg.template swizzle<0, 0>()); \
39-
/* Check the result. */ \
38+
/* Check the result. */ \
4039
if (res2[0] != res[0] || res2[1] != res[0]) { \
4140
ERR[0] += 1; \
4241
} \
@@ -56,9 +55,8 @@ bool check(bool a, bool b) { return (a != b); }
5655
arg[i] = INPVAL; \
5756
arg2[i] = inpVal2; \
5857
} \
59-
/* Perform the operation. */ \
60-
vec<RETTY, SZ> \
61-
res = sycl::ext::oneapi::experimental::NAME(arg, arg2); \
58+
/* Perform the operation. */ \
59+
vec<RETTY, SZ> res = sycl::ext::oneapi::experimental::NAME(arg, arg2); \
6260
/* Swizzle and vec different combination. */ \
6361
vec<RETTY, 2> res2 = sycl::ext::oneapi::experimental::NAME( \
6462
arg.template swizzle<0, 0>(), arg2.template swizzle<0, 0>()); \

0 commit comments

Comments
 (0)