Skip to content

Commit 1cb5cb0

Browse files
committed
[ESIMD] Support bfloat16 simd vector element type.
- Implement corresponding element type traits - Implement __esimd_bitcast and __esimd_bf_cvt intrinsics
1 parent 3dc891f commit 1cb5cb0

File tree

5 files changed

+145
-5
lines changed

5 files changed

+145
-5
lines changed

llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ class ESIMDIntrinDescTable {
654654
{"test_src_tmpl_arg",
655655
{"test.src.tmpl.arg", {t(0), t1(1), t8(2), t16(3), t32(4), c8(17)}}},
656656
{"slm_init", {"slm.init", {a(0)}}},
657-
};
657+
{"bf_cvt", {"bf.cvt", {a(0)}}}};
658658
}
659659

660660
const IntrinTable &getTable() { return Table; }
@@ -1106,6 +1106,15 @@ static void translateGetSurfaceIndex(CallInst &CI) {
11061106
CI.replaceAllUsesWith(SI);
11071107
}
11081108

1109+
static void translateBitcast(CallInst &CI) {
1110+
auto opnd = CI.getArgOperand(0);
1111+
IRBuilder<> Builder(&CI);
1112+
auto BC = Builder.CreateBitCast(opnd, CI.getType());
1113+
auto *SI = cast<CastInst>(BC);
1114+
SI->setDebugLoc(CI.getDebugLoc());
1115+
CI.replaceAllUsesWith(SI);
1116+
}
1117+
11091118
// Newly created GenX intrinsic might have different return type than expected.
11101119
// This helper function creates cast operation from GenX intrinsic return type
11111120
// to currently expected. Returns pointer to created cast instruction if it
@@ -1766,6 +1775,11 @@ size_t SYCLLowerESIMDPass::runOnFunction(Function &F,
17661775
ToErase.push_back(CI);
17671776
continue;
17681777
}
1778+
if (Name.startswith("__esimd_bitcast")) {
1779+
translateBitcast(*CI);
1780+
ToErase.push_back(CI);
1781+
continue;
1782+
}
17691783
assert(!Name.startswith("__esimd_set_kernel_properties") &&
17701784
"__esimd_set_kernel_properties must have been lowered");
17711785

sycl/include/sycl/ext/intel/esimd.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181

8282
#include <sycl/ext/intel/esimd/alt_ui.hpp>
8383
#include <sycl/ext/intel/esimd/common.hpp>
84+
#include <sycl/ext/intel/esimd/detail/bfloat16_type_traits.hpp>
8485
#include <sycl/ext/intel/esimd/detail/half_type_traits.hpp>
8586
#include <sycl/ext/intel/esimd/simd.hpp>
8687
#include <sycl/ext/intel/esimd/simd_view.hpp>
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
//==-------------- bfloat16_type_traits.hpp - DPC++ Explicit SIMD API ------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// Implementation of SIMD element type traits for the bfloat16 type.
9+
//===----------------------------------------------------------------------===//
10+
11+
#pragma once
12+
13+
#include <sycl/ext/intel/esimd/detail/elem_type_traits.hpp>
14+
#include <sycl/ext/intel/esimd/detail/intrin.hpp>
15+
16+
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>
17+
18+
/// @cond ESIMD_DETAIL
19+
20+
namespace sycl {
21+
__SYCL_INLINE_VER_NAMESPACE(_V1) {
22+
namespace ext::intel::esimd::detail {
23+
24+
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
25+
26+
template <> struct element_type_traits<bfloat16> {
27+
// TODO map the raw type to __bf16 once SPIRV target supports it:
28+
using RawT =
29+
typename std::invoke_result_t<decltype(&bfloat16::raw), bfloat16>;
30+
// Nearest standard enclosing C++ type to delegate natively unsupported
31+
// operations to:
32+
using EnclosingCppT = float;
33+
// Can't map bfloat16 operations to opertations on RawT:
34+
static inline constexpr bool use_native_cpp_ops = false;
35+
static inline constexpr bool is_floating_point = true;
36+
};
37+
38+
#ifdef __SYCL_DEVICE_ONLY__
39+
// VC BE-specific glitch
40+
// @llvm.genx.bf.cvt uses half (_Float16) as bit representation for bfloat16
41+
using vc_be_bfloat16_raw_t = _Float16;
42+
#endif // __SYCL_DEVICE_ONLY__
43+
44+
// ------------------- Type conversion traits
45+
46+
template <int N> struct vector_conversion_traits<bfloat16, N> {
47+
using StdT = __cpp_t<bfloat16>;
48+
using StdVecT = vector_type_t<StdT, N>;
49+
using RawT = __raw_t<bfloat16>;
50+
51+
static ESIMD_INLINE vector_type_t<RawT, N>
52+
convert_to_raw(vector_type_t<StdT, N> Val) {
53+
#ifdef __SYCL_DEVICE_ONLY__
54+
using RawVecT = vector_type_t<vc_be_bfloat16_raw_t, N>;
55+
RawVecT ConvVal = __esimd_bf_cvt<vc_be_bfloat16_raw_t, StdT, N>(Val);
56+
// cast from _Float16 to int16_t:
57+
return __esimd_bitcast<vector_type_t<RawT, N>>(ConvVal);
58+
#else
59+
vector_type_t<RawT, N> Output = 0;
60+
61+
for (int i = 0; i < N; i++) {
62+
Output[i] = bfloat16(Val[i]).raw();
63+
}
64+
return Output;
65+
#endif // __SYCL_DEVICE_ONLY__
66+
}
67+
68+
static ESIMD_INLINE vector_type_t<StdT, N>
69+
convert_to_cpp(vector_type_t<RawT, N> Val) {
70+
#ifdef __SYCL_DEVICE_ONLY__
71+
using RawVecT = vector_type_t<vc_be_bfloat16_raw_t, N>;
72+
RawVecT Bits = __esimd_bitcast<RawVecT>(Val);
73+
return __esimd_bf_cvt<StdT, vc_be_bfloat16_raw_t, N>(Bits);
74+
#else
75+
vector_type_t<StdT, N> Output;
76+
77+
for (int i = 0; i < N; i++) {
78+
Output[i] = bfloat16::from_bits(Val[i]);
79+
}
80+
return Output;
81+
#endif // __SYCL_DEVICE_ONLY__
82+
}
83+
};
84+
85+
template <> struct scalar_conversion_traits<bfloat16> {
86+
using RawT = __raw_t<bfloat16>;
87+
88+
static ESIMD_INLINE RawT bitcast_to_raw(bfloat16 Val) { return Val.raw(); }
89+
90+
static ESIMD_INLINE bfloat16 bitcast_to_wrapper(RawT Val) {
91+
return bfloat16::from_bits(Val);
92+
}
93+
};
94+
95+
// bfloat16 uses default inefficient implementations of std C++ operations,
96+
// hence no specializations of other traits.
97+
98+
// Misc
99+
inline std::ostream &operator<<(std::ostream &O, bfloat16 const &rhs) {
100+
O << static_cast<float>(rhs);
101+
return O;
102+
}
103+
104+
} // namespace ext::intel::esimd::detail
105+
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
106+
} // namespace sycl
107+
108+
/// @endcond ESIMD_DETAIL

sycl/include/sycl/ext/intel/esimd/detail/half_type_traits.hpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,7 @@ template <int N> struct vector_conversion_traits<sycl::half, N> {
9696
};
9797

9898
// Proxy class to access bit representation of a wrapper type both on host and
99-
// device. Declared as friend to the wrapper types (e.g. sycl::half).
100-
// Specific type traits implementations (scalar_conversion_traits) can use
101-
// concrete wrapper type specializations of the static functions in this class
102-
// to access private fields in the wrapper type (e.g. sycl::half).
99+
// device. Declared as friend to the sycl::half.
103100
// TODO add this functionality to sycl type implementation? With C++20,
104101
// std::bit_cast should be a good replacement.
105102
class WrapperElementTypeProxy {

sycl/include/sycl/ext/intel/esimd/detail/intrin.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,27 @@ __esimd_wrindirect(__ESIMD_DNS::vector_type_t<T, N> OldVal,
335335
}
336336
return Result;
337337
}
338+
#endif // __SYCL_DEVICE_ONLY__
338339

340+
// TODO should be replaced by std::bit_cast once C++20 is supported.
341+
template <class To, class From,
342+
class = std::enable_if_t<sizeof(From) == sizeof(To)>>
343+
__ESIMD_INTRIN To __esimd_bitcast(From Src)
344+
#ifdef __SYCL_DEVICE_ONLY__
345+
;
346+
#else
347+
{
348+
auto *Ptr = reinterpret_cast<To *>(&Src);
349+
return *Ptr;
350+
}
351+
#endif // __SYCL_DEVICE_ONLY__
352+
353+
#ifdef __SYCL_DEVICE_ONLY__
354+
// This intrinsic requires one of the types to be _Float16, which is absent on
355+
// host, so it can't be represented on host. Callers must emulate it.
356+
template <class To, class From, int N>
357+
__ESIMD_INTRIN __ESIMD_DNS::vector_type_t<To, N>
358+
__esimd_bf_cvt(__ESIMD_DNS::vector_type_t<From, N> Val);
339359
#endif // __SYCL_DEVICE_ONLY__
340360

341361
/// @endcond ESIMD_DETAIL

0 commit comments

Comments
 (0)