Skip to content

Commit 8c92df9

Browse files
authored
[SYCL][ESIMD]Limit bfloat16 operators to scalars to enable operations with simd vectors (#12089)
The purpose of this change is to limit operators defined for bfloat16 to scalar types to allow arithmetic operations between bfloat16 scalars and simd vectors. This allows to use simd operators that are defined separately and support operations between vectors and scalars
1 parent a90aaa7 commit 8c92df9

File tree

5 files changed

+282
-61
lines changed

5 files changed

+282
-61
lines changed

sycl/include/sycl/ext/intel/esimd/detail/bfloat16_type_traits.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ inline std::ostream &operator<<(std::ostream &O, bfloat16 const &rhs) {
9494
return O;
9595
}
9696

97+
template <> struct is_esimd_arithmetic_type<bfloat16, void> : std::true_type {};
98+
9799
} // namespace ext::intel::esimd::detail
98100
} // namespace _V1
99101
} // namespace sycl

sycl/include/sycl/ext/oneapi/bfloat16.hpp

Lines changed: 164 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -132,69 +132,175 @@ class bfloat16 {
132132
#endif
133133
}
134134

135-
// Increment and decrement operators overloading
135+
bfloat16 &operator+=(const bfloat16 &rhs) {
136+
value = from_float(to_float(value) + to_float(rhs.value));
137+
return *this;
138+
}
139+
140+
bfloat16 &operator-=(const bfloat16 &rhs) {
141+
value = from_float(to_float(value) - to_float(rhs.value));
142+
return *this;
143+
}
144+
145+
bfloat16 &operator*=(const bfloat16 &rhs) {
146+
value = from_float(to_float(value) * to_float(rhs.value));
147+
return *this;
148+
}
149+
150+
bfloat16 &operator/=(const bfloat16 &rhs) {
151+
value = from_float(to_float(value) / to_float(rhs.value));
152+
return *this;
153+
}
154+
155+
// Operator ++, --
156+
bfloat16 &operator++() {
157+
float f = to_float(value);
158+
value = from_float(++f);
159+
return *this;
160+
}
161+
162+
bfloat16 operator++(int) {
163+
bfloat16 ret(*this);
164+
operator++();
165+
return ret;
166+
}
167+
168+
bfloat16 &operator--() {
169+
float f = to_float(value);
170+
value = from_float(--f);
171+
return *this;
172+
}
173+
174+
bfloat16 operator--(int) {
175+
bfloat16 ret(*this);
176+
operator--();
177+
return ret;
178+
}
179+
180+
// Operator +, -, *, /
136181
#define OP(op) \
137-
friend bfloat16 &operator op(bfloat16 &lhs) { \
138-
float f = to_float(lhs.value); \
139-
lhs.value = from_float(op f); \
140-
return lhs; \
141-
} \
142-
friend bfloat16 operator op(bfloat16 &lhs, int) { \
143-
bfloat16 old = lhs; \
144-
operator op(lhs); \
145-
return old; \
146-
}
147-
OP(++)
148-
OP(--)
182+
friend bfloat16 operator op(const bfloat16 lhs, const bfloat16 rhs) { \
183+
return to_float(lhs.value) op to_float(rhs.value); \
184+
} \
185+
friend double operator op(const bfloat16 lhs, const double rhs) { \
186+
return to_float(lhs.value) op rhs; \
187+
} \
188+
friend double operator op(const double lhs, const bfloat16 rhs) { \
189+
return lhs op to_float(rhs.value); \
190+
} \
191+
friend float operator op(const bfloat16 lhs, const float rhs) { \
192+
return to_float(lhs.value) op rhs; \
193+
} \
194+
friend float operator op(const float lhs, const bfloat16 rhs) { \
195+
return lhs op to_float(rhs.value); \
196+
} \
197+
friend bfloat16 operator op(const bfloat16 lhs, const int rhs) { \
198+
return to_float(lhs.value) op rhs; \
199+
} \
200+
friend bfloat16 operator op(const int lhs, const bfloat16 rhs) { \
201+
return lhs op to_float(rhs.value); \
202+
} \
203+
friend bfloat16 operator op(const bfloat16 lhs, const long rhs) { \
204+
return to_float(lhs.value) op rhs; \
205+
} \
206+
friend bfloat16 operator op(const long lhs, const bfloat16 rhs) { \
207+
return lhs op to_float(rhs.value); \
208+
} \
209+
friend bfloat16 operator op(const bfloat16 lhs, const long long rhs) { \
210+
return to_float(lhs.value) op rhs; \
211+
} \
212+
friend bfloat16 operator op(const long long lhs, const bfloat16 rhs) { \
213+
return lhs op to_float(rhs.value); \
214+
} \
215+
friend bfloat16 operator op(const bfloat16 &lhs, const unsigned int &rhs) { \
216+
return to_float(lhs.value) op rhs; \
217+
} \
218+
friend bfloat16 operator op(const unsigned int &lhs, const bfloat16 &rhs) { \
219+
return lhs op to_float(rhs.value); \
220+
} \
221+
friend bfloat16 operator op(const bfloat16 &lhs, const unsigned long &rhs) { \
222+
return to_float(lhs.value) op rhs; \
223+
} \
224+
friend bfloat16 operator op(const unsigned long &lhs, const bfloat16 &rhs) { \
225+
return lhs op to_float(rhs.value); \
226+
} \
227+
friend bfloat16 operator op(const bfloat16 &lhs, \
228+
const unsigned long long &rhs) { \
229+
return to_float(lhs.value) op rhs; \
230+
} \
231+
friend bfloat16 operator op(const unsigned long long &lhs, \
232+
const bfloat16 &rhs) { \
233+
return lhs op to_float(rhs.value); \
234+
}
235+
OP(+)
236+
OP(-)
237+
OP(*)
238+
OP(/)
239+
149240
#undef OP
150241

151-
// Assignment operators overloading
242+
// Operator ==, !=, <, >, <=, >=
152243
#define OP(op) \
153-
friend bfloat16 &operator op(bfloat16 &lhs, const bfloat16 &rhs) { \
154-
float f = static_cast<float>(lhs); \
155-
f op static_cast<float>(rhs); \
156-
return lhs = f; \
157-
} \
158-
template <typename T> \
159-
friend bfloat16 &operator op(bfloat16 &lhs, const T &rhs) { \
160-
float f = static_cast<float>(lhs); \
161-
f op static_cast<float>(rhs); \
162-
return lhs = f; \
163-
} \
164-
template <typename T> friend T &operator op(T &lhs, const bfloat16 &rhs) { \
165-
float f = static_cast<float>(lhs); \
166-
f op static_cast<float>(rhs); \
167-
return lhs = f; \
168-
}
169-
OP(+=)
170-
OP(-=)
171-
OP(*=)
172-
OP(/=)
173-
#undef OP
244+
friend bool operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \
245+
return to_float(lhs.value) op to_float(rhs.value); \
246+
} \
247+
friend bool operator op(const bfloat16 &lhs, const double &rhs) { \
248+
return to_float(lhs.value) op rhs; \
249+
} \
250+
friend bool operator op(const double &lhs, const bfloat16 &rhs) { \
251+
return lhs op to_float(rhs.value); \
252+
} \
253+
friend bool operator op(const bfloat16 &lhs, const float &rhs) { \
254+
return to_float(lhs.value) op rhs; \
255+
} \
256+
friend bool operator op(const float &lhs, const bfloat16 &rhs) { \
257+
return lhs op to_float(rhs.value); \
258+
} \
259+
friend bool operator op(const bfloat16 &lhs, const int &rhs) { \
260+
return to_float(lhs.value) op rhs; \
261+
} \
262+
friend bool operator op(const int &lhs, const bfloat16 &rhs) { \
263+
return lhs op to_float(rhs.value); \
264+
} \
265+
friend bool operator op(const bfloat16 &lhs, const long &rhs) { \
266+
return to_float(lhs.value) op rhs; \
267+
} \
268+
friend bool operator op(const long &lhs, const bfloat16 &rhs) { \
269+
return lhs op to_float(rhs.value); \
270+
} \
271+
friend bool operator op(const bfloat16 &lhs, const long long &rhs) { \
272+
return to_float(lhs.value) op rhs; \
273+
} \
274+
friend bool operator op(const long long &lhs, const bfloat16 &rhs) { \
275+
return lhs op to_float(rhs.value); \
276+
} \
277+
friend bool operator op(const bfloat16 &lhs, const unsigned int &rhs) { \
278+
return to_float(lhs.value) op rhs; \
279+
} \
280+
friend bool operator op(const unsigned int &lhs, const bfloat16 &rhs) { \
281+
return lhs op to_float(rhs.value); \
282+
} \
283+
friend bool operator op(const bfloat16 &lhs, const unsigned long &rhs) { \
284+
return to_float(lhs.value) op rhs; \
285+
} \
286+
friend bool operator op(const unsigned long &lhs, const bfloat16 &rhs) { \
287+
return lhs op to_float(rhs.value); \
288+
} \
289+
friend bool operator op(const bfloat16 &lhs, \
290+
const unsigned long long &rhs) { \
291+
return to_float(lhs.value) op rhs; \
292+
} \
293+
friend bool operator op(const unsigned long long &lhs, \
294+
const bfloat16 &rhs) { \
295+
return lhs op to_float(rhs.value); \
296+
}
297+
OP(==)
298+
OP(!=)
299+
OP(<)
300+
OP(>)
301+
OP(<=)
302+
OP(>=)
174303

175-
// Binary operators overloading
176-
#define OP(type, op) \
177-
friend type operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \
178-
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
179-
} \
180-
template <typename T> \
181-
friend type operator op(const bfloat16 &lhs, const T &rhs) { \
182-
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
183-
} \
184-
template <typename T> \
185-
friend type operator op(const T &lhs, const bfloat16 &rhs) { \
186-
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
187-
}
188-
OP(bfloat16, +)
189-
OP(bfloat16, -)
190-
OP(bfloat16, *)
191-
OP(bfloat16, /)
192-
OP(bool, ==)
193-
OP(bool, !=)
194-
OP(bool, <)
195-
OP(bool, >)
196-
OP(bool, <=)
197-
OP(bool, >=)
198304
#undef OP
199305

200306
// Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported

sycl/test-e2e/ESIMD/regression/bfloat16_half_vector_plus_eq_scalar.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,11 @@ int main() {
9191
}
9292

9393
#ifdef USE_BF16
94-
// TODO: Reenable once the issue with bfloat16 is resolved
95-
// Passed &= test<sycl::ext::oneapi::bfloat16>(Q);
94+
Passed &= test<sycl::ext::oneapi::bfloat16>(Q);
9695
#endif
9796
#ifdef USE_TF32
9897
Passed &= test<sycl::ext::intel::experimental::esimd::tfloat32>(Q);
9998
#endif
10099
std::cout << (Passed ? "Passed\n" : "FAILED\n");
101100
return Passed ? 0 : 1;
102-
}
101+
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
//==- bfloat16_vector_plus_scalar.cpp - Test for bfloat16 operators ------==//
4+
//
5+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6+
// See https://llvm.org/LICENSE.txt for license information.
7+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
//
9+
//===----------------------------------------------------------------------===//
10+
#include "../esimd_test_utils.hpp"
11+
#include <iostream>
12+
#include <sycl/ext/intel/esimd.hpp>
13+
#include <sycl/sycl.hpp>
14+
15+
using namespace sycl;
16+
using namespace sycl::ext::intel::esimd;
17+
using namespace sycl::ext::intel::experimental::esimd;
18+
19+
template <typename T> ESIMD_NOINLINE bool test(queue Q) {
20+
std::cout << "Testing T=" << esimd_test::type_name<T>() << "...\n";
21+
22+
constexpr int N = 8;
23+
24+
constexpr int NumOps = 4;
25+
constexpr int CSize = NumOps * N;
26+
27+
T *Mem = malloc_shared<T>(CSize, Q);
28+
T TOne = static_cast<T>(1);
29+
T TTen = static_cast<T>(10);
30+
31+
Q.single_task([=]() SYCL_ESIMD_KERNEL {
32+
{
33+
simd<T, N> Vec(TOne);
34+
Vec = Vec + TTen;
35+
Vec.copy_to(Mem);
36+
}
37+
{
38+
simd<T, N> Vec(TOne);
39+
Vec = Vec - TTen;
40+
Vec.copy_to(Mem + N);
41+
}
42+
{
43+
simd<T, N> Vec(TOne);
44+
Vec = Vec * TTen;
45+
Vec.copy_to(Mem + 2 * N);
46+
}
47+
{
48+
simd<T, N> Vec(TOne);
49+
Vec = Vec / TTen;
50+
Vec.copy_to(Mem + 3 * N);
51+
}
52+
}).wait();
53+
54+
bool ReturnValue = true;
55+
for (int i = 0; i < N; ++i) {
56+
if (Mem[i] != TOne + TTen) {
57+
ReturnValue = false;
58+
break;
59+
}
60+
if (Mem[i + N] != TOne - TTen) {
61+
ReturnValue = false;
62+
break;
63+
}
64+
if (Mem[i + 2 * N] != TOne * TTen) {
65+
ReturnValue = false;
66+
break;
67+
}
68+
if (!((Mem[i + 3 * N] == (TOne / TTen)) ||
69+
(std::abs((double)(Mem[i + 3 * N] - (TOne / TTen)) /
70+
(double)(TOne / TTen)) <= 0.001))) {
71+
ReturnValue = false;
72+
break;
73+
}
74+
}
75+
76+
free(Mem, Q);
77+
return ReturnValue;
78+
}
79+
80+
int main() {
81+
queue Q;
82+
esimd_test::printTestLabel(Q);
83+
84+
bool SupportsHalf = Q.get_device().has(aspect::fp16);
85+
86+
bool Passed = true;
87+
Passed &= test<int>(Q);
88+
Passed &= test<float>(Q);
89+
if (SupportsHalf) {
90+
Passed &= test<sycl::half>(Q);
91+
}
92+
#ifdef USE_BF16
93+
Passed &= test<sycl::ext::oneapi::bfloat16>(Q);
94+
#endif
95+
#ifdef USE_TF32
96+
Passed &= test<sycl::ext::intel::experimental::esimd::tfloat32>(Q);
97+
#endif
98+
std::cout << (Passed ? "Passed\n" : "FAILED\n");
99+
return Passed ? 0 : 1;
100+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//==- bfloat16_vector_plus_scalar_pvc.cpp - Test for bfloat16 operators -==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: gpu-intel-pvc
9+
// RUN: %{build} -o %t.out
10+
// RUN: %{run} %t.out
11+
12+
#define USE_BF16
13+
#define USE_TF32
14+
#include "bfloat16_vector_plus_scalar.cpp"

0 commit comments

Comments
 (0)