Skip to content

Commit 091481d

Browse files
[SYCL][Reduction] Add missing known identities for bool operations (#8428)
This commit adds identities for arithmetic and integral operations on bool in accordance with the SYCL 2020 specification. Additionally it adds a check for the specification defined identities. --------- Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 8689420 commit 091481d

File tree

4 files changed

+148
-17
lines changed

4 files changed

+148
-17
lines changed

sycl/include/sycl/detail/generic_type_lists.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,16 @@ using marray_integer_list =
500500
using integer_list =
501501
type_list<scalar_integer_list, vector_integer_list, marray_integer_list>;
502502

503+
// bool types
504+
505+
using marray_bool_list =
506+
type_list<marray<bool, 1>, marray<bool, 2>, marray<bool, 3>,
507+
marray<bool, 4>, marray<bool, 8>, marray<bool, 16>>;
508+
509+
using scalar_bool_list = type_list<bool>;
510+
511+
using bool_list = type_list<scalar_bool_list, marray_bool_list>;
512+
503513
// basic types
504514
using scalar_signed_basic_list =
505515
type_list<scalar_floating_list, scalar_signed_integer_list>;

sycl/include/sycl/detail/generic_type_traits.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ using is_vigeninteger = is_contained<T, gtl::vector_signed_integer_list>;
181181
template <typename T>
182182
using is_vugeninteger = is_contained<T, gtl::vector_unsigned_integer_list>;
183183

184+
template <typename T> using is_genbool = is_contained<T, gtl::bool_list>;
185+
184186
template <typename T> using is_gentype = is_contained<T, gtl::basic_list>;
185187

186188
template <typename T>

sycl/include/sycl/known_identity.hpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,43 +64,43 @@ using IsLogicalOR =
6464
std::is_same<BinaryOperation, sycl::logical_or<void>>::value>;
6565

6666
template <typename T>
67-
using isComplex =
68-
bool_constant<std::is_same<T, std::complex<float>>::value ||
69-
std::is_same<T, std::complex<double>>::value>;
67+
using isComplex = bool_constant<std::is_same<T, std::complex<float>>::value ||
68+
std::is_same<T, std::complex<double>>::value>;
7069

7170
// Identity = 0
7271
template <typename T, class BinaryOperation>
73-
using IsZeroIdentityOp =
74-
bool_constant<(is_geninteger<T>::value &&
75-
(IsPlus<T, BinaryOperation>::value ||
76-
IsBitOR<T, BinaryOperation>::value ||
77-
IsBitXOR<T, BinaryOperation>::value)) ||
78-
(is_genfloat<T>::value &&
79-
IsPlus<T, BinaryOperation>::value) ||
80-
(isComplex<T>::value &&
81-
IsPlus<T, BinaryOperation>::value)>;
72+
using IsZeroIdentityOp = bool_constant<
73+
((is_genbool<T>::value || is_geninteger<T>::value) &&
74+
(IsPlus<T, BinaryOperation>::value || IsBitOR<T, BinaryOperation>::value ||
75+
IsBitXOR<T, BinaryOperation>::value)) ||
76+
(is_genfloat<T>::value && IsPlus<T, BinaryOperation>::value) ||
77+
(isComplex<T>::value && IsPlus<T, BinaryOperation>::value)>;
8278

8379
// Identity = 1
8480
template <typename T, class BinaryOperation>
8581
using IsOneIdentityOp =
86-
bool_constant<(is_geninteger<T>::value || is_genfloat<T>::value) &&
82+
bool_constant<(is_genbool<T>::value || is_geninteger<T>::value ||
83+
is_genfloat<T>::value) &&
8784
IsMultiplies<T, BinaryOperation>::value>;
8885

8986
// Identity = ~0
9087
template <typename T, class BinaryOperation>
91-
using IsOnesIdentityOp = bool_constant<is_geninteger<T>::value &&
92-
IsBitAND<T, BinaryOperation>::value>;
88+
using IsOnesIdentityOp =
89+
bool_constant<(is_genbool<T>::value || is_geninteger<T>::value) &&
90+
IsBitAND<T, BinaryOperation>::value>;
9391

9492
// Identity = <max possible value>
9593
template <typename T, class BinaryOperation>
9694
using IsMinimumIdentityOp =
97-
bool_constant<(is_geninteger<T>::value || is_genfloat<T>::value) &&
95+
bool_constant<(is_genbool<T>::value || is_geninteger<T>::value ||
96+
is_genfloat<T>::value) &&
9897
IsMinimum<T, BinaryOperation>::value>;
9998

10099
// Identity = <min possible value>
101100
template <typename T, class BinaryOperation>
102101
using IsMaximumIdentityOp =
103-
bool_constant<(is_geninteger<T>::value || is_genfloat<T>::value) &&
102+
bool_constant<(is_genbool<T>::value || is_geninteger<T>::value ||
103+
is_genfloat<T>::value) &&
104104
IsMaximum<T, BinaryOperation>::value>;
105105

106106
// Identity = false
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// RUN: %clangxx -fsycl -fsyntax-only %s
2+
3+
// Tests the existance and specializations of known identities.
4+
5+
#include <sycl/sycl.hpp>
6+
7+
#include <cassert>
8+
#include <limits>
9+
#include <type_traits>
10+
11+
template <typename BinOp, typename OperandT> constexpr void checkNoIdentity() {
12+
static_assert(!sycl::has_known_identity<BinOp, OperandT>::value,
13+
"Operation should not have a known identity!");
14+
static_assert(!sycl::has_known_identity_v<BinOp, OperandT>,
15+
"Operation should not have a known identity!");
16+
}
17+
18+
#define CHECK_IDENTITY(BINOP, OPERAND, EXPECTED) \
19+
static_assert(sycl::has_known_identity<BINOP, OPERAND>::value, \
20+
"No trait specialization for known identity!"); \
21+
static_assert(sycl::has_known_identity_v<BINOP, OPERAND>, \
22+
"No trait specialization for known identity!"); \
23+
static_assert(sycl::known_identity<BINOP, OPERAND>::value == EXPECTED, \
24+
"Identity does not match expected."); \
25+
static_assert(sycl::known_identity_v<BINOP, OPERAND> == EXPECTED, \
26+
"Identity does not match expected.");
27+
28+
template <typename OperandT> constexpr void checkAll() {
29+
if constexpr (std::is_arithmetic_v<OperandT> ||
30+
std::is_same_v<std::remove_cv_t<OperandT>, sycl::half>) {
31+
CHECK_IDENTITY(sycl::plus<OperandT>, OperandT, OperandT{});
32+
CHECK_IDENTITY(sycl::plus<>, OperandT, OperandT{});
33+
CHECK_IDENTITY(sycl::multiplies<OperandT>, OperandT, OperandT{1});
34+
CHECK_IDENTITY(sycl::multiplies<>, OperandT, OperandT{1});
35+
} else {
36+
checkNoIdentity<sycl::plus<OperandT>, OperandT>();
37+
checkNoIdentity<sycl::plus<>, OperandT>();
38+
checkNoIdentity<sycl::multiplies<OperandT>, OperandT>();
39+
checkNoIdentity<sycl::multiplies<>, OperandT>();
40+
}
41+
42+
if constexpr (std::is_integral_v<OperandT>) {
43+
CHECK_IDENTITY(sycl::bit_and<OperandT>, OperandT,
44+
static_cast<OperandT>(~OperandT{}));
45+
CHECK_IDENTITY(sycl::bit_and<>, OperandT,
46+
static_cast<OperandT>(~OperandT{}));
47+
CHECK_IDENTITY(sycl::bit_or<OperandT>, OperandT, OperandT{});
48+
CHECK_IDENTITY(sycl::bit_or<>, OperandT, OperandT{});
49+
CHECK_IDENTITY(sycl::bit_xor<OperandT>, OperandT, OperandT{});
50+
CHECK_IDENTITY(sycl::bit_xor<>, OperandT, OperandT{});
51+
CHECK_IDENTITY(sycl::minimum<OperandT>, OperandT,
52+
std::numeric_limits<OperandT>::max());
53+
CHECK_IDENTITY(sycl::minimum<>, OperandT,
54+
std::numeric_limits<OperandT>::max());
55+
CHECK_IDENTITY(sycl::maximum<OperandT>, OperandT,
56+
std::numeric_limits<OperandT>::lowest());
57+
CHECK_IDENTITY(sycl::maximum<>, OperandT,
58+
std::numeric_limits<OperandT>::lowest());
59+
} else {
60+
checkNoIdentity<sycl::bit_and<OperandT>, OperandT>();
61+
checkNoIdentity<sycl::bit_and<>, OperandT>();
62+
checkNoIdentity<sycl::bit_or<OperandT>, OperandT>();
63+
checkNoIdentity<sycl::bit_or<>, OperandT>();
64+
checkNoIdentity<sycl::bit_xor<OperandT>, OperandT>();
65+
checkNoIdentity<sycl::bit_xor<>, OperandT>();
66+
}
67+
68+
// The implementation is relaxed about logical operators to allow implicit
69+
// conversions for logical operators, so negative checks are not used for this
70+
// case.
71+
if constexpr (std::is_same_v<std::remove_cv_t<OperandT>, bool>) {
72+
CHECK_IDENTITY(sycl::logical_and<OperandT>, OperandT, true);
73+
CHECK_IDENTITY(sycl::logical_and<>, OperandT, true);
74+
CHECK_IDENTITY(sycl::logical_or<OperandT>, OperandT, false);
75+
CHECK_IDENTITY(sycl::logical_or<>, OperandT, false);
76+
}
77+
78+
if constexpr (std::is_floating_point_v<OperandT> ||
79+
std::is_same_v<std::remove_cv_t<OperandT>, sycl::half>) {
80+
CHECK_IDENTITY(sycl::minimum<OperandT>, OperandT,
81+
std::numeric_limits<OperandT>::infinity());
82+
CHECK_IDENTITY(sycl::minimum<>, OperandT,
83+
std::numeric_limits<OperandT>::infinity());
84+
CHECK_IDENTITY(sycl::maximum<OperandT>, OperandT,
85+
-std::numeric_limits<OperandT>::infinity());
86+
CHECK_IDENTITY(sycl::maximum<>, OperandT,
87+
-std::numeric_limits<OperandT>::infinity());
88+
}
89+
90+
if constexpr (!std::is_integral_v<OperandT> &&
91+
!std::is_floating_point_v<OperandT> &&
92+
!std::is_same_v<std::remove_cv_t<OperandT>, sycl::half>) {
93+
checkNoIdentity<sycl::minimum<OperandT>, OperandT>();
94+
checkNoIdentity<sycl::minimum<>, OperandT>();
95+
checkNoIdentity<sycl::maximum<OperandT>, OperandT>();
96+
checkNoIdentity<sycl::maximum<>, OperandT>();
97+
}
98+
}
99+
100+
struct CustomType {};
101+
102+
int main() {
103+
checkAll<bool>();
104+
checkAll<char>();
105+
checkAll<short>();
106+
checkAll<int>();
107+
checkAll<long>();
108+
checkAll<long long>();
109+
checkAll<signed char>();
110+
checkAll<unsigned char>();
111+
checkAll<unsigned int>();
112+
checkAll<unsigned long>();
113+
checkAll<unsigned long long>();
114+
checkAll<float>();
115+
checkAll<double>();
116+
checkAll<sycl::half>();
117+
checkAll<CustomType>();
118+
return 0;
119+
}

0 commit comments

Comments
 (0)