Skip to content

Commit a1787de

Browse files
[SYCL] Add missing marray binary operator overloads (#8276)
This commit does the following for marray: 1. Add overloads on all binary operators with scalars as the left operand. 2. Allow half, float, double in && and || operators. fixes #8331 --------- Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 60e97e7 commit a1787de

File tree

2 files changed

+97
-42
lines changed

2 files changed

+97
-42
lines changed

sycl/include/sycl/marray.hpp

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -208,28 +208,37 @@ template <typename Type, std::size_t NumElements> class marray {
208208
return Ret; \
209209
} \
210210
template <typename T> \
211-
friend typename std::enable_if< \
212-
std::is_convertible<DataT, T>::value && \
213-
(std::is_fundamental<T>::value || \
214-
std::is_same<typename std::remove_const<T>::type, half>::value), \
215-
marray>::type \
211+
friend typename std::enable_if_t< \
212+
std::is_convertible_v<DataT, T> && \
213+
(std::is_fundamental_v<T> || \
214+
std::is_same_v<typename std::remove_const<T>::type, half>), \
215+
marray> \
216216
operator BINOP(const marray &Lhs, const T &Rhs) { \
217217
return Lhs BINOP marray(static_cast<DataT>(Rhs)); \
218218
} \
219+
template <typename T> \
220+
friend typename std::enable_if_t< \
221+
std::is_convertible_v<DataT, T> && \
222+
(std::is_fundamental_v<T> || \
223+
std::is_same_v<typename std::remove_const<T>::type, half>), \
224+
marray> \
225+
operator BINOP(const T &Lhs, const marray &Rhs) { \
226+
return marray(static_cast<DataT>(Lhs)) BINOP Rhs; \
227+
} \
219228
friend marray &operator OPASSIGN(marray &Lhs, const marray &Rhs) { \
220229
Lhs = Lhs BINOP Rhs; \
221230
return Lhs; \
222231
} \
223232
template <std::size_t Num = NumElements> \
224-
friend typename std::enable_if<Num != 1, marray &>::type operator OPASSIGN( \
233+
friend typename std::enable_if_t<Num != 1, marray &> operator OPASSIGN( \
225234
marray &Lhs, const DataT &Rhs) { \
226235
Lhs = Lhs BINOP marray(Rhs); \
227236
return Lhs; \
228237
}
229238

230239
#define __SYCL_BINOP_INTEGRAL(BINOP, OPASSIGN) \
231240
template <typename T = DataT, \
232-
typename = std::enable_if<std::is_integral<T>::value, marray>> \
241+
typename = std::enable_if<std::is_integral_v<T>, marray>> \
233242
friend marray operator BINOP(const marray &Lhs, const marray &Rhs) { \
234243
marray Ret; \
235244
for (size_t I = 0; I < NumElements; ++I) { \
@@ -238,23 +247,31 @@ template <typename Type, std::size_t NumElements> class marray {
238247
return Ret; \
239248
} \
240249
template <typename T, typename BaseT = DataT> \
241-
friend typename std::enable_if<std::is_convertible<T, DataT>::value && \
242-
std::is_integral<T>::value && \
243-
std::is_integral<BaseT>::value, \
244-
marray>::type \
250+
friend typename std::enable_if_t<std::is_convertible_v<T, DataT> && \
251+
std::is_integral_v<T> && \
252+
std::is_integral_v<BaseT>, \
253+
marray> \
245254
operator BINOP(const marray &Lhs, const T &Rhs) { \
246255
return Lhs BINOP marray(static_cast<DataT>(Rhs)); \
247256
} \
257+
template <typename T, typename BaseT = DataT> \
258+
friend typename std::enable_if_t<std::is_convertible_v<T, DataT> && \
259+
std::is_integral_v<T> && \
260+
std::is_integral_v<BaseT>, \
261+
marray> \
262+
operator BINOP(const T &Lhs, const marray &Rhs) { \
263+
return marray(static_cast<DataT>(Lhs)) BINOP Rhs; \
264+
} \
248265
template <typename T = DataT, \
249-
typename = std::enable_if<std::is_integral<T>::value, marray>> \
266+
typename = std::enable_if<std::is_integral_v<T>, marray>> \
250267
friend marray &operator OPASSIGN(marray &Lhs, const marray &Rhs) { \
251268
Lhs = Lhs BINOP Rhs; \
252269
return Lhs; \
253270
} \
254271
template <std::size_t Num = NumElements, typename T = DataT> \
255-
friend typename std::enable_if<Num != 1 && std::is_integral<T>::value, \
256-
marray &>::type \
257-
operator OPASSIGN(marray &Lhs, const DataT &Rhs) { \
272+
friend \
273+
typename std::enable_if_t<Num != 1 && std::is_integral_v<T>, marray &> \
274+
operator OPASSIGN(marray &Lhs, const DataT &Rhs) { \
258275
Lhs = Lhs BINOP marray(Rhs); \
259276
return Lhs; \
260277
}
@@ -291,32 +308,20 @@ template <typename Type, std::size_t NumElements> class marray {
291308
return Ret; \
292309
} \
293310
template <typename T> \
294-
friend typename std::enable_if<std::is_convertible<T, DataT>::value && \
295-
(std::is_fundamental<T>::value || \
296-
std::is_same<T, half>::value), \
297-
marray<bool, NumElements>>::type \
311+
friend typename std::enable_if_t<std::is_convertible_v<T, DataT> && \
312+
(std::is_fundamental_v<T> || \
313+
std::is_same_v<T, half>), \
314+
marray<bool, NumElements>> \
298315
operator RELLOGOP(const marray &Lhs, const T &Rhs) { \
299316
return Lhs RELLOGOP marray(static_cast<const DataT &>(Rhs)); \
300-
}
301-
302-
#define __SYCL_RELLOGOP_INTEGRAL(RELLOGOP) \
303-
template <typename T = DataT> \
304-
friend typename std::enable_if<std::is_integral<T>::value, \
305-
marray<bool, NumElements>>::type \
306-
operator RELLOGOP(const marray &Lhs, const marray &Rhs) { \
307-
marray<bool, NumElements> Ret; \
308-
for (size_t I = 0; I < NumElements; ++I) { \
309-
Ret[I] = Lhs[I] RELLOGOP Rhs[I]; \
310-
} \
311-
return Ret; \
312317
} \
313-
template <typename T, typename BaseT = DataT> \
314-
friend typename std::enable_if<std::is_convertible<T, DataT>::value && \
315-
std::is_integral<T>::value && \
316-
std::is_integral<BaseT>::value, \
317-
marray<bool, NumElements>>::type \
318-
operator RELLOGOP(const marray &Lhs, const T &Rhs) { \
319-
return Lhs RELLOGOP marray(static_cast<const DataT &>(Rhs)); \
318+
template <typename T> \
319+
friend typename std::enable_if_t<std::is_convertible_v<T, DataT> && \
320+
(std::is_fundamental_v<T> || \
321+
std::is_same_v<T, half>), \
322+
marray<bool, NumElements>> \
323+
operator RELLOGOP(const T &Lhs, const marray &Rhs) { \
324+
return marray(static_cast<const DataT &>(Lhs)) RELLOGOP Rhs; \
320325
}
321326

322327
__SYCL_RELLOGOP(==)
@@ -325,12 +330,10 @@ template <typename Type, std::size_t NumElements> class marray {
325330
__SYCL_RELLOGOP(<)
326331
__SYCL_RELLOGOP(>=)
327332
__SYCL_RELLOGOP(<=)
328-
329-
__SYCL_RELLOGOP_INTEGRAL(&&)
330-
__SYCL_RELLOGOP_INTEGRAL(||)
333+
__SYCL_RELLOGOP(&&)
334+
__SYCL_RELLOGOP(||)
331335

332336
#undef __SYCL_RELLOGOP
333-
#undef __SYCL_RELLOGOP_INTEGRAL
334337

335338
#ifdef __SYCL_UOP
336339
#error "Undefine __SYCL_UOP macro"

sycl/test/basic_tests/marray/marray.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,49 @@ using namespace sycl;
2323
CHECK_ALIAS_BY_SIZE(ALIAS_MTYPE, ELEM_TYPE, 8) \
2424
CHECK_ALIAS_BY_SIZE(ALIAS_MTYPE, ELEM_TYPE, 16)
2525

26+
// Check different combinations of the given binary operation. Some compare
27+
// scalar values with the marrays, which is valid as all elements in the marrays
28+
// should be the same.
29+
#define CHECK_BINOP(OP, LHS, RHS) \
30+
assert((LHS[0] OP RHS) == (LHS OP RHS) && (LHS OP RHS[0]) == (LHS OP RHS) && \
31+
(LHS[0] OP RHS[0]) == (LHS OP RHS));
32+
2633
struct NotDefaultConstructible {
2734
NotDefaultConstructible() = delete;
2835
constexpr NotDefaultConstructible(int){};
2936
};
3037

38+
template <typename T> void CheckBinOps() {
39+
sycl::marray<T, 3> ref_arr0{0};
40+
sycl::marray<T, 3> ref_arr1{1};
41+
sycl::marray<T, 3> ref_arr2{2};
42+
sycl::marray<T, 3> ref_arr3{3};
43+
44+
CHECK_BINOP(+, ref_arr1, ref_arr2)
45+
CHECK_BINOP(-, ref_arr1, ref_arr2)
46+
CHECK_BINOP(*, ref_arr1, ref_arr2)
47+
CHECK_BINOP(/, ref_arr1, ref_arr2)
48+
CHECK_BINOP(&&, ref_arr0, ref_arr2)
49+
CHECK_BINOP(||, ref_arr0, ref_arr2)
50+
CHECK_BINOP(==, ref_arr1, ref_arr2)
51+
CHECK_BINOP(!=, ref_arr1, ref_arr2)
52+
CHECK_BINOP(<, ref_arr1, ref_arr2)
53+
CHECK_BINOP(>, ref_arr1, ref_arr2)
54+
CHECK_BINOP(<=, ref_arr1, ref_arr2)
55+
CHECK_BINOP(>=, ref_arr1, ref_arr2)
56+
57+
if constexpr (!std::is_same_v<T, sycl::half> && !std::is_same_v<T, float> &&
58+
!std::is_same_v<T, double>) {
59+
// Operators not supported on sycl::half, float, and double.
60+
CHECK_BINOP(%, ref_arr1, ref_arr2)
61+
CHECK_BINOP(&, ref_arr1, ref_arr3)
62+
CHECK_BINOP(|, ref_arr1, ref_arr3)
63+
CHECK_BINOP(^, ref_arr1, ref_arr3)
64+
CHECK_BINOP(>>, ref_arr1, ref_arr2)
65+
CHECK_BINOP(<<, ref_arr1, ref_arr2)
66+
}
67+
}
68+
3169
template <typename DataT> void CheckConstexprVariadicCtors() {
3270
constexpr DataT default_val{1};
3371

@@ -145,6 +183,20 @@ int main() {
145183
b___ = !mint3{0, 1, 2};
146184
assert(b___[0] == true && b___[1] == false && b___[2] == false);
147185

186+
// Check direct binary operators
187+
CheckBinOps<bool>();
188+
CheckBinOps<std::int8_t>();
189+
CheckBinOps<std::uint8_t>();
190+
CheckBinOps<std::int16_t>();
191+
CheckBinOps<std::uint16_t>();
192+
CheckBinOps<std::int32_t>();
193+
CheckBinOps<std::uint32_t>();
194+
CheckBinOps<std::int64_t>();
195+
CheckBinOps<std::uint64_t>();
196+
CheckBinOps<sycl::half>();
197+
CheckBinOps<float>();
198+
CheckBinOps<double>();
199+
148200
// check copyability
149201
constexpr sycl::marray<double, 5> ma;
150202
constexpr sycl::marray<double, 5> mb(ma);

0 commit comments

Comments
 (0)