Skip to content

Commit 13a7b3a

Browse files
[SYCL] [libdevice] Add vector overloads of ConvertBFloat16ToFINTEL and ConvertFToBFloat16INTEL (#14085)
This PR adds vector overloads of `ConvertBFloat16ToFINTEL` and `ConvertFToBFloat16INTEL` to libdevice (SPEC: https://spec.oneapi.io/level-zero/latest/core/SPIRV.html#bfloat16-conversions) and a wrapper around it (`BF16VecToFloatVec` and `FloatVecToBF16Vec`) in `ext::oneapi::detail`. These overloads are intended to optimize BFloat16 `marray`, `vec` operations, for which we currently do element-by-element `bfloat16 -> float -> bfloat16` conversions.
1 parent fe8c284 commit 13a7b3a

File tree

7 files changed

+383
-2
lines changed

7 files changed

+383
-2
lines changed

libdevice/bfloat16_wrapper.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#if defined(__SPIR__) || defined(__SPIRV__)
1212

1313
#include <CL/__spirv/spirv_ops.hpp>
14+
#include <CL/__spirv/spirv_types.hpp>
15+
#include <cassert>
1416
#include <cstdint>
1517

1618
DEVICE_EXTERN_C_INLINE
@@ -23,4 +25,42 @@ float __devicelib_ConvertBF16ToFINTEL(const uint16_t &x) {
2325
return __spirv_ConvertBF16ToFINTEL(x);
2426
}
2527

28+
// For vector of size 1.
29+
DEVICE_EXTERN_C_INLINE
30+
void __devicelib_ConvertFToBF16INTELVec1(const float *src, uint16_t *dst) {
31+
dst[0] = __spirv_ConvertFToBF16INTEL(src[0]);
32+
}
33+
DEVICE_EXTERN_C_INLINE
34+
void __devicelib_ConvertBF16ToFINTELVec1(const uint16_t *src, float *dst) {
35+
dst[0] = __spirv_ConvertBF16ToFINTEL(src[0]);
36+
}
37+
38+
// Generate the conversion functions for vector of size 2, 3, 4, 8, 16.
39+
#define GenerateConvertFunctionForVec(size) \
40+
DEVICE_EXTERN_C_INLINE \
41+
void __devicelib_ConvertFToBF16INTELVec##size(const float *src, \
42+
uint16_t *dst) { \
43+
__ocl_vec_t<float, size> x = \
44+
*__builtin_bit_cast(const __ocl_vec_t<float, size> *, src); \
45+
__ocl_vec_t<uint16_t, size> y = __spirv_ConvertFToBF16INTEL(x); \
46+
*__builtin_bit_cast(__ocl_vec_t<uint16_t, size> *, dst) = y; \
47+
} \
48+
DEVICE_EXTERN_C_INLINE \
49+
void __devicelib_ConvertBF16ToFINTELVec##size(const uint16_t *src, \
50+
float *dst) { \
51+
__ocl_vec_t<uint16_t, size> x = \
52+
*__builtin_bit_cast(const __ocl_vec_t<uint16_t, size> *, src); \
53+
__ocl_vec_t<float, size> y = __spirv_ConvertBF16ToFINTEL(x); \
54+
*__builtin_bit_cast(__ocl_vec_t<float, size> *, dst) = y; \
55+
}
56+
57+
// clang-format off
58+
GenerateConvertFunctionForVec(2)
59+
GenerateConvertFunctionForVec(3)
60+
GenerateConvertFunctionForVec(4)
61+
GenerateConvertFunctionForVec(8)
62+
GenerateConvertFunctionForVec(16)
63+
// clang-format on
64+
#undef GenerateConvertFunctionForVec
65+
2666
#endif // __SPIR__ || __SPIRV__

libdevice/fallback-bfloat16.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,31 @@ __devicelib_ConvertBF16ToFINTEL(const uint16_t &a) {
4343
return floatValue;
4444
}
4545

46+
// Generate the conversion functions for vector of size 1, 2, 3, 4, 8, 16.
47+
#define GenerateConvertFunctionForVec(size) \
48+
DEVICE_EXTERN_C_INLINE \
49+
void __devicelib_ConvertFToBF16INTELVec##size(const float *src, \
50+
uint16_t *dst) { \
51+
for (int i = 0; i < size; ++i) { \
52+
dst[i] = __devicelib_ConvertFToBF16INTEL(src[i]); \
53+
} \
54+
} \
55+
DEVICE_EXTERN_C_INLINE \
56+
void __devicelib_ConvertBF16ToFINTELVec##size(const uint16_t *src, \
57+
float *dst) { \
58+
for (int i = 0; i < size; ++i) { \
59+
dst[i] = __devicelib_ConvertBF16ToFINTEL(src[i]); \
60+
} \
61+
}
62+
63+
// clang-format off
64+
GenerateConvertFunctionForVec(1)
65+
GenerateConvertFunctionForVec(2)
66+
GenerateConvertFunctionForVec(3)
67+
GenerateConvertFunctionForVec(4)
68+
GenerateConvertFunctionForVec(8)
69+
GenerateConvertFunctionForVec(16)
70+
// clang-format on
71+
#undef GenerateConvertFunctionForVec
72+
4673
#endif // __SPIR__ || __SPIRV__

llvm/tools/sycl-post-link/SYCLDeviceLibReqMask.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,30 @@ SYCLDeviceLibFuncMap SDLMap = {
668668
DeviceLibExt::cl_intel_devicelib_bfloat16},
669669
{"__devicelib_ConvertBF16ToFINTEL",
670670
DeviceLibExt::cl_intel_devicelib_bfloat16},
671+
{"__devicelib_ConvertFToBF16INTELVec1",
672+
DeviceLibExt::cl_intel_devicelib_bfloat16},
673+
{"__devicelib_ConvertBF16ToFINTELVec1",
674+
DeviceLibExt::cl_intel_devicelib_bfloat16},
675+
{"__devicelib_ConvertFToBF16INTELVec2",
676+
DeviceLibExt::cl_intel_devicelib_bfloat16},
677+
{"__devicelib_ConvertBF16ToFINTELVec2",
678+
DeviceLibExt::cl_intel_devicelib_bfloat16},
679+
{"__devicelib_ConvertFToBF16INTELVec3",
680+
DeviceLibExt::cl_intel_devicelib_bfloat16},
681+
{"__devicelib_ConvertBF16ToFINTELVec3",
682+
DeviceLibExt::cl_intel_devicelib_bfloat16},
683+
{"__devicelib_ConvertFToBF16INTELVec4",
684+
DeviceLibExt::cl_intel_devicelib_bfloat16},
685+
{"__devicelib_ConvertBF16ToFINTELVec4",
686+
DeviceLibExt::cl_intel_devicelib_bfloat16},
687+
{"__devicelib_ConvertFToBF16INTELVec8",
688+
DeviceLibExt::cl_intel_devicelib_bfloat16},
689+
{"__devicelib_ConvertBF16ToFINTELVec8",
690+
DeviceLibExt::cl_intel_devicelib_bfloat16},
691+
{"__devicelib_ConvertFToBF16INTELVec16",
692+
DeviceLibExt::cl_intel_devicelib_bfloat16},
693+
{"__devicelib_ConvertBF16ToFINTELVec16",
694+
DeviceLibExt::cl_intel_devicelib_bfloat16},
671695
};
672696

673697
// Each fallback device library corresponds to one bit in "require mask" which

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,10 +1019,16 @@ extern __DPCPP_SYCL_EXTERNAL void
10191019
__spirv_ocl_prefetch(const __attribute__((opencl_global)) char *Ptr,
10201020
size_t NumBytes) noexcept;
10211021

1022-
extern __DPCPP_SYCL_EXTERNAL uint16_t
1023-
__spirv_ConvertFToBF16INTEL(float) noexcept;
10241022
extern __DPCPP_SYCL_EXTERNAL float
10251023
__spirv_ConvertBF16ToFINTEL(uint16_t) noexcept;
1024+
extern __DPCPP_SYCL_EXTERNAL uint16_t
1025+
__spirv_ConvertFToBF16INTEL(float) noexcept;
1026+
template <int N>
1027+
extern __DPCPP_SYCL_EXTERNAL __ocl_vec_t<float, N>
1028+
__spirv_ConvertBF16ToFINTEL(__ocl_vec_t<uint16_t, N>) noexcept;
1029+
template <int N>
1030+
extern __DPCPP_SYCL_EXTERNAL __ocl_vec_t<uint16_t, N>
1031+
__spirv_ConvertFToBF16INTEL(__ocl_vec_t<float, N>) noexcept;
10261032

10271033
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
10281034
__SYCL_EXPORT __ocl_vec_t<uint32_t, 4>

sycl/include/sycl/ext/oneapi/bfloat16.hpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,30 @@ extern "C" __DPCPP_SYCL_EXTERNAL uint16_t
1818
__devicelib_ConvertFToBF16INTEL(const float &) noexcept;
1919
extern "C" __DPCPP_SYCL_EXTERNAL float
2020
__devicelib_ConvertBF16ToFINTEL(const uint16_t &) noexcept;
21+
extern "C" __DPCPP_SYCL_EXTERNAL void
22+
__devicelib_ConvertFToBF16INTELVec1(const float *, uint16_t *) noexcept;
23+
extern "C" __DPCPP_SYCL_EXTERNAL void
24+
__devicelib_ConvertBF16ToFINTELVec1(const uint16_t *, float *) noexcept;
25+
extern "C" __DPCPP_SYCL_EXTERNAL void
26+
__devicelib_ConvertFToBF16INTELVec2(const float *, uint16_t *) noexcept;
27+
extern "C" __DPCPP_SYCL_EXTERNAL void
28+
__devicelib_ConvertBF16ToFINTELVec2(const uint16_t *, float *) noexcept;
29+
extern "C" __DPCPP_SYCL_EXTERNAL void
30+
__devicelib_ConvertFToBF16INTELVec3(const float *, uint16_t *) noexcept;
31+
extern "C" __DPCPP_SYCL_EXTERNAL void
32+
__devicelib_ConvertBF16ToFINTELVec3(const uint16_t *, float *) noexcept;
33+
extern "C" __DPCPP_SYCL_EXTERNAL void
34+
__devicelib_ConvertFToBF16INTELVec4(const float *, uint16_t *) noexcept;
35+
extern "C" __DPCPP_SYCL_EXTERNAL void
36+
__devicelib_ConvertBF16ToFINTELVec4(const uint16_t *, float *) noexcept;
37+
extern "C" __DPCPP_SYCL_EXTERNAL void
38+
__devicelib_ConvertFToBF16INTELVec8(const float *, uint16_t *) noexcept;
39+
extern "C" __DPCPP_SYCL_EXTERNAL void
40+
__devicelib_ConvertBF16ToFINTELVec8(const uint16_t *, float *) noexcept;
41+
extern "C" __DPCPP_SYCL_EXTERNAL void
42+
__devicelib_ConvertFToBF16INTELVec16(const float *, uint16_t *) noexcept;
43+
extern "C" __DPCPP_SYCL_EXTERNAL void
44+
__devicelib_ConvertBF16ToFINTELVec16(const uint16_t *, float *) noexcept;
2145

2246
namespace sycl {
2347
inline namespace _V1 {
@@ -30,6 +54,52 @@ using Bfloat16StorageT = uint16_t;
3054
Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value);
3155
bfloat16 bitsToBfloat16(const Bfloat16StorageT Value);
3256

57+
template <int N> void BF16VecToFloatVec(const bfloat16 src[N], float dst[N]) {
58+
#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
59+
const uint16_t *src_i16 = sycl::bit_cast<const uint16_t *>(src);
60+
if constexpr (N == 1)
61+
__devicelib_ConvertBF16ToFINTELVec1(src_i16, dst);
62+
else if constexpr (N == 2)
63+
__devicelib_ConvertBF16ToFINTELVec2(src_i16, dst);
64+
else if constexpr (N == 3)
65+
__devicelib_ConvertBF16ToFINTELVec3(src_i16, dst);
66+
else if constexpr (N == 4)
67+
__devicelib_ConvertBF16ToFINTELVec4(src_i16, dst);
68+
else if constexpr (N == 8)
69+
__devicelib_ConvertBF16ToFINTELVec8(src_i16, dst);
70+
else if constexpr (N == 16)
71+
__devicelib_ConvertBF16ToFINTELVec16(src_i16, dst);
72+
#else
73+
for (int i = 0; i < N; ++i) {
74+
dst[i] = (float)src[i];
75+
}
76+
#endif
77+
}
78+
79+
template <int N> void FloatVecToBF16Vec(float src[N], bfloat16 dst[N]) {
80+
#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__))
81+
uint16_t *dst_i16 = sycl::bit_cast<uint16_t *>(dst);
82+
if constexpr (N == 1)
83+
__devicelib_ConvertFToBF16INTELVec1(src, dst_i16);
84+
else if constexpr (N == 2)
85+
__devicelib_ConvertFToBF16INTELVec2(src, dst_i16);
86+
else if constexpr (N == 3)
87+
__devicelib_ConvertFToBF16INTELVec3(src, dst_i16);
88+
else if constexpr (N == 4)
89+
__devicelib_ConvertFToBF16INTELVec4(src, dst_i16);
90+
else if constexpr (N == 8)
91+
__devicelib_ConvertFToBF16INTELVec8(src, dst_i16);
92+
else if constexpr (N == 16)
93+
__devicelib_ConvertFToBF16INTELVec16(src, dst_i16);
94+
#else
95+
for (int i = 0; i < N; ++i) {
96+
// No need to cast as bfloat16 has a assignment op overload that takes
97+
// a float.
98+
dst[i] = src[i];
99+
}
100+
#endif
101+
}
102+
33103
// sycl::vec support
34104
namespace bf16 {
35105
#ifdef __SYCL_DEVICE_ONLY__

sycl/test-e2e/BFloat16/bfloat16_conversions.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
#include <iostream>
2020
#include <sycl/detail/core.hpp>
2121

22+
#include <sycl/ext/oneapi/bfloat16.hpp>
23+
2224
using namespace sycl;
25+
using bfloat16 = sycl::ext::oneapi::bfloat16;
2326

2427
template <typename T> T calculate(T a, T b) {
2528
sycl::ext::oneapi::bfloat16 x = -a;
@@ -55,6 +58,82 @@ template <typename T> int test_host() {
5558
return 1;
5659
}
5760

61+
int test_host_vector_conversions() {
62+
bool Passed = true;
63+
std::cout << "float[4] -> bfloat16[4] -> float[4] conversion on host..."
64+
<< std::flush;
65+
66+
float FloatArray[4] = {1.0f, 2.0f, 3.0f, 4.0f};
67+
68+
// float[4] -> bfloat16[4]
69+
bfloat16 BFloatArray[4];
70+
sycl::ext::oneapi::detail::FloatVecToBF16Vec<4>(FloatArray, BFloatArray);
71+
72+
// bfloat16[4] -> float[4]
73+
float NewFloatArray[4];
74+
sycl::ext::oneapi::detail::BF16VecToFloatVec<4>(BFloatArray, NewFloatArray);
75+
76+
// Check results.
77+
for (int i = 0; i < 4; ++i)
78+
Passed &= (FloatArray[i] == NewFloatArray[i]);
79+
80+
if (Passed)
81+
std::cout << "passed\n";
82+
else
83+
std::cout << "failed\n";
84+
85+
return !Passed;
86+
}
87+
88+
int test_device_vector_conversions(queue Q) {
89+
int err = 0;
90+
buffer<int> err_buf(&err, 1);
91+
92+
std::cout << "float[4] -> bfloat16[4] conversion on device..." << std::flush;
93+
// Convert float array to bfloat16 array
94+
Q.submit([&](handler &CGH) {
95+
accessor<int, 1, access::mode::write, target::device> ERR(err_buf, CGH);
96+
CGH.single_task([=]() {
97+
float FloatArray[4] = {1.0f, -1.0f, 0.0f, 2.0f};
98+
bfloat16 BF16Array[4];
99+
sycl::ext::oneapi::detail::FloatVecToBF16Vec<4>(FloatArray, BF16Array);
100+
for (int i = 0; i < 4; i++) {
101+
if (FloatArray[i] != (float)BF16Array[i]) {
102+
ERR[0] = 1;
103+
}
104+
}
105+
});
106+
}).wait();
107+
108+
if (err)
109+
std::cout << "failed\n";
110+
else
111+
std::cout << "passed\n";
112+
113+
std::cout << "bfloat16[4] -> float[4] conversion on device..." << std::flush;
114+
// Convert bfloat16 array back to float array
115+
Q.submit([&](handler &CGH) {
116+
accessor<int, 1, access::mode::write, target::device> ERR(err_buf, CGH);
117+
CGH.single_task([=]() {
118+
bfloat16 BF16Array[3] = {1.0f, 0.0f, -1.0f};
119+
float FloatArray[3];
120+
sycl::ext::oneapi::detail::BF16VecToFloatVec<4>(BF16Array, FloatArray);
121+
for (int i = 0; i < 3; i++) {
122+
if (FloatArray[i] != (float)BF16Array[i]) {
123+
ERR[0] = 1;
124+
}
125+
}
126+
});
127+
}).wait();
128+
129+
if (err)
130+
std::cout << "failed\n";
131+
else
132+
std::cout << "passed\n";
133+
134+
return err;
135+
}
136+
58137
int main() {
59138
queue Q;
60139
int result;
@@ -63,6 +142,11 @@ int main() {
63142
if (Q.get_device().has(aspect::fp16))
64143
result |= test_device<sycl::half>(Q);
65144
result |= test_device<float>(Q);
145+
146+
// Test vector BF16 -> float conversion and vice versa.
147+
result |= test_host_vector_conversions();
148+
result |= test_device_vector_conversions(Q);
149+
66150
if (result)
67151
std::cout << "FAIL\n";
68152
else

0 commit comments

Comments
 (0)