Skip to content

Commit a3fc51a

Browse files
authored
[SYCL] Use macro SYCL_REDUCTION_DETERMINISTIC for stable reduction (#3876)
The macro SYCL_REDUCTION_DETERMINISTIC is not defined by default. This patch also has NFC portion which removes reduction only specific code from known_identity.hpp file. Signed-off-by: Vyacheslav N Klochkov <[email protected]>
1 parent d581178 commit a3fc51a

File tree

2 files changed

+78
-65
lines changed

2 files changed

+78
-65
lines changed

sycl/include/CL/sycl/ONEAPI/reduction.hpp

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,50 @@ namespace detail {
2727

2828
using cl::sycl::detail::bool_constant;
2929
using cl::sycl::detail::enable_if_t;
30-
using cl::sycl::detail::is_sgenfloat;
31-
using cl::sycl::detail::is_sgeninteger;
3230
using cl::sycl::detail::queue_impl;
3331
using cl::sycl::detail::remove_AS;
3432

33+
// This type trait is used to detect if the atomic operation BinaryOperation
34+
// used with operands of the type T is available for using in reduction.
35+
// The order in which the atomic operations are performed may be arbitrary and
36+
// thus may cause different results from run to run even on the same elements
37+
// and on same device. The macro SYCL_REDUCTION_DETERMINISTIC prohibits using
38+
// atomic operations for reduction and helps to produce stable results.
39+
// SYCL_REDUCTION_DETERMINISTIC is a short term solution, which perhaps become
40+
// deprecated eventually and is replaced by a sycl property passed to reduction.
41+
template <typename T, class BinaryOperation>
42+
using IsReduOptForFastAtomicFetch =
43+
#ifdef SYCL_REDUCTION_DETERMINISTIC
44+
bool_constant<false>;
45+
#else
46+
bool_constant<sycl::detail::is_sgeninteger<T>::value &&
47+
sycl::detail::IsValidAtomicType<T>::value &&
48+
(sycl::detail::IsPlus<T, BinaryOperation>::value ||
49+
sycl::detail::IsMinimum<T, BinaryOperation>::value ||
50+
sycl::detail::IsMaximum<T, BinaryOperation>::value ||
51+
sycl::detail::IsBitOR<T, BinaryOperation>::value ||
52+
sycl::detail::IsBitXOR<T, BinaryOperation>::value ||
53+
sycl::detail::IsBitAND<T, BinaryOperation>::value)>;
54+
#endif
55+
56+
// This type trait is used to detect if the group algorithm reduce() used with
57+
// operands of the type T and the operation BinaryOperation is available
58+
// for using in reduction.
59+
// The macro SYCL_REDUCTION_DETERMINISTIC prohibits using the reduce() algorithm
60+
// to produce stable results across same type devices.
61+
template <typename T, class BinaryOperation>
62+
using IsReduOptForFastReduce =
63+
#ifdef SYCL_REDUCTION_DETERMINISTIC
64+
bool_constant<false>;
65+
#else
66+
bool_constant<((sycl::detail::is_sgeninteger<T>::value &&
67+
(sizeof(T) == 4 || sizeof(T) == 8)) ||
68+
sycl::detail::is_sgenfloat<T>::value) &&
69+
(sycl::detail::IsPlus<T, BinaryOperation>::value ||
70+
sycl::detail::IsMinimum<T, BinaryOperation>::value ||
71+
sycl::detail::IsMaximum<T, BinaryOperation>::value)>;
72+
#endif
73+
3574
// std::tuple seems to be a) too heavy and b) not copyable to device now
3675
// Thus sycl::detail::tuple is used instead.
3776
// Switching from sycl::device::tuple to std::tuple can be done by re-defining
@@ -46,10 +85,6 @@ __SYCL_EXPORT size_t reduGetMaxWGSize(shared_ptr_class<queue_impl> Queue,
4685
__SYCL_EXPORT size_t reduComputeWGSize(size_t NWorkItems, size_t MaxWGSize,
4786
size_t &NWorkGroups);
4887

49-
50-
51-
52-
5388
/// Class that is used to represent objects that are passed to user's lambda
5489
/// functions and representing users' reduction variable.
5590
/// The generic version of the class represents those reductions of those
@@ -64,45 +99,45 @@ class reducer {
6499
T getIdentity() const { return MIdentity; }
65100

66101
template <typename _T = T>
67-
enable_if_t<IsReduPlus<_T, BinaryOperation>::value &&
102+
enable_if_t<sycl::detail::IsPlus<_T, BinaryOperation>::value &&
68103
sycl::detail::is_geninteger<_T>::value>
69104
operator++() {
70105
combine(static_cast<T>(1));
71106
}
72107

73108
template <typename _T = T>
74-
enable_if_t<IsReduPlus<_T, BinaryOperation>::value &&
109+
enable_if_t<sycl::detail::IsPlus<_T, BinaryOperation>::value &&
75110
sycl::detail::is_geninteger<_T>::value>
76111
operator++(int) {
77112
combine(static_cast<T>(1));
78113
}
79114

80115
template <typename _T = T>
81-
enable_if_t<IsReduPlus<_T, BinaryOperation>::value>
116+
enable_if_t<sycl::detail::IsPlus<_T, BinaryOperation>::value>
82117
operator+=(const _T &Partial) {
83118
combine(Partial);
84119
}
85120

86121
template <typename _T = T>
87-
enable_if_t<IsReduMultiplies<_T, BinaryOperation>::value>
122+
enable_if_t<sycl::detail::IsMultiplies<_T, BinaryOperation>::value>
88123
operator*=(const _T &Partial) {
89124
combine(Partial);
90125
}
91126

92127
template <typename _T = T>
93-
enable_if_t<IsReduBitOR<_T, BinaryOperation>::value>
128+
enable_if_t<sycl::detail::IsBitOR<_T, BinaryOperation>::value>
94129
operator|=(const _T &Partial) {
95130
combine(Partial);
96131
}
97132

98133
template <typename _T = T>
99-
enable_if_t<IsReduBitXOR<_T, BinaryOperation>::value>
134+
enable_if_t<sycl::detail::IsBitXOR<_T, BinaryOperation>::value>
100135
operator^=(const _T &Partial) {
101136
combine(Partial);
102137
}
103138

104139
template <typename _T = T>
105-
enable_if_t<IsReduBitAND<_T, BinaryOperation>::value>
140+
enable_if_t<sycl::detail::IsBitAND<_T, BinaryOperation>::value>
106141
operator&=(const _T &Partial) {
107142
combine(Partial);
108143
}
@@ -150,45 +185,45 @@ class reducer<T, BinaryOperation,
150185
}
151186

152187
template <typename _T = T>
153-
enable_if_t<IsReduPlus<_T, BinaryOperation>::value &&
188+
enable_if_t<sycl::detail::IsPlus<_T, BinaryOperation>::value &&
154189
sycl::detail::is_geninteger<_T>::value>
155190
operator++() {
156191
combine(static_cast<T>(1));
157192
}
158193

159194
template <typename _T = T>
160-
enable_if_t<IsReduPlus<_T, BinaryOperation>::value &&
195+
enable_if_t<sycl::detail::IsPlus<_T, BinaryOperation>::value &&
161196
sycl::detail::is_geninteger<_T>::value>
162197
operator++(int) {
163198
combine(static_cast<T>(1));
164199
}
165200

166201
template <typename _T = T>
167-
enable_if_t<IsReduPlus<_T, BinaryOperation>::value>
202+
enable_if_t<sycl::detail::IsPlus<_T, BinaryOperation>::value>
168203
operator+=(const _T &Partial) {
169204
combine(Partial);
170205
}
171206

172207
template <typename _T = T>
173-
enable_if_t<IsReduMultiplies<_T, BinaryOperation>::value>
208+
enable_if_t<sycl::detail::IsMultiplies<_T, BinaryOperation>::value>
174209
operator*=(const _T &Partial) {
175210
combine(Partial);
176211
}
177212

178213
template <typename _T = T>
179-
enable_if_t<IsReduBitOR<_T, BinaryOperation>::value>
214+
enable_if_t<sycl::detail::IsBitOR<_T, BinaryOperation>::value>
180215
operator|=(const _T &Partial) {
181216
combine(Partial);
182217
}
183218

184219
template <typename _T = T>
185-
enable_if_t<IsReduBitXOR<_T, BinaryOperation>::value>
220+
enable_if_t<sycl::detail::IsBitXOR<_T, BinaryOperation>::value>
186221
operator^=(const _T &Partial) {
187222
combine(Partial);
188223
}
189224

190225
template <typename _T = T>
191-
enable_if_t<IsReduBitAND<_T, BinaryOperation>::value>
226+
enable_if_t<sycl::detail::IsBitAND<_T, BinaryOperation>::value>
192227
operator&=(const _T &Partial) {
193228
combine(Partial);
194229
}
@@ -197,7 +232,7 @@ class reducer<T, BinaryOperation,
197232
template <typename _T = T, class _BinaryOperation = BinaryOperation>
198233
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
199234
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
200-
IsReduPlus<T, _BinaryOperation>::value>
235+
sycl::detail::IsPlus<T, _BinaryOperation>::value>
201236
atomic_combine(_T *ReduVarPtr) const {
202237
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
203238
.fetch_add(MValue);
@@ -207,7 +242,7 @@ class reducer<T, BinaryOperation,
207242
template <typename _T = T, class _BinaryOperation = BinaryOperation>
208243
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
209244
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
210-
IsReduBitOR<T, _BinaryOperation>::value>
245+
sycl::detail::IsBitOR<T, _BinaryOperation>::value>
211246
atomic_combine(_T *ReduVarPtr) const {
212247
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
213248
.fetch_or(MValue);
@@ -217,7 +252,7 @@ class reducer<T, BinaryOperation,
217252
template <typename _T = T, class _BinaryOperation = BinaryOperation>
218253
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
219254
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
220-
IsReduBitXOR<T, _BinaryOperation>::value>
255+
sycl::detail::IsBitXOR<T, _BinaryOperation>::value>
221256
atomic_combine(_T *ReduVarPtr) const {
222257
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
223258
.fetch_xor(MValue);
@@ -227,7 +262,7 @@ class reducer<T, BinaryOperation,
227262
template <typename _T = T, class _BinaryOperation = BinaryOperation>
228263
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
229264
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
230-
IsReduBitAND<T, _BinaryOperation>::value>
265+
sycl::detail::IsBitAND<T, _BinaryOperation>::value>
231266
atomic_combine(_T *ReduVarPtr) const {
232267
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
233268
.fetch_and(MValue);
@@ -237,7 +272,7 @@ class reducer<T, BinaryOperation,
237272
template <typename _T = T, class _BinaryOperation = BinaryOperation>
238273
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
239274
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
240-
IsReduMinimum<T, _BinaryOperation>::value>
275+
sycl::detail::IsMinimum<T, _BinaryOperation>::value>
241276
atomic_combine(_T *ReduVarPtr) const {
242277
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
243278
.fetch_min(MValue);
@@ -247,7 +282,7 @@ class reducer<T, BinaryOperation,
247282
template <typename _T = T, class _BinaryOperation = BinaryOperation>
248283
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
249284
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
250-
IsReduMaximum<T, _BinaryOperation>::value>
285+
sycl::detail::IsMaximum<T, _BinaryOperation>::value>
251286
atomic_combine(_T *ReduVarPtr) const {
252287
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
253288
.fetch_max(MValue);

sycl/include/CL/sycl/known_identity.hpp

Lines changed: 17 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,93 +17,71 @@ __SYCL_INLINE_NAMESPACE(cl) {
1717
namespace sycl {
1818
namespace detail {
1919

20-
using cl::sycl::detail::is_sgeninteger;
21-
2220
template <typename T, class BinaryOperation>
23-
using IsReduPlus =
21+
using IsPlus =
2422
bool_constant<std::is_same<BinaryOperation, ONEAPI::plus<T>>::value ||
2523
std::is_same<BinaryOperation, ONEAPI::plus<void>>::value>;
2624

2725
template <typename T, class BinaryOperation>
28-
using IsReduMultiplies =
29-
bool_constant<std::is_same<BinaryOperation, std::multiplies<T>>::value ||
30-
std::is_same<BinaryOperation, std::multiplies<void>>::value>;
26+
using IsMultiplies = bool_constant<
27+
std::is_same<BinaryOperation, ONEAPI::multiplies<T>>::value ||
28+
std::is_same<BinaryOperation, ONEAPI::multiplies<void>>::value>;
3129

3230
template <typename T, class BinaryOperation>
33-
using IsReduMinimum =
31+
using IsMinimum =
3432
bool_constant<std::is_same<BinaryOperation, ONEAPI::minimum<T>>::value ||
3533
std::is_same<BinaryOperation, ONEAPI::minimum<void>>::value>;
3634

3735
template <typename T, class BinaryOperation>
38-
using IsReduMaximum =
36+
using IsMaximum =
3937
bool_constant<std::is_same<BinaryOperation, ONEAPI::maximum<T>>::value ||
4038
std::is_same<BinaryOperation, ONEAPI::maximum<void>>::value>;
4139

4240
template <typename T, class BinaryOperation>
43-
using IsReduBitOR =
41+
using IsBitOR =
4442
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_or<T>>::value ||
4543
std::is_same<BinaryOperation, ONEAPI::bit_or<void>>::value>;
4644

4745
template <typename T, class BinaryOperation>
48-
using IsReduBitXOR =
46+
using IsBitXOR =
4947
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_xor<T>>::value ||
5048
std::is_same<BinaryOperation, ONEAPI::bit_xor<void>>::value>;
5149

5250
template <typename T, class BinaryOperation>
53-
using IsReduBitAND =
51+
using IsBitAND =
5452
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_and<T>>::value ||
5553
std::is_same<BinaryOperation, ONEAPI::bit_and<void>>::value>;
5654

57-
template <typename T, class BinaryOperation>
58-
using IsReduOptForFastAtomicFetch =
59-
bool_constant<is_sgeninteger<T>::value &&
60-
sycl::detail::IsValidAtomicType<T>::value &&
61-
(IsReduPlus<T, BinaryOperation>::value ||
62-
IsReduMinimum<T, BinaryOperation>::value ||
63-
IsReduMaximum<T, BinaryOperation>::value ||
64-
IsReduBitOR<T, BinaryOperation>::value ||
65-
IsReduBitXOR<T, BinaryOperation>::value ||
66-
IsReduBitAND<T, BinaryOperation>::value)>;
67-
68-
template <typename T, class BinaryOperation>
69-
using IsReduOptForFastReduce =
70-
bool_constant<((is_sgeninteger<T>::value &&
71-
(sizeof(T) == 4 || sizeof(T) == 8)) ||
72-
is_sgenfloat<T>::value) &&
73-
(IsReduPlus<T, BinaryOperation>::value ||
74-
IsReduMinimum<T, BinaryOperation>::value ||
75-
IsReduMaximum<T, BinaryOperation>::value)>;
76-
7755
// Identity = 0
7856
template <typename T, class BinaryOperation>
7957
using IsZeroIdentityOp = bool_constant<
80-
(is_sgeninteger<T>::value && (IsReduPlus<T, BinaryOperation>::value ||
81-
IsReduBitOR<T, BinaryOperation>::value ||
82-
IsReduBitXOR<T, BinaryOperation>::value)) ||
83-
(is_sgenfloat<T>::value && IsReduPlus<T, BinaryOperation>::value)>;
58+
(is_sgeninteger<T>::value &&
59+
(IsPlus<T, BinaryOperation>::value || IsBitOR<T, BinaryOperation>::value ||
60+
IsBitXOR<T, BinaryOperation>::value)) ||
61+
(is_sgenfloat<T>::value && IsPlus<T, BinaryOperation>::value)>;
8462

8563
// Identity = 1
8664
template <typename T, class BinaryOperation>
8765
using IsOneIdentityOp =
8866
bool_constant<(is_sgeninteger<T>::value || is_sgenfloat<T>::value) &&
89-
IsReduMultiplies<T, BinaryOperation>::value>;
67+
IsMultiplies<T, BinaryOperation>::value>;
9068

9169
// Identity = ~0
9270
template <typename T, class BinaryOperation>
9371
using IsOnesIdentityOp = bool_constant<is_sgeninteger<T>::value &&
94-
IsReduBitAND<T, BinaryOperation>::value>;
72+
IsBitAND<T, BinaryOperation>::value>;
9573

9674
// Identity = <max possible value>
9775
template <typename T, class BinaryOperation>
9876
using IsMinimumIdentityOp =
9977
bool_constant<(is_sgeninteger<T>::value || is_sgenfloat<T>::value) &&
100-
IsReduMinimum<T, BinaryOperation>::value>;
78+
IsMinimum<T, BinaryOperation>::value>;
10179

10280
// Identity = <min possible value>
10381
template <typename T, class BinaryOperation>
10482
using IsMaximumIdentityOp =
10583
bool_constant<(is_sgeninteger<T>::value || is_sgenfloat<T>::value) &&
106-
IsReduMaximum<T, BinaryOperation>::value>;
84+
IsMaximum<T, BinaryOperation>::value>;
10785

10886
template <typename T, class BinaryOperation>
10987
using IsKnownIdentityOp =

0 commit comments

Comments
 (0)