Skip to content

Commit 77259a9

Browse files
committed
Converting some operators to hidden friends
Also reorder member functions and improve the test. Signed-off-by: Alexey Sotkin <[email protected]>
1 parent 688bf73 commit 77259a9

File tree

2 files changed

+33
-30
lines changed

2 files changed

+33
-30
lines changed

sycl/include/sycl/ext/intel/experimental/bfloat16.hpp

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -60,22 +60,45 @@ class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 {
6060
// Get raw bits representation of bfloat16
6161
operator storage_t() const { return value; }
6262

63-
// Assignment operators overloading
63+
// Logical operators (!,||,&&) are covered if we can cast to bool
64+
explicit operator bool() { return to_float(value) != 0.0f; }
65+
66+
// Unary minus operator overloading
67+
friend bfloat16 operator-(bfloat16 &lhs) {
68+
return bfloat16{-to_float(lhs.value)};
69+
}
70+
71+
// Increment and decrement operators overloading
72+
#define OP(op) \
73+
friend bfloat16 &operator op(bfloat16 &lhs) { \
74+
float f = to_float(lhs.value); \
75+
lhs.value = from_float(op f); \
76+
return lhs; \
77+
} \
78+
friend bfloat16 operator op(bfloat16 &lhs, int) { \
79+
bfloat16 old = lhs; \
80+
operator op(lhs); \
81+
return old; \
82+
}
83+
OP(++)
84+
OP(--)
85+
#undef OP
86+
87+
// Assignment operators overloading
6488
#define OP(op) \
65-
friend bfloat16 operator op(bfloat16 &lhs, const bfloat16 &rhs) { \
89+
friend bfloat16 &operator op(bfloat16 &lhs, const bfloat16 &rhs) { \
6690
float f = static_cast<float>(lhs); \
6791
f op static_cast<float>(rhs); \
6892
return lhs = f; \
6993
} \
7094
\
7195
template <typename T> \
72-
friend bfloat16 operator op(bfloat16 &lhs, const T &rhs) { \
96+
friend bfloat16 &operator op(bfloat16 &lhs, const T &rhs) { \
7397
float f = static_cast<float>(lhs); \
74-
\
7598
f op static_cast<float>(rhs); \
7699
return lhs = f; \
77100
} \
78-
template <typename T> friend T operator op(T &lhs, const bfloat16 &rhs) { \
101+
template <typename T> friend T &operator op(T &lhs, const bfloat16 &rhs) { \
79102
float f = static_cast<float>(lhs); \
80103
f op static_cast<float>(rhs); \
81104
return lhs = f; \
@@ -86,28 +109,6 @@ class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 {
86109
OP(/=)
87110
#undef OP
88111

89-
// Increment and decrement operators overloading
90-
#define OP(op) \
91-
bfloat16 &operator op() { \
92-
float f = to_float(value); \
93-
value = from_float(op f); \
94-
return *this; \
95-
} \
96-
bfloat16 operator op(int) { \
97-
bfloat16 old = *this; \
98-
operator op(); \
99-
return old; \
100-
}
101-
OP(++)
102-
OP(--)
103-
#undef OP
104-
105-
// Unary minus operator overloading
106-
bfloat16 operator-() { return bfloat16{-to_float(value)}; }
107-
108-
// Logical operators (!,||,&&) are covered if we can cast to bool
109-
explicit operator bool() { return to_float(value) != 0.0f; }
110-
111112
// Binary operators overloading
112113
#define OP(type, op) \
113114
friend type operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \
@@ -132,6 +133,7 @@ class [[sycl_detail::uses_aspects(ext_intel_bf16_conversion)]] bfloat16 {
132133
OP(bool, <=)
133134
OP(bool, >=)
134135
#undef OP
136+
135137
// Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported
136138
// for floating-point types.
137139
};

sycl/test/extensions/bfloat16.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using sycl::ext::intel::experimental::bfloat16;
77

88
SYCL_EXTERNAL uint16_t some_bf16_intrinsic(uint16_t x, uint16_t y);
99

10+
__attribute__((noinline))
1011
float op(float a, float b) {
1112
bfloat16 A {a};
1213
// CHECK: [[A:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float %a)
@@ -16,11 +17,11 @@ float op(float a, float b) {
1617
// CHECK: [[B:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float %b)
1718
// CHECK-NOT: fptoui
1819

19-
bfloat16 C = static_cast<float>(A) + static_cast<float>(B);
20+
bfloat16 C = A + B;
2021
// CHECK: [[A_float:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[A]])
2122
// CHECK: [[B_float:%.*]] = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELt(i16 zeroext [[B]])
22-
// CHECK: %add = fadd float [[A_float]], [[B_float]]
23-
// CHECK: [[C:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float %add)
23+
// CHECK: [[Add:%.*]] = fadd float [[A_float]], [[B_float]]
24+
// CHECK: [[C:%.*]] = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float [[Add]])
2425
// CHECK-NOT: uitofp
2526
// CHECK-NOT: fptoui
2627

0 commit comments

Comments
 (0)