Skip to content

Commit 373b27f

Browse files
committed
[SYCL][CUDA] add bf16 builtins
1 parent d38b599 commit 373b27f

File tree

3 files changed

+102
-0
lines changed

3 files changed

+102
-0
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,28 @@ __spirv_ocl_printf(const __attribute__((opencl_constant)) char *Format, ...);
755755
extern SYCL_EXTERNAL int __spirv_ocl_printf(const char *Format, ...);
756756
#endif
757757

758+
extern SYCL_EXTERNAL __SYCL_EXPORT uint16_t __clc_fabs(uint16_t) noexcept;
759+
760+
#define __CLC_BF16(...) \
761+
extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fabs(__VA_ARGS__) noexcept; \
762+
extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fmin(__VA_ARGS__, __VA_ARGS__) noexcept; \
763+
extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fmax(__VA_ARGS__, __VA_ARGS__) noexcept; \
764+
extern SYCL_EXTERNAL __SYCL_EXPORT __VA_ARGS__ __clc_fma(__VA_ARGS__, __VA_ARGS__, __VA_ARGS__) noexcept;
765+
766+
#define __CLC_BF16_SCAL_VEC(TYPE) \
767+
__CLC_BF16(TYPE) \
768+
__CLC_BF16(__ocl_vec_t<TYPE, 2>) \
769+
__CLC_BF16(__ocl_vec_t<TYPE, 3>) \
770+
__CLC_BF16(__ocl_vec_t<TYPE, 4>) \
771+
__CLC_BF16(__ocl_vec_t<TYPE, 8>) \
772+
__CLC_BF16(__ocl_vec_t<TYPE, 16>)
773+
774+
__CLC_BF16_SCAL_VEC(uint16_t)
775+
__CLC_BF16_SCAL_VEC(uint32_t)
776+
777+
#undef __CLC_BF16_SCAL_VEC
778+
#undef __CLC_BF16
779+
758780
#else // if !__SYCL_DEVICE_ONLY__
759781

760782
template <typename dataT>

sycl/include/CL/sycl.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
#if SYCL_EXT_ONEAPI_BACKEND_LEVEL_ZERO
6161
#include <sycl/ext/oneapi/backend/level_zero.hpp>
6262
#endif
63+
#include <sycl/ext/oneapi/bf16_storage_builtins.hpp>
6364
#include <sycl/ext/oneapi/experimental/builtins.hpp>
6465
#include <sycl/ext/oneapi/filter_selector.hpp>
6566
#include <sycl/ext/oneapi/group_algorithm.hpp>
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#pragma once
2+
3+
#include <CL/__spirv/spirv_ops.hpp>
4+
#include <CL/sycl/builtins.hpp>
5+
#include <CL/sycl/detail/builtins.hpp>
6+
#include <CL/sycl/detail/generic_type_lists.hpp>
7+
#include <CL/sycl/detail/generic_type_traits.hpp>
8+
#include <CL/sycl/detail/type_traits.hpp>
9+
10+
__SYCL_INLINE_NAMESPACE(cl) {
11+
namespace sycl {
12+
namespace ext {
13+
namespace oneapi {
14+
15+
namespace detail {
16+
17+
template <typename T> struct is_bf16_storage_type {
18+
static constexpr int value = false;
19+
};
20+
21+
template <> struct is_bf16_storage_type<uint16_t> {
22+
static constexpr int value = true;
23+
};
24+
25+
template <> struct is_bf16_storage_type<uint32_t> {
26+
static constexpr int value = true;
27+
};
28+
29+
template <int N> struct is_bf16_storage_type<vec<uint16_t, N>> {
30+
static constexpr int value = true;
31+
};
32+
33+
template <int N> struct is_bf16_storage_type<vec<uint32_t, N>> {
34+
static constexpr int value = true;
35+
};
36+
37+
} // namespace detail
38+
39+
template <typename T>
40+
std::enable_if_t<detail::is_bf16_storage_type<T>::value, T> fabs(T x) {
41+
#ifdef __SYCL_DEVICE_ONLY__
42+
return __clc_fabs(x);
43+
#else
44+
throw runtime_error("bf16 is not supported on host device.",
45+
PI_INVALID_DEVICE);
46+
#endif
47+
}
48+
template <typename T>
49+
std::enable_if_t<detail::is_bf16_storage_type<T>::value, T> fmin(T x, T y) {
50+
#ifdef __SYCL_DEVICE_ONLY__
51+
return __clc_fmin(x, y);
52+
#else
53+
throw runtime_error("bf16 is not supported on host device.",
54+
PI_INVALID_DEVICE);
55+
#endif
56+
}
57+
template <typename T>
58+
std::enable_if_t<detail::is_bf16_storage_type<T>::value, T> fmax(T x, T y) {
59+
#ifdef __SYCL_DEVICE_ONLY__
60+
return __clc_fmax(x, y);
61+
#else
62+
throw runtime_error("bf16 is not supported on host device.",
63+
PI_INVALID_DEVICE);
64+
#endif
65+
}
66+
template <typename T>
67+
std::enable_if_t<detail::is_bf16_storage_type<T>::value, T> fma(T x, T y, T z) {
68+
#ifdef __SYCL_DEVICE_ONLY__
69+
return __clc_fma(x, y, z);
70+
#else
71+
throw runtime_error("bf16 is not supported on host device.",
72+
PI_INVALID_DEVICE);
73+
#endif
74+
}
75+
76+
} // namespace oneapi
77+
} // namespace ext
78+
} // namespace sycl
79+
} // __SYCL_INLINE_NAMESPACE(cl)

0 commit comments

Comments
 (0)