Skip to content

[SYCL] Add SYCL 2020 operators sycl::logical_and and sycl::logical_or #4476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sycl/include/CL/sycl/functional.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ namespace sycl {

template <typename T = void> using plus = std::plus<T>;
template <typename T = void> using multiplies = std::multiplies<T>;
template <typename T = void> using bit_and = std::bit_and<T>;
template <typename T = void> using bit_or = std::bit_or<T>;
template <typename T = void> using bit_xor = std::bit_xor<T>;
template <typename T = void> using bit_and = std::bit_and<T>;
template <typename T = void> using logical_and = std::logical_and<T>;
template <typename T = void> using logical_or = std::logical_or<T>;

template <typename T = void> struct minimum {
T operator()(const T &lhs, const T &rhs) const {
Expand Down
92 changes: 64 additions & 28 deletions sycl/include/CL/sycl/known_identity.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ using IsMaximum =
bool_constant<std::is_same<BinaryOperation, sycl::maximum<T>>::value ||
std::is_same<BinaryOperation, sycl::maximum<void>>::value>;

template <typename T, class BinaryOperation>
using IsBitAND =
bool_constant<std::is_same<BinaryOperation, sycl::bit_and<T>>::value ||
std::is_same<BinaryOperation, sycl::bit_and<void>>::value>;

template <typename T, class BinaryOperation>
using IsBitOR =
bool_constant<std::is_same<BinaryOperation, sycl::bit_or<T>>::value ||
Expand All @@ -48,9 +53,14 @@ using IsBitXOR =
std::is_same<BinaryOperation, sycl::bit_xor<void>>::value>;

template <typename T, class BinaryOperation>
using IsBitAND =
bool_constant<std::is_same<BinaryOperation, sycl::bit_and<T>>::value ||
std::is_same<BinaryOperation, sycl::bit_and<void>>::value>;
using IsLogicalAND = bool_constant<
std::is_same<BinaryOperation, sycl::logical_and<T>>::value ||
std::is_same<BinaryOperation, sycl::logical_and<void>>::value>;

template <typename T, class BinaryOperation>
using IsLogicalOR =
bool_constant<std::is_same<BinaryOperation, sycl::logical_or<T>>::value ||
std::is_same<BinaryOperation, sycl::logical_or<void>>::value>;

// Identity = 0
template <typename T, class BinaryOperation>
Expand Down Expand Up @@ -83,13 +93,23 @@ using IsMaximumIdentityOp =
bool_constant<(is_sgeninteger<T>::value || is_sgenfloat<T>::value) &&
IsMaximum<T, BinaryOperation>::value>;

// Identity = false
template <typename T, class BinaryOperation>
using IsFalseIdentityOp = bool_constant<IsLogicalOR<T, BinaryOperation>::value>;

// Identity = true
template <typename T, class BinaryOperation>
using IsTrueIdentityOp = bool_constant<IsLogicalAND<T, BinaryOperation>::value>;

template <typename T, class BinaryOperation>
using IsKnownIdentityOp =
bool_constant<IsZeroIdentityOp<T, BinaryOperation>::value ||
IsOneIdentityOp<T, BinaryOperation>::value ||
IsOnesIdentityOp<T, BinaryOperation>::value ||
IsMinimumIdentityOp<T, BinaryOperation>::value ||
IsMaximumIdentityOp<T, BinaryOperation>::value>;
IsMaximumIdentityOp<T, BinaryOperation>::value ||
IsFalseIdentityOp<T, BinaryOperation>::value ||
IsTrueIdentityOp<T, BinaryOperation>::value>;

template <typename BinaryOperation, typename AccumulatorT>
struct has_known_identity_impl
Expand All @@ -101,16 +121,16 @@ struct known_identity_impl {};

/// Returns zero as identity for ADD, OR, XOR operations.
template <typename BinaryOperation, typename AccumulatorT>
struct known_identity_impl<BinaryOperation, AccumulatorT,
typename std::enable_if<IsZeroIdentityOp<
AccumulatorT, BinaryOperation>::value>::type> {
struct known_identity_impl<
BinaryOperation, AccumulatorT,
std::enable_if_t<IsZeroIdentityOp<AccumulatorT, BinaryOperation>::value>> {
static constexpr AccumulatorT value = 0;
};

template <typename BinaryOperation>
struct known_identity_impl<BinaryOperation, half,
typename std::enable_if<IsZeroIdentityOp<
half, BinaryOperation>::value>::type> {
struct known_identity_impl<
BinaryOperation, half,
std::enable_if_t<IsZeroIdentityOp<half, BinaryOperation>::value>> {
static constexpr half value =
#ifdef __SYCL_DEVICE_ONLY__
0;
Expand All @@ -121,16 +141,16 @@ struct known_identity_impl<BinaryOperation, half,

/// Returns one as identify for MULTIPLY operations.
template <typename BinaryOperation, typename AccumulatorT>
struct known_identity_impl<BinaryOperation, AccumulatorT,
typename std::enable_if<IsOneIdentityOp<
AccumulatorT, BinaryOperation>::value>::type> {
struct known_identity_impl<
BinaryOperation, AccumulatorT,
std::enable_if_t<IsOneIdentityOp<AccumulatorT, BinaryOperation>::value>> {
static constexpr AccumulatorT value = 1;
};

template <typename BinaryOperation>
struct known_identity_impl<BinaryOperation, half,
typename std::enable_if<IsOneIdentityOp<
half, BinaryOperation>::value>::type> {
struct known_identity_impl<
BinaryOperation, half,
std::enable_if_t<IsOneIdentityOp<half, BinaryOperation>::value>> {
static constexpr half value =
#ifdef __SYCL_DEVICE_ONLY__
1;
Expand All @@ -141,17 +161,17 @@ struct known_identity_impl<BinaryOperation, half,

/// Returns bit image consisting of all ones as identity for AND operations.
template <typename BinaryOperation, typename AccumulatorT>
struct known_identity_impl<BinaryOperation, AccumulatorT,
typename std::enable_if<IsOnesIdentityOp<
AccumulatorT, BinaryOperation>::value>::type> {
struct known_identity_impl<
BinaryOperation, AccumulatorT,
std::enable_if_t<IsOnesIdentityOp<AccumulatorT, BinaryOperation>::value>> {
static constexpr AccumulatorT value = ~static_cast<AccumulatorT>(0);
};

/// Returns maximal possible value as identity for MIN operations.
template <typename BinaryOperation, typename AccumulatorT>
struct known_identity_impl<BinaryOperation, AccumulatorT,
typename std::enable_if<IsMinimumIdentityOp<
AccumulatorT, BinaryOperation>::value>::type> {
std::enable_if_t<IsMinimumIdentityOp<
AccumulatorT, BinaryOperation>::value>> {
static constexpr AccumulatorT value =
std::numeric_limits<AccumulatorT>::has_infinity
? std::numeric_limits<AccumulatorT>::infinity()
Expand All @@ -161,22 +181,38 @@ struct known_identity_impl<BinaryOperation, AccumulatorT,
/// Returns minimal possible value as identity for MAX operations.
template <typename BinaryOperation, typename AccumulatorT>
struct known_identity_impl<BinaryOperation, AccumulatorT,
typename std::enable_if<IsMaximumIdentityOp<
AccumulatorT, BinaryOperation>::value>::type> {
std::enable_if_t<IsMaximumIdentityOp<
AccumulatorT, BinaryOperation>::value>> {
static constexpr AccumulatorT value =
std::numeric_limits<AccumulatorT>::has_infinity
? static_cast<AccumulatorT>(
-std::numeric_limits<AccumulatorT>::infinity())
: std::numeric_limits<AccumulatorT>::lowest();
};

/// Returns false as identity for LOGICAL OR operations.
template <typename BinaryOperation, typename AccumulatorT>
struct known_identity_impl<
BinaryOperation, AccumulatorT,
std::enable_if_t<IsFalseIdentityOp<AccumulatorT, BinaryOperation>::value>> {
static constexpr AccumulatorT value = false;
};

/// Returns true as identity for LOGICAL AND operations.
template <typename BinaryOperation, typename AccumulatorT>
struct known_identity_impl<
BinaryOperation, AccumulatorT,
std::enable_if_t<IsTrueIdentityOp<AccumulatorT, BinaryOperation>::value>> {
static constexpr AccumulatorT value = true;
};

} // namespace detail

// ---- has_known_identity
template <typename BinaryOperation, typename AccumulatorT>
struct has_known_identity : detail::has_known_identity_impl<
typename std::decay<BinaryOperation>::type,
typename std::decay<AccumulatorT>::type> {};
struct has_known_identity
: detail::has_known_identity_impl<std::decay_t<BinaryOperation>,
std::decay_t<AccumulatorT>> {};

template <typename BinaryOperation, typename AccumulatorT>
__SYCL_INLINE_CONSTEXPR bool has_known_identity_v =
Expand All @@ -185,8 +221,8 @@ __SYCL_INLINE_CONSTEXPR bool has_known_identity_v =
// ---- known_identity
template <typename BinaryOperation, typename AccumulatorT>
struct known_identity
: detail::known_identity_impl<typename std::decay<BinaryOperation>::type,
typename std::decay<AccumulatorT>::type> {};
: detail::known_identity_impl<std::decay_t<BinaryOperation>,
std::decay_t<AccumulatorT>> {};

template <typename BinaryOperation, typename AccumulatorT>
__SYCL_INLINE_CONSTEXPR AccumulatorT known_identity_v =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,31 @@

using namespace cl::sycl;

template <typename T> void checkCommonBasicKnownIdentity() {
template <typename T> void checkCommonKnownIdentity() {
static_assert(has_known_identity<sycl::maximum<>, T>::value);
static_assert(has_known_identity<sycl::maximum<T>, T>::value);
static_assert(has_known_identity<sycl::minimum<>, T>::value);
static_assert(has_known_identity<sycl::minimum<T>, T>::value);
}

template <typename T> void checkCommonKnownIdentity() {
checkCommonBasicKnownIdentity<T>();

static_assert(has_known_identity<std::plus<>, T>::value);
static_assert(has_known_identity<std::plus<T>, T>::value);
static_assert(known_identity<std::plus<>, T>::value == 0);
static_assert(known_identity<std::plus<T>, T>::value == 0);

static_assert(has_known_identity<sycl::plus<>, T>::value);
static_assert(has_known_identity<sycl::plus<T>, T>::value);
static_assert(known_identity<sycl::plus<>, T>::value == 0);
static_assert(known_identity<sycl::plus<T>, T>::value == 0);

static_assert(has_known_identity<std::multiplies<>, T>::value);
static_assert(has_known_identity<std::multiplies<T>, T>::value);
static_assert(known_identity<std::multiplies<>, T>::value == 1);
static_assert(known_identity<std::multiplies<T>, T>::value == 1);

static_assert(has_known_identity<sycl::multiplies<>, T>::value);
static_assert(has_known_identity<sycl::multiplies<T>, T>::value);
static_assert(known_identity<sycl::multiplies<>, T>::value == 1);
static_assert(known_identity<sycl::multiplies<T>, T>::value == 1);
}

template <typename T> void checkIntKnownIdentity() {
Expand All @@ -39,15 +45,52 @@ template <typename T> void checkIntKnownIdentity() {
static_assert(known_identity<std::bit_and<>, T>::value == Ones);
static_assert(known_identity<std::bit_and<T>, T>::value == Ones);

static_assert(has_known_identity<sycl::bit_and<>, T>::value);
static_assert(has_known_identity<sycl::bit_and<T>, T>::value);
static_assert(known_identity<sycl::bit_and<>, T>::value == Ones);
static_assert(known_identity<sycl::bit_and<T>, T>::value == Ones);

static_assert(has_known_identity<std::bit_or<>, T>::value);
static_assert(has_known_identity<std::bit_or<T>, T>::value);
static_assert(known_identity<std::bit_or<>, T>::value == 0);
static_assert(known_identity<std::bit_or<T>, T>::value == 0);

static_assert(has_known_identity<sycl::bit_or<>, T>::value);
static_assert(has_known_identity<sycl::bit_or<T>, T>::value);
static_assert(known_identity<sycl::bit_or<>, T>::value == 0);
static_assert(known_identity<sycl::bit_or<T>, T>::value == 0);

static_assert(has_known_identity<std::bit_xor<>, T>::value);
static_assert(has_known_identity<std::bit_xor<T>, T>::value);
static_assert(known_identity<std::bit_xor<>, T>::value == 0);
static_assert(known_identity<std::bit_xor<T>, T>::value == 0);

static_assert(has_known_identity<sycl::bit_xor<>, T>::value);
static_assert(has_known_identity<sycl::bit_xor<T>, T>::value);
static_assert(known_identity<sycl::bit_xor<>, T>::value == 0);
static_assert(known_identity<sycl::bit_xor<T>, T>::value == 0);
}

template <typename T> void checkBoolKnownIdentity() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it also be possible to use bit_and, bit_or, bit_xor with bool? I think these operations are well defined.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting question, but I afraid it may be implementation defined. At least this doc does not mention bool type: https://en.cppreference.com/w/cpp/utility/functional/bit_and

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Pennycook - This PR passed testing, it fixes an obvious miss. bit_and for bool is a separate and less obvious (imo). It can be supported by another PR. If you don't mind I'll merge the current PR as is.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delayed response! Yes, I agree that this can be merged, and we can revisit the question of which type/operator combinations are supported later.

I think things like bit_and do work for bool, but possibly because of weird integer conversion rules in C++. It's not immediately obvious to me which combinations the SYCL specification expects to work; I'll try to get some more clarity on this point, and we can also wait to see if any users actually complain about not being able to use "weird" combinations.

static_assert(has_known_identity<std::logical_and<>, T>::value);
static_assert(has_known_identity<std::logical_and<T>, T>::value);
static_assert(known_identity<std::logical_and<>, T>::value == true);
static_assert(known_identity<std::logical_and<T>, T>::value == true);

static_assert(has_known_identity<sycl::logical_and<>, T>::value);
static_assert(has_known_identity<sycl::logical_and<T>, T>::value);
static_assert(known_identity<sycl::logical_and<>, T>::value == true);
static_assert(known_identity<sycl::logical_and<T>, T>::value == true);

static_assert(has_known_identity<std::logical_or<>, T>::value);
static_assert(has_known_identity<std::logical_or<T>, T>::value);
static_assert(known_identity<std::logical_or<>, T>::value == false);
static_assert(known_identity<std::logical_or<T>, T>::value == false);

static_assert(has_known_identity<sycl::logical_or<>, T>::value);
static_assert(has_known_identity<sycl::logical_or<T>, T>::value);
static_assert(known_identity<sycl::logical_or<>, T>::value == false);
static_assert(known_identity<sycl::logical_or<T>, T>::value == false);
}

int main() {
Expand Down Expand Up @@ -94,9 +137,11 @@ int main() {
checkCommonKnownIdentity<double>();
checkCommonKnownIdentity<cl_double>();

checkCommonBasicKnownIdentity<half>();
checkCommonBasicKnownIdentity<sycl::cl_half>();
checkCommonBasicKnownIdentity<::cl_half>();
checkCommonKnownIdentity<half>();
checkCommonKnownIdentity<sycl::cl_half>();
checkCommonKnownIdentity<::cl_half>();

checkBoolKnownIdentity<bool>();

// Few negative tests just to check that it does not always return true.
static_assert(!has_known_identity<std::minus<>, int>::value);
Expand Down