Skip to content

Commit 060fd50

Browse files
authored
[SYCL] Fix the type trait 'known_identity' (#3227)
* [SYCL] Fix the type trait 'known_identity' - Copy 'known_identity' and 'has_known_identity' from sycl::ONEAPI to sycl - Fix the compilation errors reported on attempts to use vector types like sycl::int4 as a reduction variable. Signed-off-by: Vyacheslav N Klochkov <[email protected]>
1 parent ee39a68 commit 060fd50

File tree

3 files changed

+218
-88
lines changed

3 files changed

+218
-88
lines changed

sycl/include/CL/sycl/ONEAPI/reduction.hpp

Lines changed: 91 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "CL/sycl/ONEAPI/accessor_property_list.hpp"
1212
#include <CL/sycl/ONEAPI/group_algorithm.hpp>
1313
#include <CL/sycl/accessor.hpp>
14+
#include <CL/sycl/atomic.hpp>
1415
#include <CL/sycl/handler.hpp>
1516
#include <CL/sycl/kernel.hpp>
1617

@@ -22,121 +23,103 @@ namespace ONEAPI {
2223

2324
namespace detail {
2425

26+
using cl::sycl::detail::bool_constant;
27+
using cl::sycl::detail::enable_if_t;
28+
using cl::sycl::detail::is_sgenfloat;
29+
using cl::sycl::detail::is_sgeninteger;
2530
using cl::sycl::detail::queue_impl;
31+
using cl::sycl::detail::remove_AS;
2632

2733
__SYCL_EXPORT size_t reduGetMaxWGSize(shared_ptr_class<queue_impl> Queue,
2834
size_t LocalMemBytesPerWorkItem);
2935
__SYCL_EXPORT size_t reduComputeWGSize(size_t NWorkItems, size_t MaxWGSize,
3036
size_t &NWorkGroups);
3137

32-
using cl::sycl::detail::bool_constant;
33-
using cl::sycl::detail::enable_if_t;
34-
using cl::sycl::detail::is_geninteger16bit;
35-
using cl::sycl::detail::is_geninteger32bit;
36-
using cl::sycl::detail::is_geninteger64bit;
37-
using cl::sycl::detail::is_geninteger8bit;
38-
using cl::sycl::detail::remove_AS;
39-
4038
template <typename T, class BinaryOperation>
41-
using IsReduPlus = detail::bool_constant<
42-
std::is_same<BinaryOperation, ONEAPI::plus<T>>::value ||
43-
std::is_same<BinaryOperation, ONEAPI::plus<void>>::value>;
39+
using IsReduPlus =
40+
bool_constant<std::is_same<BinaryOperation, ONEAPI::plus<T>>::value ||
41+
std::is_same<BinaryOperation, ONEAPI::plus<void>>::value>;
4442

4543
template <typename T, class BinaryOperation>
46-
using IsReduMultiplies = detail::bool_constant<
47-
std::is_same<BinaryOperation, std::multiplies<T>>::value ||
48-
std::is_same<BinaryOperation, std::multiplies<void>>::value>;
44+
using IsReduMultiplies =
45+
bool_constant<std::is_same<BinaryOperation, std::multiplies<T>>::value ||
46+
std::is_same<BinaryOperation, std::multiplies<void>>::value>;
4947

5048
template <typename T, class BinaryOperation>
51-
using IsReduMinimum = detail::bool_constant<
52-
std::is_same<BinaryOperation, ONEAPI::minimum<T>>::value ||
53-
std::is_same<BinaryOperation, ONEAPI::minimum<void>>::value>;
49+
using IsReduMinimum =
50+
bool_constant<std::is_same<BinaryOperation, ONEAPI::minimum<T>>::value ||
51+
std::is_same<BinaryOperation, ONEAPI::minimum<void>>::value>;
5452

5553
template <typename T, class BinaryOperation>
56-
using IsReduMaximum = detail::bool_constant<
57-
std::is_same<BinaryOperation, ONEAPI::maximum<T>>::value ||
58-
std::is_same<BinaryOperation, ONEAPI::maximum<void>>::value>;
54+
using IsReduMaximum =
55+
bool_constant<std::is_same<BinaryOperation, ONEAPI::maximum<T>>::value ||
56+
std::is_same<BinaryOperation, ONEAPI::maximum<void>>::value>;
5957

6058
template <typename T, class BinaryOperation>
61-
using IsReduBitOR = detail::bool_constant<
62-
std::is_same<BinaryOperation, ONEAPI::bit_or<T>>::value ||
63-
std::is_same<BinaryOperation, ONEAPI::bit_or<void>>::value>;
59+
using IsReduBitOR =
60+
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_or<T>>::value ||
61+
std::is_same<BinaryOperation, ONEAPI::bit_or<void>>::value>;
6462

6563
template <typename T, class BinaryOperation>
66-
using IsReduBitXOR = detail::bool_constant<
67-
std::is_same<BinaryOperation, ONEAPI::bit_xor<T>>::value ||
68-
std::is_same<BinaryOperation, ONEAPI::bit_xor<void>>::value>;
64+
using IsReduBitXOR =
65+
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_xor<T>>::value ||
66+
std::is_same<BinaryOperation, ONEAPI::bit_xor<void>>::value>;
6967

7068
template <typename T, class BinaryOperation>
71-
using IsReduBitAND = detail::bool_constant<
72-
std::is_same<BinaryOperation, ONEAPI::bit_and<T>>::value ||
73-
std::is_same<BinaryOperation, ONEAPI::bit_and<void>>::value>;
69+
using IsReduBitAND =
70+
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_and<T>>::value ||
71+
std::is_same<BinaryOperation, ONEAPI::bit_and<void>>::value>;
7472

7573
template <typename T, class BinaryOperation>
7674
using IsReduOptForFastAtomicFetch =
77-
detail::bool_constant<(is_geninteger32bit<T>::value ||
78-
is_geninteger64bit<T>::value) &&
79-
(IsReduPlus<T, BinaryOperation>::value ||
80-
IsReduMinimum<T, BinaryOperation>::value ||
81-
IsReduMaximum<T, BinaryOperation>::value ||
82-
IsReduBitOR<T, BinaryOperation>::value ||
83-
IsReduBitXOR<T, BinaryOperation>::value ||
84-
IsReduBitAND<T, BinaryOperation>::value)>;
75+
bool_constant<is_sgeninteger<T>::value &&
76+
sycl::detail::IsValidAtomicType<T>::value &&
77+
(IsReduPlus<T, BinaryOperation>::value ||
78+
IsReduMinimum<T, BinaryOperation>::value ||
79+
IsReduMaximum<T, BinaryOperation>::value ||
80+
IsReduBitOR<T, BinaryOperation>::value ||
81+
IsReduBitXOR<T, BinaryOperation>::value ||
82+
IsReduBitAND<T, BinaryOperation>::value)>;
8583

8684
template <typename T, class BinaryOperation>
87-
using IsReduOptForFastReduce = detail::bool_constant<
88-
(is_geninteger32bit<T>::value || is_geninteger64bit<T>::value ||
89-
std::is_same<T, half>::value || std::is_same<T, float>::value ||
90-
std::is_same<T, double>::value) &&
91-
(IsReduPlus<T, BinaryOperation>::value ||
92-
IsReduMinimum<T, BinaryOperation>::value ||
93-
IsReduMaximum<T, BinaryOperation>::value)>;
85+
using IsReduOptForFastReduce =
86+
bool_constant<((is_sgeninteger<T>::value &&
87+
(sizeof(T) == 32 || sizeof(T) == 64)) ||
88+
is_sgenfloat<T>::value) &&
89+
(IsReduPlus<T, BinaryOperation>::value ||
90+
IsReduMinimum<T, BinaryOperation>::value ||
91+
IsReduMaximum<T, BinaryOperation>::value)>;
9492

9593
// Identity = 0
9694
template <typename T, class BinaryOperation>
9795
using IsZeroIdentityOp = bool_constant<
98-
((is_geninteger8bit<T>::value || is_geninteger16bit<T>::value ||
99-
is_geninteger32bit<T>::value || is_geninteger64bit<T>::value) &&
100-
(IsReduPlus<T, BinaryOperation>::value ||
101-
IsReduBitOR<T, BinaryOperation>::value ||
102-
IsReduBitXOR<T, BinaryOperation>::value)) ||
103-
((std::is_same<T, half>::value || std::is_same<T, float>::value ||
104-
std::is_same<T, double>::value) &&
105-
IsReduPlus<T, BinaryOperation>::value)>;
96+
(is_sgeninteger<T>::value && (IsReduPlus<T, BinaryOperation>::value ||
97+
IsReduBitOR<T, BinaryOperation>::value ||
98+
IsReduBitXOR<T, BinaryOperation>::value)) ||
99+
(is_sgenfloat<T>::value && IsReduPlus<T, BinaryOperation>::value)>;
106100

107101
// Identity = 1
108102
template <typename T, class BinaryOperation>
109-
using IsOneIdentityOp = bool_constant<
110-
(is_geninteger8bit<T>::value || is_geninteger16bit<T>::value ||
111-
is_geninteger32bit<T>::value || is_geninteger64bit<T>::value ||
112-
std::is_same<T, half>::value || std::is_same<T, float>::value ||
113-
std::is_same<T, double>::value) &&
114-
IsReduMultiplies<T, BinaryOperation>::value>;
103+
using IsOneIdentityOp =
104+
bool_constant<(is_sgeninteger<T>::value || is_sgenfloat<T>::value) &&
105+
IsReduMultiplies<T, BinaryOperation>::value>;
115106

116107
// Identity = ~0
117108
template <typename T, class BinaryOperation>
118-
using IsOnesIdentityOp = bool_constant<
119-
(is_geninteger8bit<T>::value || is_geninteger16bit<T>::value ||
120-
is_geninteger32bit<T>::value || is_geninteger64bit<T>::value) &&
121-
IsReduBitAND<T, BinaryOperation>::value>;
109+
using IsOnesIdentityOp = bool_constant<is_sgeninteger<T>::value &&
110+
IsReduBitAND<T, BinaryOperation>::value>;
122111

123112
// Identity = <max possible value>
124113
template <typename T, class BinaryOperation>
125-
using IsMinimumIdentityOp = bool_constant<
126-
(is_geninteger8bit<T>::value || is_geninteger16bit<T>::value ||
127-
is_geninteger32bit<T>::value || is_geninteger64bit<T>::value ||
128-
std::is_same<T, half>::value || std::is_same<T, float>::value ||
129-
std::is_same<T, double>::value) &&
130-
IsReduMinimum<T, BinaryOperation>::value>;
114+
using IsMinimumIdentityOp =
115+
bool_constant<(is_sgeninteger<T>::value || is_sgenfloat<T>::value) &&
116+
IsReduMinimum<T, BinaryOperation>::value>;
131117

132118
// Identity = <min possible value>
133119
template <typename T, class BinaryOperation>
134-
using IsMaximumIdentityOp = bool_constant<
135-
(is_geninteger8bit<T>::value || is_geninteger16bit<T>::value ||
136-
is_geninteger32bit<T>::value || is_geninteger64bit<T>::value ||
137-
std::is_same<T, half>::value || std::is_same<T, float>::value ||
138-
std::is_same<T, double>::value) &&
139-
IsReduMaximum<T, BinaryOperation>::value>;
120+
using IsMaximumIdentityOp =
121+
bool_constant<(is_sgeninteger<T>::value || is_sgenfloat<T>::value) &&
122+
IsReduMaximum<T, BinaryOperation>::value>;
140123

141124
template <typename T, class BinaryOperation>
142125
using IsKnownIdentityOp =
@@ -343,7 +326,7 @@ class reducer<T, BinaryOperation,
343326
/// Atomic ADD operation: *ReduVarPtr += MValue;
344327
template <typename _T = T, class _BinaryOperation = BinaryOperation>
345328
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
346-
(is_geninteger32bit<T>::value || is_geninteger64bit<T>::value) &&
329+
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
347330
IsReduPlus<T, _BinaryOperation>::value>
348331
atomic_combine(_T *ReduVarPtr) const {
349332
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
@@ -353,7 +336,7 @@ class reducer<T, BinaryOperation,
353336
/// Atomic BITWISE OR operation: *ReduVarPtr |= MValue;
354337
template <typename _T = T, class _BinaryOperation = BinaryOperation>
355338
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
356-
(is_geninteger32bit<T>::value || is_geninteger64bit<T>::value) &&
339+
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
357340
IsReduBitOR<T, _BinaryOperation>::value>
358341
atomic_combine(_T *ReduVarPtr) const {
359342
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
@@ -363,7 +346,7 @@ class reducer<T, BinaryOperation,
363346
/// Atomic BITWISE XOR operation: *ReduVarPtr ^= MValue;
364347
template <typename _T = T, class _BinaryOperation = BinaryOperation>
365348
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
366-
(is_geninteger32bit<T>::value || is_geninteger64bit<T>::value) &&
349+
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
367350
IsReduBitXOR<T, _BinaryOperation>::value>
368351
atomic_combine(_T *ReduVarPtr) const {
369352
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
@@ -373,7 +356,7 @@ class reducer<T, BinaryOperation,
373356
/// Atomic BITWISE AND operation: *ReduVarPtr &= MValue;
374357
template <typename _T = T, class _BinaryOperation = BinaryOperation>
375358
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
376-
(is_geninteger32bit<T>::value || is_geninteger64bit<T>::value) &&
359+
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
377360
IsReduBitAND<T, _BinaryOperation>::value>
378361
atomic_combine(_T *ReduVarPtr) const {
379362
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
@@ -383,7 +366,7 @@ class reducer<T, BinaryOperation,
383366
/// Atomic MIN operation: *ReduVarPtr = ONEAPI::minimum(*ReduVarPtr, MValue);
384367
template <typename _T = T, class _BinaryOperation = BinaryOperation>
385368
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
386-
(is_geninteger32bit<T>::value || is_geninteger64bit<T>::value) &&
369+
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
387370
IsReduMinimum<T, _BinaryOperation>::value>
388371
atomic_combine(_T *ReduVarPtr) const {
389372
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
@@ -393,7 +376,7 @@ class reducer<T, BinaryOperation,
393376
/// Atomic MAX operation: *ReduVarPtr = ONEAPI::maximum(*ReduVarPtr, MValue);
394377
template <typename _T = T, class _BinaryOperation = BinaryOperation>
395378
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
396-
(is_geninteger32bit<T>::value || is_geninteger64bit<T>::value) &&
379+
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
397380
IsReduMaximum<T, _BinaryOperation>::value>
398381
atomic_combine(_T *ReduVarPtr) const {
399382
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
@@ -1604,7 +1587,7 @@ reduction(accessor<T, Dims, AccMode, access::target::global_buffer, IsPH> &Acc,
16041587
/// The identity value is not passed to this version as it is statically known.
16051588
template <typename T, class BinaryOperation, int Dims, access::mode AccMode,
16061589
access::placeholder IsPH>
1607-
detail::enable_if_t<
1590+
std::enable_if_t<
16081591
detail::IsKnownIdentityOp<T, BinaryOperation>::value,
16091592
detail::reduction_impl<T, BinaryOperation, Dims, false, AccMode, IsPH>>
16101593
reduction(accessor<T, Dims, AccMode, access::target::global_buffer, IsPH> &Acc,
@@ -1632,9 +1615,9 @@ reduction(T *VarPtr, const T &Identity, BinaryOperation BOp) {
16321615
/// operation used in the reduction.
16331616
/// The identity value is not passed to this version as it is statically known.
16341617
template <typename T, class BinaryOperation>
1635-
detail::enable_if_t<detail::IsKnownIdentityOp<T, BinaryOperation>::value,
1636-
detail::reduction_impl<T, BinaryOperation, 0, true,
1637-
access::mode::read_write>>
1618+
std::enable_if_t<detail::IsKnownIdentityOp<T, BinaryOperation>::value,
1619+
detail::reduction_impl<T, BinaryOperation, 0, true,
1620+
access::mode::read_write>>
16381621
reduction(T *VarPtr, BinaryOperation) {
16391622
return detail::reduction_impl<T, BinaryOperation, 0, true,
16401623
access::mode::read_write>(VarPtr);
@@ -1659,7 +1642,30 @@ template <typename BinaryOperation, typename AccumulatorT>
16591642
inline constexpr AccumulatorT known_identity_v =
16601643
known_identity<BinaryOperation, AccumulatorT>::value;
16611644
#endif
1662-
16631645
} // namespace ONEAPI
1646+
1647+
// Currently, the type traits defined below correspond to SYCL 1.2.1 ONEAPI
1648+
// reduction extension. That may be changed later when SYCL 2020 reductions
1649+
// are implemented.
1650+
template <typename BinaryOperation, typename AccumulatorT>
1651+
struct has_known_identity
1652+
: ONEAPI::has_known_identity<BinaryOperation, AccumulatorT> {};
1653+
1654+
#if __cplusplus >= 201703L
1655+
template <typename BinaryOperation, typename AccumulatorT>
1656+
inline constexpr bool has_known_identity_v =
1657+
has_known_identity<BinaryOperation, AccumulatorT>::value;
1658+
#endif
1659+
1660+
template <typename BinaryOperation, typename AccumulatorT>
1661+
struct known_identity : ONEAPI::known_identity<BinaryOperation, AccumulatorT> {
1662+
};
1663+
1664+
#if __cplusplus >= 201703L
1665+
template <typename BinaryOperation, typename AccumulatorT>
1666+
inline constexpr AccumulatorT known_identity_v =
1667+
known_identity<BinaryOperation, AccumulatorT>::value;
1668+
#endif
1669+
16641670
} // namespace sycl
16651671
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/test/basic_tests/reduction_ctor.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111
using namespace cl::sycl;
1212

13+
bool toBool(bool V) { return V; }
14+
bool toBool(vec<int, 2> V) { return V.x() && V.y(); }
15+
bool toBool(vec<int, 4> V) { return V.x() && V.y() && V.z() && V.w(); }
16+
1317
template <typename T, typename Reduction>
1418
void test_reducer(Reduction &Redu, T A, T B) {
1519
typename Reduction::reducer_type Reducer;
@@ -29,7 +33,7 @@ void test_reducer(Reduction &Redu, T Identity, BinaryOperation BOp, T A, T B) {
2933
Reducer.combine(B);
3034

3135
T ExpectedValue = BOp(A, B);
32-
assert(ExpectedValue == Reducer.MValue &&
36+
assert(toBool(ExpectedValue == Reducer.MValue) &&
3337
"Wrong result of binary operation.");
3438
}
3539

@@ -40,14 +44,17 @@ template <typename SpecializationKernelName, typename T, int Dim,
4044
void testKnown(T Identity, BinaryOperation BOp, T A, T B) {
4145
buffer<T, 1> ReduBuf(1);
4246

47+
static_assert(has_known_identity<BinaryOperation, T>::value);
4348
queue Q;
4449
Q.submit([&](handler &CGH) {
4550
// Reduction needs a global_buffer accessor as a parameter.
4651
// This accessor is not really used in this test.
4752
accessor<T, Dim, access::mode::discard_write, access::target::global_buffer>
4853
ReduAcc(ReduBuf, CGH);
4954
auto Redu = ONEAPI::reduction(ReduAcc, BOp);
50-
assert(Redu.getIdentity() == Identity && "Failed getIdentity() check().");
55+
assert(toBool(Redu.getIdentity() == Identity) &&
56+
toBool(known_identity<BinaryOperation, T>::value == Identity) &&
57+
"Failed getIdentity() check().");
5158
test_reducer(Redu, A, B);
5259
test_reducer(Redu, Identity, BOp, A, B);
5360

@@ -67,7 +74,8 @@ void testUnknown(T Identity, BinaryOperation BOp, T A, T B) {
6774
accessor<T, Dim, access::mode::discard_write, access::target::global_buffer>
6875
ReduAcc(ReduBuf, CGH);
6976
auto Redu = ONEAPI::reduction(ReduAcc, Identity, BOp);
70-
assert(Redu.getIdentity() == Identity && "Failed getIdentity() check().");
77+
bool IsCorrectVal = toBool(Redu.getIdentity() == Identity);
78+
assert(IsCorrectVal && "Failed getIdentity() check().");
7179
test_reducer(Redu, Identity, BOp, A, B);
7280

7381
// Command group must have at least one task in it. Use an empty one.
@@ -124,6 +132,16 @@ int main() {
124132
testUnknown<class KernelName_zhF, int, 0>(
125133
0, [](auto a, auto b) { return a | b; }, 1, 8);
126134

135+
int2 IdentityI2 = {0, 0};
136+
int2 AI2 = {1, 2};
137+
int2 BI2 = {7, 13};
138+
testUnknown<class KNI2, int2, 0>(IdentityI2, ONEAPI::plus<int2>(), AI2, BI2);
139+
140+
float4 IdentityF4 = {0, 0, 0, 0};
141+
float4 AF4 = {1, 2, -1, -34};
142+
float4 BF4 = {7, 13, 0, 35};
143+
testUnknown<class KNF4, float4, 0>(IdentityF4, ONEAPI::plus<>(), AF4, BF4);
144+
127145
std::cout << "Test passed\n";
128146
return 0;
129147
}

0 commit comments

Comments
 (0)