Skip to content

Commit 701377b

Browse files
committed
[SYCL] Fix the type trait 'known_identity'
- Move 'known_identity' and 'has_known_identity' from sycl::ONEAPI to sycl - Fix the compilation errors reported on attempts to use vector types like sycl::int4 as a reduction variable. Signed-off-by: Vyacheslav N Klochkov <[email protected]>
1 parent bdfad9e commit 701377b

File tree

3 files changed

+190
-86
lines changed

3 files changed

+190
-86
lines changed

sycl/include/CL/sycl/ONEAPI/reduction.hpp

Lines changed: 63 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -22,121 +22,100 @@ namespace ONEAPI {
2222

2323
namespace detail {
2424

25+
using cl::sycl::detail::bool_constant;
26+
using cl::sycl::detail::enable_if_t;
27+
using cl::sycl::detail::is_sgenfloat;
28+
using cl::sycl::detail::is_sgeninteger;
2529
using cl::sycl::detail::queue_impl;
30+
using cl::sycl::detail::remove_AS;
2631

2732
__SYCL_EXPORT size_t reduGetMaxWGSize(shared_ptr_class<queue_impl> Queue,
2833
size_t LocalMemBytesPerWorkItem);
2934
__SYCL_EXPORT size_t reduComputeWGSize(size_t NWorkItems, size_t MaxWGSize,
3035
size_t &NWorkGroups);
3136

32-
using cl::sycl::detail::bool_constant;
33-
using cl::sycl::detail::enable_if_t;
34-
using cl::sycl::detail::is_geninteger16bit;
35-
using cl::sycl::detail::is_geninteger32bit;
36-
using cl::sycl::detail::is_geninteger64bit;
37-
using cl::sycl::detail::is_geninteger8bit;
38-
using cl::sycl::detail::remove_AS;
39-
4037
template <typename T, class BinaryOperation>
41-
using IsReduPlus = detail::bool_constant<
42-
std::is_same<BinaryOperation, ONEAPI::plus<T>>::value ||
43-
std::is_same<BinaryOperation, ONEAPI::plus<void>>::value>;
38+
using IsReduPlus =
39+
bool_constant<std::is_same<BinaryOperation, ONEAPI::plus<T>>::value ||
40+
std::is_same<BinaryOperation, ONEAPI::plus<void>>::value>;
4441

4542
template <typename T, class BinaryOperation>
46-
using IsReduMultiplies = detail::bool_constant<
47-
std::is_same<BinaryOperation, std::multiplies<T>>::value ||
48-
std::is_same<BinaryOperation, std::multiplies<void>>::value>;
43+
using IsReduMultiplies =
44+
bool_constant<std::is_same<BinaryOperation, std::multiplies<T>>::value ||
45+
std::is_same<BinaryOperation, std::multiplies<void>>::value>;
4946

5047
template <typename T, class BinaryOperation>
51-
using IsReduMinimum = detail::bool_constant<
52-
std::is_same<BinaryOperation, ONEAPI::minimum<T>>::value ||
53-
std::is_same<BinaryOperation, ONEAPI::minimum<void>>::value>;
48+
using IsReduMinimum =
49+
bool_constant<std::is_same<BinaryOperation, ONEAPI::minimum<T>>::value ||
50+
std::is_same<BinaryOperation, ONEAPI::minimum<void>>::value>;
5451

5552
template <typename T, class BinaryOperation>
56-
using IsReduMaximum = detail::bool_constant<
57-
std::is_same<BinaryOperation, ONEAPI::maximum<T>>::value ||
58-
std::is_same<BinaryOperation, ONEAPI::maximum<void>>::value>;
53+
using IsReduMaximum =
54+
bool_constant<std::is_same<BinaryOperation, ONEAPI::maximum<T>>::value ||
55+
std::is_same<BinaryOperation, ONEAPI::maximum<void>>::value>;
5956

6057
template <typename T, class BinaryOperation>
61-
using IsReduBitOR = detail::bool_constant<
62-
std::is_same<BinaryOperation, ONEAPI::bit_or<T>>::value ||
63-
std::is_same<BinaryOperation, ONEAPI::bit_or<void>>::value>;
58+
using IsReduBitOR =
59+
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_or<T>>::value ||
60+
std::is_same<BinaryOperation, ONEAPI::bit_or<void>>::value>;
6461

6562
template <typename T, class BinaryOperation>
66-
using IsReduBitXOR = detail::bool_constant<
67-
std::is_same<BinaryOperation, ONEAPI::bit_xor<T>>::value ||
68-
std::is_same<BinaryOperation, ONEAPI::bit_xor<void>>::value>;
63+
using IsReduBitXOR =
64+
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_xor<T>>::value ||
65+
std::is_same<BinaryOperation, ONEAPI::bit_xor<void>>::value>;
6966

7067
template <typename T, class BinaryOperation>
71-
using IsReduBitAND = detail::bool_constant<
72-
std::is_same<BinaryOperation, ONEAPI::bit_and<T>>::value ||
73-
std::is_same<BinaryOperation, ONEAPI::bit_and<void>>::value>;
68+
using IsReduBitAND =
69+
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_and<T>>::value ||
70+
std::is_same<BinaryOperation, ONEAPI::bit_and<void>>::value>;
7471

7572
template <typename T, class BinaryOperation>
7673
using IsReduOptForFastAtomicFetch =
77-
detail::bool_constant<(is_geninteger32bit<T>::value ||
78-
is_geninteger64bit<T>::value) &&
79-
(IsReduPlus<T, BinaryOperation>::value ||
80-
IsReduMinimum<T, BinaryOperation>::value ||
81-
IsReduMaximum<T, BinaryOperation>::value ||
82-
IsReduBitOR<T, BinaryOperation>::value ||
83-
IsReduBitXOR<T, BinaryOperation>::value ||
84-
IsReduBitAND<T, BinaryOperation>::value)>;
74+
bool_constant<is_sgeninteger<T>::value && IsValidAtomicType<T>::value &&
75+
(IsReduPlus<T, BinaryOperation>::value ||
76+
IsReduMinimum<T, BinaryOperation>::value ||
77+
IsReduMaximum<T, BinaryOperation>::value ||
78+
IsReduBitOR<T, BinaryOperation>::value ||
79+
IsReduBitXOR<T, BinaryOperation>::value ||
80+
IsReduBitAND<T, BinaryOperation>::value)>;
8581

8682
template <typename T, class BinaryOperation>
87-
using IsReduOptForFastReduce = detail::bool_constant<
88-
(is_geninteger32bit<T>::value || is_geninteger64bit<T>::value ||
89-
std::is_same<T, half>::value || std::is_same<T, float>::value ||
90-
std::is_same<T, double>::value) &&
91-
(IsReduPlus<T, BinaryOperation>::value ||
92-
IsReduMinimum<T, BinaryOperation>::value ||
93-
IsReduMaximum<T, BinaryOperation>::value)>;
83+
using IsReduOptForFastReduce =
84+
bool_constant<(is_sgeninteger<T>::value || is_sgenfloat<T>::value) &&
85+
(IsReduPlus<T, BinaryOperation>::value ||
86+
IsReduMinimum<T, BinaryOperation>::value ||
87+
IsReduMaximum<T, BinaryOperation>::value)>;
9488

9589
// Identity = 0
9690
template <typename T, class BinaryOperation>
9791
using IsZeroIdentityOp = bool_constant<
98-
((is_geninteger8bit<T>::value || is_geninteger16bit<T>::value ||
99-
is_geninteger32bit<T>::value || is_geninteger64bit<T>::value) &&
100-
(IsReduPlus<T, BinaryOperation>::value ||
101-
IsReduBitOR<T, BinaryOperation>::value ||
102-
IsReduBitXOR<T, BinaryOperation>::value)) ||
103-
((std::is_same<T, half>::value || std::is_same<T, float>::value ||
104-
std::is_same<T, double>::value) &&
105-
IsReduPlus<T, BinaryOperation>::value)>;
92+
(is_sgeninteger<T>::value && (IsReduPlus<T, BinaryOperation>::value ||
93+
IsReduBitOR<T, BinaryOperation>::value ||
94+
IsReduBitXOR<T, BinaryOperation>::value)) ||
95+
(is_sgenfloat<T>::value && IsReduPlus<T, BinaryOperation>::value)>;
10696

10797
// Identity = 1
10898
template <typename T, class BinaryOperation>
109-
using IsOneIdentityOp = bool_constant<
110-
(is_geninteger8bit<T>::value || is_geninteger16bit<T>::value ||
111-
is_geninteger32bit<T>::value || is_geninteger64bit<T>::value ||
112-
std::is_same<T, half>::value || std::is_same<T, float>::value ||
113-
std::is_same<T, double>::value) &&
114-
IsReduMultiplies<T, BinaryOperation>::value>;
99+
using IsOneIdentityOp =
100+
bool_constant<(is_sgeninteger<T>::value || is_sgenfloat<T>::value) &&
101+
IsReduMultiplies<T, BinaryOperation>::value>;
115102

116103
// Identity = ~0
117104
template <typename T, class BinaryOperation>
118-
using IsOnesIdentityOp = bool_constant<
119-
(is_geninteger8bit<T>::value || is_geninteger16bit<T>::value ||
120-
is_geninteger32bit<T>::value || is_geninteger64bit<T>::value) &&
121-
IsReduBitAND<T, BinaryOperation>::value>;
105+
using IsOnesIdentityOp = bool_constant<is_sgeninteger<T>::value &&
106+
IsReduBitAND<T, BinaryOperation>::value>;
122107

123108
// Identity = <max possible value>
124109
template <typename T, class BinaryOperation>
125-
using IsMinimumIdentityOp = bool_constant<
126-
(is_geninteger8bit<T>::value || is_geninteger16bit<T>::value ||
127-
is_geninteger32bit<T>::value || is_geninteger64bit<T>::value ||
128-
std::is_same<T, half>::value || std::is_same<T, float>::value ||
129-
std::is_same<T, double>::value) &&
130-
IsReduMinimum<T, BinaryOperation>::value>;
110+
using IsMinimumIdentityOp =
111+
bool_constant<(is_sgeninteger<T>::value || is_sgenfloat<T>::value) &&
112+
IsReduMinimum<T, BinaryOperation>::value>;
131113

132114
// Identity = <min possible value>
133115
template <typename T, class BinaryOperation>
134-
using IsMaximumIdentityOp = bool_constant<
135-
(is_geninteger8bit<T>::value || is_geninteger16bit<T>::value ||
136-
is_geninteger32bit<T>::value || is_geninteger64bit<T>::value ||
137-
std::is_same<T, half>::value || std::is_same<T, float>::value ||
138-
std::is_same<T, double>::value) &&
139-
IsReduMaximum<T, BinaryOperation>::value>;
116+
using IsMaximumIdentityOp =
117+
bool_constant<(is_sgeninteger<T>::value || is_sgenfloat<T>::value) &&
118+
IsReduMaximum<T, BinaryOperation>::value>;
140119

141120
template <typename T, class BinaryOperation>
142121
using IsKnownIdentityOp =
@@ -1604,7 +1583,7 @@ reduction(accessor<T, Dims, AccMode, access::target::global_buffer, IsPH> &Acc,
16041583
/// The identity value is not passed to this version as it is statically known.
16051584
template <typename T, class BinaryOperation, int Dims, access::mode AccMode,
16061585
access::placeholder IsPH>
1607-
detail::enable_if_t<
1586+
std::enable_if_t<
16081587
detail::IsKnownIdentityOp<T, BinaryOperation>::value,
16091588
detail::reduction_impl<T, BinaryOperation, Dims, false, AccMode, IsPH>>
16101589
reduction(accessor<T, Dims, AccMode, access::target::global_buffer, IsPH> &Acc,
@@ -1632,16 +1611,18 @@ reduction(T *VarPtr, const T &Identity, BinaryOperation BOp) {
16321611
/// operation used in the reduction.
16331612
/// The identity value is not passed to this version as it is statically known.
16341613
template <typename T, class BinaryOperation>
1635-
detail::enable_if_t<detail::IsKnownIdentityOp<T, BinaryOperation>::value,
1636-
detail::reduction_impl<T, BinaryOperation, 0, true,
1637-
access::mode::read_write>>
1614+
std::enable_if_t<detail::IsKnownIdentityOp<T, BinaryOperation>::value,
1615+
detail::reduction_impl<T, BinaryOperation, 0, true,
1616+
access::mode::read_write>>
16381617
reduction(T *VarPtr, BinaryOperation) {
16391618
return detail::reduction_impl<T, BinaryOperation, 0, true,
16401619
access::mode::read_write>(VarPtr);
16411620
}
16421621

1622+
} // namespace ONEAPI
1623+
16431624
template <typename BinaryOperation, typename AccumulatorT>
1644-
struct has_known_identity : detail::has_known_identity_impl<
1625+
struct has_known_identity : ONEAPI::detail::has_known_identity_impl<
16451626
typename std::decay<BinaryOperation>::type,
16461627
typename std::decay<AccumulatorT>::type> {};
16471628
#if __cplusplus >= 201703L
@@ -1651,15 +1632,14 @@ inline constexpr bool has_known_identity_v =
16511632
#endif
16521633

16531634
template <typename BinaryOperation, typename AccumulatorT>
1654-
struct known_identity
1655-
: detail::known_identity_impl<typename std::decay<BinaryOperation>::type,
1656-
typename std::decay<AccumulatorT>::type> {};
1635+
struct known_identity : ONEAPI::detail::known_identity_impl<
1636+
typename std::decay<BinaryOperation>::type,
1637+
typename std::decay<AccumulatorT>::type> {};
16571638
#if __cplusplus >= 201703L
16581639
template <typename BinaryOperation, typename AccumulatorT>
16591640
inline constexpr AccumulatorT known_identity_v =
16601641
known_identity<BinaryOperation, AccumulatorT>::value;
16611642
#endif
16621643

1663-
} // namespace ONEAPI
16641644
} // namespace sycl
16651645
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/test/basic_tests/reduction_ctor.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111
using namespace cl::sycl;
1212

13+
bool toBool(bool V) { return V; }
14+
bool toBool(vec<int, 2> V) { return V.x() && V.y(); }
15+
bool toBool(vec<int, 4> V) { return V.x() && V.y() && V.z() && V.w(); }
16+
1317
template <typename T, typename Reduction>
1418
void test_reducer(Reduction &Redu, T A, T B) {
1519
typename Reduction::reducer_type Reducer;
@@ -29,7 +33,7 @@ void test_reducer(Reduction &Redu, T Identity, BinaryOperation BOp, T A, T B) {
2933
Reducer.combine(B);
3034

3135
T ExpectedValue = BOp(A, B);
32-
assert(ExpectedValue == Reducer.MValue &&
36+
assert(toBool(ExpectedValue == Reducer.MValue) &&
3337
"Wrong result of binary operation.");
3438
}
3539

@@ -40,14 +44,17 @@ template <typename SpecializationKernelName, typename T, int Dim,
4044
void testKnown(T Identity, BinaryOperation BOp, T A, T B) {
4145
buffer<T, 1> ReduBuf(1);
4246

47+
static_assert(has_known_identity<BinaryOperation, T>::value);
4348
queue Q;
4449
Q.submit([&](handler &CGH) {
4550
// Reduction needs a global_buffer accessor as a parameter.
4651
// This accessor is not really used in this test.
4752
accessor<T, Dim, access::mode::discard_write, access::target::global_buffer>
4853
ReduAcc(ReduBuf, CGH);
4954
auto Redu = ONEAPI::reduction(ReduAcc, BOp);
50-
assert(Redu.getIdentity() == Identity && "Failed getIdentity() check().");
55+
assert(toBool(Redu.getIdentity() == Identity) &&
56+
toBool(known_identity<BinaryOperation, T>::value == Identity) &&
57+
"Failed getIdentity() check().");
5158
test_reducer(Redu, A, B);
5259
test_reducer(Redu, Identity, BOp, A, B);
5360

@@ -67,7 +74,8 @@ void testUnknown(T Identity, BinaryOperation BOp, T A, T B) {
6774
accessor<T, Dim, access::mode::discard_write, access::target::global_buffer>
6875
ReduAcc(ReduBuf, CGH);
6976
auto Redu = ONEAPI::reduction(ReduAcc, Identity, BOp);
70-
assert(Redu.getIdentity() == Identity && "Failed getIdentity() check().");
77+
bool IsCorrectVal = toBool(Redu.getIdentity() == Identity);
78+
assert(IsCorrectVal && "Failed getIdentity() check().");
7179
test_reducer(Redu, Identity, BOp, A, B);
7280

7381
// Command group must have at least one task in it. Use an empty one.
@@ -124,6 +132,16 @@ int main() {
124132
testUnknown<class KernelName_zhF, int, 0>(
125133
0, [](auto a, auto b) { return a | b; }, 1, 8);
126134

135+
int2 IdentityI2 = {0, 0};
136+
int2 AI2 = {1, 2};
137+
int2 BI2 = {7, 13};
138+
testUnknown<class KNI2, int2, 0>(IdentityI2, ONEAPI::plus<int2>(), AI2, BI2);
139+
140+
float4 IdentityF4 = {0, 0, 0, 0};
141+
float4 AF4 = {1, 2, -1, -34};
142+
float4 BF4 = {7, 13, 0, 35};
143+
testUnknown<class KNF4, float4, 0>(IdentityF4, ONEAPI::plus<>(), AF4, BF4);
144+
127145
std::cout << "Test passed\n";
128146
return 0;
129147
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// RUN: %clangxx -fsyntax-only -Xclang -verify %s -I %sycl_include -Xclang -verify-ignore-unexpected=note,warning
2+
// expected-no-diagnostics
3+
4+
// This test performs basic checks of has_known_identity and known_identity
5+
// type traits.
6+
7+
#include <CL/sycl.hpp>
8+
#include <cassert>
9+
10+
using namespace cl::sycl;
11+
12+
template <typename T> void checkCommonBasicKnownIdentity() {
13+
static_assert(has_known_identity<ONEAPI::maximum<>, T>::value);
14+
static_assert(has_known_identity<ONEAPI::maximum<T>, T>::value);
15+
static_assert(has_known_identity<ONEAPI::minimum<>, T>::value);
16+
static_assert(has_known_identity<ONEAPI::minimum<T>, T>::value);
17+
}
18+
19+
template <typename T> void checkCommonKnownIdentity() {
20+
checkCommonBasicKnownIdentity<T>();
21+
22+
static_assert(has_known_identity<std::plus<>, T>::value);
23+
static_assert(has_known_identity<std::plus<T>, T>::value);
24+
static_assert(known_identity<std::plus<>, T>::value == 0);
25+
static_assert(known_identity<std::plus<T>, T>::value == 0);
26+
27+
static_assert(has_known_identity<std::multiplies<>, T>::value);
28+
static_assert(has_known_identity<std::multiplies<T>, T>::value);
29+
static_assert(known_identity<std::multiplies<>, T>::value == 1);
30+
static_assert(known_identity<std::multiplies<T>, T>::value == 1);
31+
}
32+
33+
template <typename T> void checkIntKnownIdentity() {
34+
checkCommonKnownIdentity<T>();
35+
36+
constexpr T Ones = ~static_cast<T>(0);
37+
static_assert(has_known_identity<std::bit_and<>, T>::value);
38+
static_assert(has_known_identity<std::bit_and<T>, T>::value);
39+
static_assert(known_identity<std::bit_and<>, T>::value == Ones);
40+
static_assert(known_identity<std::bit_and<T>, T>::value == Ones);
41+
42+
static_assert(has_known_identity<std::bit_or<>, T>::value);
43+
static_assert(has_known_identity<std::bit_or<T>, T>::value);
44+
static_assert(known_identity<std::bit_or<>, T>::value == 0);
45+
static_assert(known_identity<std::bit_or<T>, T>::value == 0);
46+
47+
static_assert(has_known_identity<std::bit_xor<>, T>::value);
48+
static_assert(has_known_identity<std::bit_xor<T>, T>::value);
49+
static_assert(known_identity<std::bit_xor<>, T>::value == 0);
50+
static_assert(known_identity<std::bit_xor<T>, T>::value == 0);
51+
}
52+
53+
int main() {
54+
checkIntKnownIdentity<int8_t>();
55+
checkIntKnownIdentity<char>();
56+
checkIntKnownIdentity<cl_char>();
57+
58+
checkIntKnownIdentity<int16_t>();
59+
checkIntKnownIdentity<short>();
60+
checkIntKnownIdentity<cl_short>();
61+
62+
checkIntKnownIdentity<int32_t>();
63+
checkIntKnownIdentity<int>();
64+
checkIntKnownIdentity<cl_int>();
65+
66+
checkIntKnownIdentity<long>();
67+
68+
checkIntKnownIdentity<int64_t>();
69+
checkIntKnownIdentity<long long>();
70+
checkIntKnownIdentity<cl_long>();
71+
72+
checkIntKnownIdentity<uint8_t>();
73+
checkIntKnownIdentity<unsigned char>();
74+
checkIntKnownIdentity<cl_uchar>();
75+
76+
checkIntKnownIdentity<uint16_t>();
77+
checkIntKnownIdentity<unsigned short>();
78+
checkIntKnownIdentity<cl_ushort>();
79+
80+
checkIntKnownIdentity<uint32_t>();
81+
checkIntKnownIdentity<unsigned int>();
82+
checkIntKnownIdentity<unsigned>();
83+
checkIntKnownIdentity<cl_uint>();
84+
85+
checkIntKnownIdentity<unsigned long>();
86+
87+
checkIntKnownIdentity<uint64_t>();
88+
checkIntKnownIdentity<unsigned long long>();
89+
checkIntKnownIdentity<cl_ulong>();
90+
checkIntKnownIdentity<std::size_t>();
91+
92+
checkCommonKnownIdentity<float>();
93+
checkCommonKnownIdentity<cl_float>();
94+
checkCommonKnownIdentity<double>();
95+
checkCommonKnownIdentity<cl_double>();
96+
97+
checkCommonBasicKnownIdentity<half>();
98+
checkCommonBasicKnownIdentity<sycl::cl_half>();
99+
checkCommonBasicKnownIdentity<::cl_half>();
100+
101+
// Few negative tests just to check that it does not always return true.
102+
static_assert(!has_known_identity<std::minus<>, int>::value);
103+
static_assert(!has_known_identity<ONEAPI::bit_or<>, float>::value);
104+
105+
return 0;
106+
}

0 commit comments

Comments
 (0)