Skip to content

Commit cd21221

Browse files
committed
Add explicit conversion functions and operator overloadings
Signed-off-by: Alexey Sotkin <[email protected]>
1 parent a3f17fa commit cd21221

File tree

1 file changed

+95
-11
lines changed

1 file changed

+95
-11
lines changed

sycl/include/sycl/ext/intel/experimental/bfloat16.hpp

Lines changed: 95 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,31 +25,115 @@ class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 {
2525
bfloat16(const bfloat16 &) = default;
2626
~bfloat16() = default;
2727

28-
// Direct initialization
29-
bfloat16(const storage_t &a) : value(a) {}
30-
31-
// convert from float to bfloat16
32-
bfloat16(const float &a) {
28+
// Explicit conversion functions
29+
static storage_t from_float(const float &a) {
3330
#if defined(__SYCL_DEVICE_ONLY__)
34-
value = __spirv_ConvertFToBF16INTEL(a);
31+
return __spirv_ConvertFToBF16INTEL(a);
3532
#else
3633
throw runtime_error("Bfloat16 conversion is not supported on HOST device.",
3734
PI_INVALID_DEVICE);
3835
#endif
3936
}
40-
41-
// convert from bfloat16 to float
42-
operator float() const {
37+
static float to_float(const storage_t &a) {
4338
#if defined(__SYCL_DEVICE_ONLY__)
44-
return __spirv_ConvertBF16ToFINTEL(value);
39+
return __spirv_ConvertBF16ToFINTEL(a);
4540
#else
4641
throw runtime_error("Bfloat16 conversion is not supported on HOST device.",
4742
PI_INVALID_DEVICE);
4843
#endif
4944
}
5045

51-
// Get bfloat16 as uint16.
46+
// Direct initialization
47+
bfloat16(const storage_t &a) : value(a) {}
48+
49+
// Implicit conversion from float to bfloat16
50+
bfloat16(const float &a) { value = from_float(a); }
51+
52+
bfloat16 &operator=(const float &rhs) {
53+
value = from_float(rhs);
54+
return *this;
55+
}
56+
57+
// Implicit conversion from bfloat16 to float
58+
operator float() const { return to_float(value); }
59+
60+
// Get raw bits representation of bfloat16
5261
operator storage_t() const { return value; }
62+
63+
// Assignment operators overloading
64+
#define OP(op) \
65+
friend bfloat16 operator op(bfloat16 &lhs, const bfloat16 &rhs) { \
66+
float f = static_cast<float>(lhs); \
67+
f op static_cast<float>(rhs); \
68+
return lhs = f; \
69+
} \
70+
\
71+
template <typename T> \
72+
friend bfloat16 operator op(bfloat16 &lhs, const T &rhs) { \
73+
float f = static_cast<float>(lhs); \
74+
\
75+
f op static_cast<float>(rhs); \
76+
return lhs = f; \
77+
} \
78+
template <typename T> friend T operator op(T &lhs, const bfloat16 &rhs) { \
79+
float f = static_cast<float>(lhs); \
80+
f op static_cast<float>(rhs); \
81+
return lhs = f; \
82+
}
83+
OP(+=)
84+
OP(-=)
85+
OP(*=)
86+
OP(/=)
87+
#undef OP
88+
89+
// Increment and decrement operators overloading
90+
#define OP(op) \
91+
bfloat16 &operator op() { \
92+
float f = to_float(value); \
93+
value = from_float(op f); \
94+
return *this; \
95+
} \
96+
bfloat16 operator op(int) { \
97+
bfloat16 old = *this; \
98+
operator op(); \
99+
return old; \
100+
}
101+
OP(++)
102+
OP(--)
103+
#undef OP
104+
105+
// Unary minus operator overloading
106+
bfloat16 operator-() { return bfloat16{-to_float(value)}; }
107+
108+
// Logical operators (!,||,&&) are covered if we can cast to bool
109+
explicit operator bool() { return to_float(value) != 0.0f; }
110+
111+
// Binary operators overloading
112+
#define OP(type, op) \
113+
friend type operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \
114+
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
115+
} \
116+
template <typename T> \
117+
friend type operator op(const bfloat16 &lhs, const T &rhs) { \
118+
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
119+
} \
120+
template <typename T> \
121+
friend type operator op(const T &lhs, const bfloat16 &rhs) { \
122+
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
123+
}
124+
OP(bfloat16, +)
125+
OP(bfloat16, -)
126+
OP(bfloat16, *)
127+
OP(bfloat16, /)
128+
OP(bool, ==)
129+
OP(bool, !=)
130+
OP(bool, <)
131+
OP(bool, >)
132+
OP(bool, <=)
133+
OP(bool, >=)
134+
#undef OP
135+
// Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported
136+
// for floating-point types.
53137
};
54138

55139
} // namespace experimental

0 commit comments

Comments
 (0)