Skip to content

Commit 65a7365

Browse files
authored
[ESIMD] Support bfloat16 simd vector element type. (#6664)
* [ESIMD] Support bfloat16 simd vector element type. - Implement corresponding element type traits - Implement __esimd_bf_cvt intrinsic
1 parent 66e469b commit 65a7365

File tree

6 files changed

+124
-6
lines changed

6 files changed

+124
-6
lines changed

llvm/lib/SYCLLowerIR/ESIMD/ESIMDVerifier.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@ static const char *LegalSYCLFunctions[] = {
4646
"^sycl::_V1::sin<.+>",
4747
"^sycl::_V1::log<.+>",
4848
"^sycl::_V1::exp<.+>",
49+
"^sycl::_V1::bit_cast<.+>",
4950
"^sycl::_V1::operator.+<.+>",
5051
"^sycl::_V1::ext::oneapi::sub_group::.+",
5152
"^sycl::_V1::ext::oneapi::experimental::spec_constant<.+>::.+",
52-
"^sycl::_V1::ext::oneapi::experimental::this_sub_group"};
53+
"^sycl::_V1::ext::oneapi::experimental::this_sub_group",
54+
"^sycl::_V1::ext::oneapi::experimental::bfloat16::.+"};
5355

5456
static const char *LegalSYCLFunctionsInStatelessMode[] = {
5557
"^sycl::_V1::multi_ptr<.+>::get",

llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ class ESIMDIntrinDescTable {
654654
{"test_src_tmpl_arg",
655655
{"test.src.tmpl.arg", {t(0), t1(1), t8(2), t16(3), t32(4), c8(17)}}},
656656
{"slm_init", {"slm.init", {a(0)}}},
657-
};
657+
{"bf_cvt", {"bf.cvt", {a(0)}}}};
658658
}
659659

660660
const IntrinTable &getTable() { return Table; }

sycl/include/sycl/ext/intel/esimd.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181

8282
#include <sycl/ext/intel/esimd/alt_ui.hpp>
8383
#include <sycl/ext/intel/esimd/common.hpp>
84+
#include <sycl/ext/intel/esimd/detail/bfloat16_type_traits.hpp>
8485
#include <sycl/ext/intel/esimd/detail/half_type_traits.hpp>
8586
#include <sycl/ext/intel/esimd/simd.hpp>
8687
#include <sycl/ext/intel/esimd/simd_view.hpp>
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
//==-------------- bfloat16_type_traits.hpp - DPC++ Explicit SIMD API ------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// Implementation of SIMD element type traits for the bfloat16 type.
9+
//===----------------------------------------------------------------------===//
10+
11+
#pragma once
12+
13+
#include <sycl/ext/intel/esimd/detail/elem_type_traits.hpp>
14+
#include <sycl/ext/intel/esimd/detail/intrin.hpp>
15+
16+
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>
17+
18+
/// @cond ESIMD_DETAIL
19+
20+
namespace sycl {
21+
__SYCL_INLINE_VER_NAMESPACE(_V1) {
22+
namespace ext::intel::esimd::detail {
23+
24+
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
25+
26+
template <> struct element_type_traits<bfloat16> {
27+
// TODO map the raw type to __bf16 once SPIRV target supports it:
28+
using RawT = uint_type_t<sizeof(bfloat16)>;
29+
// Nearest standard enclosing C++ type to delegate natively unsupported
30+
// operations to:
31+
using EnclosingCppT = float;
32+
// Can't map bfloat16 operations to opertations on RawT:
33+
static inline constexpr bool use_native_cpp_ops = false;
34+
static inline constexpr bool is_floating_point = true;
35+
};
36+
37+
#ifdef __SYCL_DEVICE_ONLY__
38+
// VC BE-specific glitch
39+
// @llvm.genx.bf.cvt uses half (_Float16) as bit representation for bfloat16
40+
using vc_be_bfloat16_raw_t = _Float16;
41+
#endif // __SYCL_DEVICE_ONLY__
42+
43+
// ------------------- Type conversion traits
44+
45+
template <int N> struct vector_conversion_traits<bfloat16, N> {
46+
using StdT = __cpp_t<bfloat16>;
47+
using StdVecT = vector_type_t<StdT, N>;
48+
using RawT = __raw_t<bfloat16>;
49+
50+
static ESIMD_INLINE vector_type_t<RawT, N>
51+
convert_to_raw(vector_type_t<StdT, N> Val) {
52+
#ifdef __SYCL_DEVICE_ONLY__
53+
using RawVecT = vector_type_t<vc_be_bfloat16_raw_t, N>;
54+
RawVecT ConvVal = __esimd_bf_cvt<vc_be_bfloat16_raw_t, StdT, N>(Val);
55+
// cast from _Float16 to int16_t:
56+
return sycl::bit_cast<vector_type_t<RawT, N>>(ConvVal);
57+
#else
58+
vector_type_t<RawT, N> Output = 0;
59+
60+
for (int i = 0; i < N; i++) {
61+
Output[i] = sycl::bit_cast<RawT>(static_cast<bfloat16>(Val[i]));
62+
}
63+
return Output;
64+
#endif // __SYCL_DEVICE_ONLY__
65+
}
66+
67+
static ESIMD_INLINE vector_type_t<StdT, N>
68+
convert_to_cpp(vector_type_t<RawT, N> Val) {
69+
#ifdef __SYCL_DEVICE_ONLY__
70+
using RawVecT = vector_type_t<vc_be_bfloat16_raw_t, N>;
71+
RawVecT Bits = sycl::bit_cast<RawVecT>(Val);
72+
return __esimd_bf_cvt<StdT, vc_be_bfloat16_raw_t, N>(Bits);
73+
#else
74+
vector_type_t<StdT, N> Output;
75+
76+
for (int i = 0; i < N; i++) {
77+
Output[i] = sycl::bit_cast<bfloat16>(Val[i]);
78+
}
79+
return Output;
80+
#endif // __SYCL_DEVICE_ONLY__
81+
}
82+
};
83+
84+
// TODO: remove bitcasts from the scalar_conversion_traits, and replace with
85+
// sycl::bit_cast directly
86+
template <> struct scalar_conversion_traits<bfloat16> {
87+
using RawT = __raw_t<bfloat16>;
88+
89+
static ESIMD_INLINE RawT bitcast_to_raw(bfloat16 Val) {
90+
return sycl::bit_cast<RawT>(Val);
91+
}
92+
93+
static ESIMD_INLINE bfloat16 bitcast_to_wrapper(RawT Val) {
94+
return sycl::bit_cast<bfloat16>(Val);
95+
}
96+
};
97+
98+
// bfloat16 uses default inefficient implementations of std C++ operations,
99+
// hence no specializations of other traits.
100+
101+
// Misc
102+
inline std::ostream &operator<<(std::ostream &O, bfloat16 const &rhs) {
103+
O << static_cast<float>(rhs);
104+
return O;
105+
}
106+
107+
} // namespace ext::intel::esimd::detail
108+
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
109+
} // namespace sycl
110+
111+
/// @endcond ESIMD_DETAIL

sycl/include/sycl/ext/intel/esimd/detail/half_type_traits.hpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,7 @@ template <int N> struct vector_conversion_traits<sycl::half, N> {
9696
};
9797

9898
// Proxy class to access bit representation of a wrapper type both on host and
99-
// device. Declared as friend to the wrapper types (e.g. sycl::half).
100-
// Specific type traits implementations (scalar_conversion_traits) can use
101-
// concrete wrapper type specializations of the static functions in this class
102-
// to access private fields in the wrapper type (e.g. sycl::half).
99+
// device. Declared as friend to the sycl::half.
103100
// TODO add this functionality to sycl type implementation? With C++20,
104101
// std::bit_cast should be a good replacement.
105102
class WrapperElementTypeProxy {

sycl/include/sycl/ext/intel/esimd/detail/intrin.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,14 @@ __esimd_wrindirect(__ESIMD_DNS::vector_type_t<T, N> OldVal,
335335
}
336336
return Result;
337337
}
338+
#endif // __SYCL_DEVICE_ONLY__
338339

340+
#ifdef __SYCL_DEVICE_ONLY__
341+
// This intrinsic requires one of the types to be _Float16, which is absent on
342+
// host, so it can't be represented on host. Callers must emulate it.
343+
template <class To, class From, int N>
344+
__ESIMD_INTRIN __ESIMD_DNS::vector_type_t<To, N>
345+
__esimd_bf_cvt(__ESIMD_DNS::vector_type_t<From, N> Val);
339346
#endif // __SYCL_DEVICE_ONLY__
340347

341348
/// @endcond ESIMD_DETAIL

0 commit comments

Comments
 (0)