Skip to content

Commit 025cf7e

Browse files
committed
Added bfloat16 support for cuda backend.
Added bfloat16 in oneapi experimental namespace. Signed-off-by: jack.kirk <[email protected]>
1 parent fd73a49 commit 025cf7e

File tree

1 file changed

+161
-0
lines changed

1 file changed

+161
-0
lines changed
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
//==--------- bfloat16.hpp ------- SYCL bfloat16 conversion ----------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
11+
#include <CL/__spirv/spirv_ops.hpp>
12+
#include <CL/sycl/half_type.hpp>
13+
14+
__SYCL_INLINE_NAMESPACE(cl) {
15+
namespace sycl {
16+
namespace ext {
17+
namespace oneapi {
18+
namespace experimental {
19+
20+
class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 {
21+
using storage_t = uint16_t;
22+
storage_t value;
23+
24+
public:
25+
bfloat16() = default;
26+
bfloat16(const bfloat16 &) = default;
27+
~bfloat16() = default;
28+
29+
// Explicit conversion functions
30+
static storage_t from_float(const float &a) {
31+
#if defined(__SYCL_DEVICE_ONLY__)
32+
#if defined(__NVPTX__)
33+
return __nvvm_f2bf16_rn(a);
34+
#else
35+
return __spirv_ConvertFToBF16INTEL(a);
36+
#endif
37+
#else
38+
throw exception{errc::feature_not_supported,
39+
"Bfloat16 conversion is not supported on host device"};
40+
#endif
41+
}
42+
static float to_float(const storage_t &a) {
43+
#if defined(__SYCL_DEVICE_ONLY__)
44+
#if defined(__NVPTX__)
45+
unsigned int y = a;
46+
y = y << 16;
47+
float *res = reinterpret_cast<float *>(&y);
48+
return *res;
49+
#else
50+
return __spirv_ConvertBF16ToFINTEL(a);
51+
#endif
52+
#else
53+
throw exception{errc::feature_not_supported,
54+
"Bfloat16 conversion is not supported on host device"};
55+
#endif
56+
}
57+
58+
static bfloat16 from_bits(const storage_t &a) {
59+
bfloat16 res;
60+
res.value = a;
61+
return res;
62+
}
63+
64+
// Implicit conversion from float to bfloat16
65+
bfloat16(const float &a) { value = from_float(a); }
66+
67+
bfloat16 &operator=(const float &rhs) {
68+
value = from_float(rhs);
69+
return *this;
70+
}
71+
72+
// Implicit conversion from bfloat16 to float
73+
operator float() const { return to_float(value); }
74+
operator sycl::half() const { return to_float(value); }
75+
76+
// Get raw bits representation of bfloat16
77+
storage_t raw() const { return value; }
78+
79+
// Logical operators (!,||,&&) are covered if we can cast to bool
80+
explicit operator bool() { return to_float(value) != 0.0f; }
81+
82+
// Unary minus operator overloading
83+
friend bfloat16 operator-(bfloat16 &lhs) {
84+
return bfloat16{-to_float(lhs.value)};
85+
}
86+
87+
// Increment and decrement operators overloading
88+
#define OP(op) \
89+
friend bfloat16 &operator op(bfloat16 &lhs) { \
90+
float f = to_float(lhs.value); \
91+
lhs.value = from_float(op f); \
92+
return lhs; \
93+
} \
94+
friend bfloat16 operator op(bfloat16 &lhs, int) { \
95+
bfloat16 old = lhs; \
96+
operator op(lhs); \
97+
return old; \
98+
}
99+
OP(++)
100+
OP(--)
101+
#undef OP
102+
103+
// Assignment operators overloading
104+
#define OP(op) \
105+
friend bfloat16 &operator op(bfloat16 &lhs, const bfloat16 &rhs) { \
106+
float f = static_cast<float>(lhs); \
107+
f op static_cast<float>(rhs); \
108+
return lhs = f; \
109+
} \
110+
template <typename T> \
111+
friend bfloat16 &operator op(bfloat16 &lhs, const T &rhs) { \
112+
float f = static_cast<float>(lhs); \
113+
f op static_cast<float>(rhs); \
114+
return lhs = f; \
115+
} \
116+
template <typename T> friend T &operator op(T &lhs, const bfloat16 &rhs) { \
117+
float f = static_cast<float>(lhs); \
118+
f op static_cast<float>(rhs); \
119+
return lhs = f; \
120+
}
121+
OP(+=)
122+
OP(-=)
123+
OP(*=)
124+
OP(/=)
125+
#undef OP
126+
127+
// Binary operators overloading
128+
#define OP(type, op) \
129+
friend type operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \
130+
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
131+
} \
132+
template <typename T> \
133+
friend type operator op(const bfloat16 &lhs, const T &rhs) { \
134+
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
135+
} \
136+
template <typename T> \
137+
friend type operator op(const T &lhs, const bfloat16 &rhs) { \
138+
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
139+
}
140+
OP(bfloat16, +)
141+
OP(bfloat16, -)
142+
OP(bfloat16, *)
143+
OP(bfloat16, /)
144+
OP(bool, ==)
145+
OP(bool, !=)
146+
OP(bool, <)
147+
OP(bool, >)
148+
OP(bool, <=)
149+
OP(bool, >=)
150+
#undef OP
151+
152+
// Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported
153+
// for floating-point types.
154+
};
155+
156+
} // namespace experimental
157+
} // namespace intel
158+
} // namespace ext
159+
160+
} // namespace sycl
161+
} // __SYCL_INLINE_NAMESPACE(cl)

0 commit comments

Comments
 (0)