Skip to content

[SYCL] Add missing marray binary operator overloads #8276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 45 additions & 42 deletions sycl/include/sycl/marray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,28 +208,37 @@ template <typename Type, std::size_t NumElements> class marray {
return Ret; \
} \
template <typename T> \
friend typename std::enable_if< \
std::is_convertible<DataT, T>::value && \
(std::is_fundamental<T>::value || \
std::is_same<typename std::remove_const<T>::type, half>::value), \
marray>::type \
friend typename std::enable_if_t< \
std::is_convertible_v<DataT, T> && \
(std::is_fundamental_v<T> || \
std::is_same_v<typename std::remove_const<T>::type, half>), \
marray> \
operator BINOP(const marray &Lhs, const T &Rhs) { \
return Lhs BINOP marray(static_cast<DataT>(Rhs)); \
} \
template <typename T> \
friend typename std::enable_if_t< \
std::is_convertible_v<DataT, T> && \
(std::is_fundamental_v<T> || \
std::is_same_v<typename std::remove_const<T>::type, half>), \
marray> \
operator BINOP(const T &Lhs, const marray &Rhs) { \
return marray(static_cast<DataT>(Lhs)) BINOP Rhs; \
} \
friend marray &operator OPASSIGN(marray &Lhs, const marray &Rhs) { \
Lhs = Lhs BINOP Rhs; \
return Lhs; \
} \
template <std::size_t Num = NumElements> \
friend typename std::enable_if<Num != 1, marray &>::type operator OPASSIGN( \
friend typename std::enable_if_t<Num != 1, marray &> operator OPASSIGN( \
marray &Lhs, const DataT &Rhs) { \
Lhs = Lhs BINOP marray(Rhs); \
return Lhs; \
}

#define __SYCL_BINOP_INTEGRAL(BINOP, OPASSIGN) \
template <typename T = DataT, \
typename = std::enable_if<std::is_integral<T>::value, marray>> \
typename = std::enable_if<std::is_integral_v<T>, marray>> \
friend marray operator BINOP(const marray &Lhs, const marray &Rhs) { \
marray Ret; \
for (size_t I = 0; I < NumElements; ++I) { \
Expand All @@ -238,23 +247,31 @@ template <typename Type, std::size_t NumElements> class marray {
return Ret; \
} \
template <typename T, typename BaseT = DataT> \
friend typename std::enable_if<std::is_convertible<T, DataT>::value && \
std::is_integral<T>::value && \
std::is_integral<BaseT>::value, \
marray>::type \
friend typename std::enable_if_t<std::is_convertible_v<T, DataT> && \
std::is_integral_v<T> && \
std::is_integral_v<BaseT>, \
marray> \
operator BINOP(const marray &Lhs, const T &Rhs) { \
return Lhs BINOP marray(static_cast<DataT>(Rhs)); \
} \
template <typename T, typename BaseT = DataT> \
friend typename std::enable_if_t<std::is_convertible_v<T, DataT> && \
std::is_integral_v<T> && \
std::is_integral_v<BaseT>, \
marray> \
operator BINOP(const T &Lhs, const marray &Rhs) { \
return marray(static_cast<DataT>(Lhs)) BINOP Rhs; \
} \
template <typename T = DataT, \
typename = std::enable_if<std::is_integral<T>::value, marray>> \
typename = std::enable_if<std::is_integral_v<T>, marray>> \
friend marray &operator OPASSIGN(marray &Lhs, const marray &Rhs) { \
Lhs = Lhs BINOP Rhs; \
return Lhs; \
} \
template <std::size_t Num = NumElements, typename T = DataT> \
friend typename std::enable_if<Num != 1 && std::is_integral<T>::value, \
marray &>::type \
operator OPASSIGN(marray &Lhs, const DataT &Rhs) { \
friend \
typename std::enable_if_t<Num != 1 && std::is_integral_v<T>, marray &> \
operator OPASSIGN(marray &Lhs, const DataT &Rhs) { \
Lhs = Lhs BINOP marray(Rhs); \
return Lhs; \
}
Expand Down Expand Up @@ -291,32 +308,20 @@ template <typename Type, std::size_t NumElements> class marray {
return Ret; \
} \
template <typename T> \
friend typename std::enable_if<std::is_convertible<T, DataT>::value && \
(std::is_fundamental<T>::value || \
std::is_same<T, half>::value), \
marray<bool, NumElements>>::type \
friend typename std::enable_if_t<std::is_convertible_v<T, DataT> && \
(std::is_fundamental_v<T> || \
std::is_same_v<T, half>), \
marray<bool, NumElements>> \
operator RELLOGOP(const marray &Lhs, const T &Rhs) { \
return Lhs RELLOGOP marray(static_cast<const DataT &>(Rhs)); \
}

#define __SYCL_RELLOGOP_INTEGRAL(RELLOGOP) \
template <typename T = DataT> \
friend typename std::enable_if<std::is_integral<T>::value, \
marray<bool, NumElements>>::type \
operator RELLOGOP(const marray &Lhs, const marray &Rhs) { \
marray<bool, NumElements> Ret; \
for (size_t I = 0; I < NumElements; ++I) { \
Ret[I] = Lhs[I] RELLOGOP Rhs[I]; \
} \
return Ret; \
} \
template <typename T, typename BaseT = DataT> \
friend typename std::enable_if<std::is_convertible<T, DataT>::value && \
std::is_integral<T>::value && \
std::is_integral<BaseT>::value, \
marray<bool, NumElements>>::type \
operator RELLOGOP(const marray &Lhs, const T &Rhs) { \
return Lhs RELLOGOP marray(static_cast<const DataT &>(Rhs)); \
template <typename T> \
friend typename std::enable_if_t<std::is_convertible_v<T, DataT> && \
(std::is_fundamental_v<T> || \
std::is_same_v<T, half>), \
marray<bool, NumElements>> \
operator RELLOGOP(const T &Lhs, const marray &Rhs) { \
return marray(static_cast<const DataT &>(Lhs)) RELLOGOP Rhs; \
}

__SYCL_RELLOGOP(==)
Expand All @@ -325,12 +330,10 @@ template <typename Type, std::size_t NumElements> class marray {
__SYCL_RELLOGOP(<)
__SYCL_RELLOGOP(>=)
__SYCL_RELLOGOP(<=)

__SYCL_RELLOGOP_INTEGRAL(&&)
__SYCL_RELLOGOP_INTEGRAL(||)
__SYCL_RELLOGOP(&&)
__SYCL_RELLOGOP(||)

#undef __SYCL_RELLOGOP
#undef __SYCL_RELLOGOP_INTEGRAL

#ifdef __SYCL_UOP
#error "Undefine __SYCL_UOP macro"
Expand Down
52 changes: 52 additions & 0 deletions sycl/test/basic_tests/marray/marray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,49 @@ using namespace sycl;
CHECK_ALIAS_BY_SIZE(ALIAS_MTYPE, ELEM_TYPE, 8) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed with #8449.

CHECK_ALIAS_BY_SIZE(ALIAS_MTYPE, ELEM_TYPE, 16)

// Check different combinations of the given binary operation. Some compare
// scalar values with the marrays, which is valid as all elements in the marrays
// should be the same.
#define CHECK_BINOP(OP, LHS, RHS) \
assert((LHS[0] OP RHS) == (LHS OP RHS) && (LHS OP RHS[0]) == (LHS OP RHS) && \
(LHS[0] OP RHS[0]) == (LHS OP RHS));

struct NotDefaultConstructible {
NotDefaultConstructible() = delete;
constexpr NotDefaultConstructible(int){};
};

template <typename T> void CheckBinOps() {
sycl::marray<T, 3> ref_arr0{0};
sycl::marray<T, 3> ref_arr1{1};
sycl::marray<T, 3> ref_arr2{2};
sycl::marray<T, 3> ref_arr3{3};

CHECK_BINOP(+, ref_arr1, ref_arr2)
CHECK_BINOP(-, ref_arr1, ref_arr2)
CHECK_BINOP(*, ref_arr1, ref_arr2)
CHECK_BINOP(/, ref_arr1, ref_arr2)
CHECK_BINOP(&&, ref_arr0, ref_arr2)
CHECK_BINOP(||, ref_arr0, ref_arr2)
CHECK_BINOP(==, ref_arr1, ref_arr2)
CHECK_BINOP(!=, ref_arr1, ref_arr2)
CHECK_BINOP(<, ref_arr1, ref_arr2)
CHECK_BINOP(>, ref_arr1, ref_arr2)
CHECK_BINOP(<=, ref_arr1, ref_arr2)
CHECK_BINOP(>=, ref_arr1, ref_arr2)

if constexpr (!std::is_same_v<T, sycl::half> && !std::is_same_v<T, float> &&
!std::is_same_v<T, double>) {
// Operators not supported on sycl::half, float, and double.
CHECK_BINOP(%, ref_arr1, ref_arr2)
CHECK_BINOP(&, ref_arr1, ref_arr3)
CHECK_BINOP(|, ref_arr1, ref_arr3)
CHECK_BINOP(^, ref_arr1, ref_arr3)
CHECK_BINOP(>>, ref_arr1, ref_arr2)
CHECK_BINOP(<<, ref_arr1, ref_arr2)
}
}

template <typename DataT> void CheckConstexprVariadicCtors() {
constexpr DataT default_val{1};

Expand Down Expand Up @@ -145,6 +183,20 @@ int main() {
b___ = !mint3{0, 1, 2};
assert(b___[0] == true && b___[1] == false && b___[2] == false);

// Check direct binary operators
CheckBinOps<bool>();
CheckBinOps<std::int8_t>();
CheckBinOps<std::uint8_t>();
CheckBinOps<std::int16_t>();
CheckBinOps<std::uint16_t>();
CheckBinOps<std::int32_t>();
CheckBinOps<std::uint32_t>();
CheckBinOps<std::int64_t>();
CheckBinOps<std::uint64_t>();
CheckBinOps<sycl::half>();
CheckBinOps<float>();
CheckBinOps<double>();

// check copyability
constexpr sycl::marray<double, 5> ma;
constexpr sycl::marray<double, 5> mb(ma);
Expand Down