Skip to content

[ESIMD] Support bfloat16 simd vector element type. #6664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion llvm/lib/SYCLLowerIR/ESIMD/ESIMDVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ static const char *LegalSYCLFunctions[] = {
"^sycl::_V1::sin<.+>",
"^sycl::_V1::log<.+>",
"^sycl::_V1::exp<.+>",
"^sycl::_V1::bit_cast<.+>",
"^sycl::_V1::operator.+<.+>",
"^sycl::_V1::ext::oneapi::sub_group::.+",
"^sycl::_V1::ext::oneapi::experimental::spec_constant<.+>::.+",
"^sycl::_V1::ext::oneapi::experimental::this_sub_group"};
"^sycl::_V1::ext::oneapi::experimental::this_sub_group",
"^sycl::_V1::ext::oneapi::experimental::bfloat16::.+"};

static const char *LegalSYCLFunctionsInStatelessMode[] = {
"^sycl::_V1::multi_ptr<.+>::get",
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ class ESIMDIntrinDescTable {
{"test_src_tmpl_arg",
{"test.src.tmpl.arg", {t(0), t1(1), t8(2), t16(3), t32(4), c8(17)}}},
{"slm_init", {"slm.init", {a(0)}}},
};
{"bf_cvt", {"bf.cvt", {a(0)}}}};
}

const IntrinTable &getTable() { return Table; }
Expand Down
1 change: 1 addition & 0 deletions sycl/include/sycl/ext/intel/esimd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@

#include <sycl/ext/intel/esimd/alt_ui.hpp>
#include <sycl/ext/intel/esimd/common.hpp>
#include <sycl/ext/intel/esimd/detail/bfloat16_type_traits.hpp>
#include <sycl/ext/intel/esimd/detail/half_type_traits.hpp>
#include <sycl/ext/intel/esimd/simd.hpp>
#include <sycl/ext/intel/esimd/simd_view.hpp>
Expand Down
111 changes: 111 additions & 0 deletions sycl/include/sycl/ext/intel/esimd/detail/bfloat16_type_traits.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
//==-------------- bfloat16_type_traits.hpp - DPC++ Explicit SIMD API ------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Implementation of SIMD element type traits for the bfloat16 type.
//===----------------------------------------------------------------------===//

#pragma once

#include <sycl/ext/intel/esimd/detail/elem_type_traits.hpp>
#include <sycl/ext/intel/esimd/detail/intrin.hpp>

#include <sycl/ext/oneapi/experimental/bfloat16.hpp>

/// @cond ESIMD_DETAIL

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
namespace ext::intel::esimd::detail {

using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;

template <> struct element_type_traits<bfloat16> {
// TODO map the raw type to __bf16 once SPIRV target supports it:
using RawT = uint_type_t<sizeof(bfloat16)>;
// Nearest standard enclosing C++ type to delegate natively unsupported
// operations to:
using EnclosingCppT = float;
// Can't map bfloat16 operations to opertations on RawT:
static inline constexpr bool use_native_cpp_ops = false;
static inline constexpr bool is_floating_point = true;
};

#ifdef __SYCL_DEVICE_ONLY__
// VC BE-specific glitch
// @llvm.genx.bf.cvt uses half (_Float16) as bit representation for bfloat16
using vc_be_bfloat16_raw_t = _Float16;
#endif // __SYCL_DEVICE_ONLY__

// ------------------- Type conversion traits

template <int N> struct vector_conversion_traits<bfloat16, N> {
using StdT = __cpp_t<bfloat16>;
using StdVecT = vector_type_t<StdT, N>;
using RawT = __raw_t<bfloat16>;

static ESIMD_INLINE vector_type_t<RawT, N>
convert_to_raw(vector_type_t<StdT, N> Val) {
#ifdef __SYCL_DEVICE_ONLY__
using RawVecT = vector_type_t<vc_be_bfloat16_raw_t, N>;
RawVecT ConvVal = __esimd_bf_cvt<vc_be_bfloat16_raw_t, StdT, N>(Val);
// cast from _Float16 to int16_t:
return sycl::bit_cast<vector_type_t<RawT, N>>(ConvVal);
#else
vector_type_t<RawT, N> Output = 0;

for (int i = 0; i < N; i++) {
Output[i] = sycl::bit_cast<RawT>(static_cast<bfloat16>(Val[i]));
}
return Output;
#endif // __SYCL_DEVICE_ONLY__
}

static ESIMD_INLINE vector_type_t<StdT, N>
convert_to_cpp(vector_type_t<RawT, N> Val) {
#ifdef __SYCL_DEVICE_ONLY__
using RawVecT = vector_type_t<vc_be_bfloat16_raw_t, N>;
RawVecT Bits = sycl::bit_cast<RawVecT>(Val);
return __esimd_bf_cvt<StdT, vc_be_bfloat16_raw_t, N>(Bits);
#else
vector_type_t<StdT, N> Output;

for (int i = 0; i < N; i++) {
Output[i] = sycl::bit_cast<bfloat16>(Val[i]);
}
return Output;
#endif // __SYCL_DEVICE_ONLY__
}
};

// TODO: remove bitcasts from the scalar_conversion_traits, and replace with
// sycl::bit_cast directly
template <> struct scalar_conversion_traits<bfloat16> {
using RawT = __raw_t<bfloat16>;

static ESIMD_INLINE RawT bitcast_to_raw(bfloat16 Val) {
return sycl::bit_cast<RawT>(Val);
}

static ESIMD_INLINE bfloat16 bitcast_to_wrapper(RawT Val) {
return sycl::bit_cast<bfloat16>(Val);
}
};

// bfloat16 uses default inefficient implementations of std C++ operations,
// hence no specializations of other traits.

// Misc
inline std::ostream &operator<<(std::ostream &O, bfloat16 const &rhs) {
O << static_cast<float>(rhs);
return O;
}

} // namespace ext::intel::esimd::detail
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
} // namespace sycl

/// @endcond ESIMD_DETAIL
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,7 @@ template <int N> struct vector_conversion_traits<sycl::half, N> {
};

// Proxy class to access bit representation of a wrapper type both on host and
// device. Declared as friend to the wrapper types (e.g. sycl::half).
// Specific type traits implementations (scalar_conversion_traits) can use
// concrete wrapper type specializations of the static functions in this class
// to access private fields in the wrapper type (e.g. sycl::half).
// device. Declared as friend to the sycl::half.
// TODO add this functionality to sycl type implementation? With C++20,
// std::bit_cast should be a good replacement.
class WrapperElementTypeProxy {
Expand Down
7 changes: 7 additions & 0 deletions sycl/include/sycl/ext/intel/esimd/detail/intrin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,14 @@ __esimd_wrindirect(__ESIMD_DNS::vector_type_t<T, N> OldVal,
}
return Result;
}
#endif // __SYCL_DEVICE_ONLY__

#ifdef __SYCL_DEVICE_ONLY__
// This intrinsic requires one of the types to be _Float16, which is absent on
// host, so it can't be represented on host. Callers must emulate it.
template <class To, class From, int N>
__ESIMD_INTRIN __ESIMD_DNS::vector_type_t<To, N>
__esimd_bf_cvt(__ESIMD_DNS::vector_type_t<From, N> Val);
#endif // __SYCL_DEVICE_ONLY__

/// @endcond ESIMD_DETAIL