Skip to content

Commit 8b9f826

Browse files
committed
[SYCL] Use macro SYCL_REDUCTION_DETERMINISTIC for stable reduction results
The macro SYCL_REDUCTION_DETERMINISTIC is not defined by default. This patch also has NFC portion which removes reduction only specific code from known_identity.hpp file. Signed-off-by: Vyacheslav N Klochkov <[email protected]>
1 parent 9216b49 commit 8b9f826

File tree

2 files changed

+69
-63
lines changed

2 files changed

+69
-63
lines changed

sycl/include/CL/sycl/ONEAPI/reduction.hpp

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,38 @@ using cl::sycl::detail::is_sgeninteger;
3232
using cl::sycl::detail::queue_impl;
3333
using cl::sycl::detail::remove_AS;
3434

35+
// This type trait is used to detect if the atomic operation BinaryOperation
36+
// used with operands of the type T is available for using in reduction.
37+
// The order in which the atomic operations are performed may be arbitrary and
38+
// thus may cause different results from run to run even on the same elements
39+
// and on same device. The macro SYCL_REDUCTION_DETERMINISTIC prohibits using
40+
// atomic operations for reduction and helps to produce stable results.
41+
// SYCL_REDUCTION_DETERMINISTIC is a short term solution, which perhaps become
42+
// deprecated eventually and is replaced by a sycl property passed to reduction.
43+
template <typename T, class BinaryOperation>
44+
using IsReduOptForFastAtomicFetch =
45+
#ifdef SYCL_REDUCTION_DETERMINISTIC
46+
bool_constant<false>;
47+
#else
48+
bool_constant<is_sgeninteger<T>::value &&
49+
sycl::detail::IsValidAtomicType<T>::value &&
50+
(sycl::detail::IsPlus<T, BinaryOperation>::value ||
51+
sycl::detail::IsMinimum<T, BinaryOperation>::value ||
52+
sycl::detail::IsMaximum<T, BinaryOperation>::value ||
53+
sycl::detail::IsBitOR<T, BinaryOperation>::value ||
54+
sycl::detail::IsBitXOR<T, BinaryOperation>::value ||
55+
sycl::detail::IsBitAND<T, BinaryOperation>::value)>;
56+
#endif
57+
58+
template <typename T, class BinaryOperation>
59+
using IsReduOptForFastReduce =
60+
bool_constant<((is_sgeninteger<T>::value &&
61+
(sizeof(T) == 4 || sizeof(T) == 8)) ||
62+
is_sgenfloat<T>::value) &&
63+
(sycl::detail::IsPlus<T, BinaryOperation>::value ||
64+
sycl::detail::IsMinimum<T, BinaryOperation>::value ||
65+
sycl::detail::IsMaximum<T, BinaryOperation>::value)>;
66+
3567
// std::tuple seems to be a) too heavy and b) not copyable to device now
3668
// Thus sycl::detail::tuple is used instead.
3769
// Switching from sycl::device::tuple to std::tuple can be done by re-defining
@@ -46,10 +78,6 @@ __SYCL_EXPORT size_t reduGetMaxWGSize(shared_ptr_class<queue_impl> Queue,
4678
__SYCL_EXPORT size_t reduComputeWGSize(size_t NWorkItems, size_t MaxWGSize,
4779
size_t &NWorkGroups);
4880

49-
50-
51-
52-
5381
/// Class that is used to represent objects that are passed to user's lambda
5482
/// functions and representing users' reduction variable.
5583
/// The generic version of the class represents those reductions of those
@@ -64,45 +92,45 @@ class reducer {
6492
T getIdentity() const { return MIdentity; }
6593

6694
template <typename _T = T>
67-
enable_if_t<IsReduPlus<_T, BinaryOperation>::value &&
95+
enable_if_t<sycl::detail::IsPlus<_T, BinaryOperation>::value &&
6896
sycl::detail::is_geninteger<_T>::value>
6997
operator++() {
7098
combine(static_cast<T>(1));
7199
}
72100

73101
template <typename _T = T>
74-
enable_if_t<IsReduPlus<_T, BinaryOperation>::value &&
102+
enable_if_t<sycl::detail::IsPlus<_T, BinaryOperation>::value &&
75103
sycl::detail::is_geninteger<_T>::value>
76104
operator++(int) {
77105
combine(static_cast<T>(1));
78106
}
79107

80108
template <typename _T = T>
81-
enable_if_t<IsReduPlus<_T, BinaryOperation>::value>
109+
enable_if_t<sycl::detail::IsPlus<_T, BinaryOperation>::value>
82110
operator+=(const _T &Partial) {
83111
combine(Partial);
84112
}
85113

86114
template <typename _T = T>
87-
enable_if_t<IsReduMultiplies<_T, BinaryOperation>::value>
115+
enable_if_t<sycl::detail::IsMultiplies<_T, BinaryOperation>::value>
88116
operator*=(const _T &Partial) {
89117
combine(Partial);
90118
}
91119

92120
template <typename _T = T>
93-
enable_if_t<IsReduBitOR<_T, BinaryOperation>::value>
121+
enable_if_t<sycl::detail::IsBitOR<_T, BinaryOperation>::value>
94122
operator|=(const _T &Partial) {
95123
combine(Partial);
96124
}
97125

98126
template <typename _T = T>
99-
enable_if_t<IsReduBitXOR<_T, BinaryOperation>::value>
127+
enable_if_t<sycl::detail::IsBitXOR<_T, BinaryOperation>::value>
100128
operator^=(const _T &Partial) {
101129
combine(Partial);
102130
}
103131

104132
template <typename _T = T>
105-
enable_if_t<IsReduBitAND<_T, BinaryOperation>::value>
133+
enable_if_t<sycl::detail::IsBitAND<_T, BinaryOperation>::value>
106134
operator&=(const _T &Partial) {
107135
combine(Partial);
108136
}
@@ -150,45 +178,45 @@ class reducer<T, BinaryOperation,
150178
}
151179

152180
template <typename _T = T>
153-
enable_if_t<IsReduPlus<_T, BinaryOperation>::value &&
181+
enable_if_t<sycl::detail::IsPlus<_T, BinaryOperation>::value &&
154182
sycl::detail::is_geninteger<_T>::value>
155183
operator++() {
156184
combine(static_cast<T>(1));
157185
}
158186

159187
template <typename _T = T>
160-
enable_if_t<IsReduPlus<_T, BinaryOperation>::value &&
188+
enable_if_t<sycl::detail::IsPlus<_T, BinaryOperation>::value &&
161189
sycl::detail::is_geninteger<_T>::value>
162190
operator++(int) {
163191
combine(static_cast<T>(1));
164192
}
165193

166194
template <typename _T = T>
167-
enable_if_t<IsReduPlus<_T, BinaryOperation>::value>
195+
enable_if_t<sycl::detail::IsPlus<_T, BinaryOperation>::value>
168196
operator+=(const _T &Partial) {
169197
combine(Partial);
170198
}
171199

172200
template <typename _T = T>
173-
enable_if_t<IsReduMultiplies<_T, BinaryOperation>::value>
201+
enable_if_t<sycl::detail::IsMultiplies<_T, BinaryOperation>::value>
174202
operator*=(const _T &Partial) {
175203
combine(Partial);
176204
}
177205

178206
template <typename _T = T>
179-
enable_if_t<IsReduBitOR<_T, BinaryOperation>::value>
207+
enable_if_t<sycl::detail::IsBitOR<_T, BinaryOperation>::value>
180208
operator|=(const _T &Partial) {
181209
combine(Partial);
182210
}
183211

184212
template <typename _T = T>
185-
enable_if_t<IsReduBitXOR<_T, BinaryOperation>::value>
213+
enable_if_t<sycl::detail::IsBitXOR<_T, BinaryOperation>::value>
186214
operator^=(const _T &Partial) {
187215
combine(Partial);
188216
}
189217

190218
template <typename _T = T>
191-
enable_if_t<IsReduBitAND<_T, BinaryOperation>::value>
219+
enable_if_t<sycl::detail::IsBitAND<_T, BinaryOperation>::value>
192220
operator&=(const _T &Partial) {
193221
combine(Partial);
194222
}
@@ -197,7 +225,7 @@ class reducer<T, BinaryOperation,
197225
template <typename _T = T, class _BinaryOperation = BinaryOperation>
198226
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
199227
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
200-
IsReduPlus<T, _BinaryOperation>::value>
228+
sycl::detail::IsPlus<T, _BinaryOperation>::value>
201229
atomic_combine(_T *ReduVarPtr) const {
202230
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
203231
.fetch_add(MValue);
@@ -207,7 +235,7 @@ class reducer<T, BinaryOperation,
207235
template <typename _T = T, class _BinaryOperation = BinaryOperation>
208236
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
209237
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
210-
IsReduBitOR<T, _BinaryOperation>::value>
238+
sycl::detail::IsBitOR<T, _BinaryOperation>::value>
211239
atomic_combine(_T *ReduVarPtr) const {
212240
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
213241
.fetch_or(MValue);
@@ -217,7 +245,7 @@ class reducer<T, BinaryOperation,
217245
template <typename _T = T, class _BinaryOperation = BinaryOperation>
218246
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
219247
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
220-
IsReduBitXOR<T, _BinaryOperation>::value>
248+
sycl::detail::IsBitXOR<T, _BinaryOperation>::value>
221249
atomic_combine(_T *ReduVarPtr) const {
222250
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
223251
.fetch_xor(MValue);
@@ -227,7 +255,7 @@ class reducer<T, BinaryOperation,
227255
template <typename _T = T, class _BinaryOperation = BinaryOperation>
228256
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
229257
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
230-
IsReduBitAND<T, _BinaryOperation>::value>
258+
sycl::detail::IsBitAND<T, _BinaryOperation>::value>
231259
atomic_combine(_T *ReduVarPtr) const {
232260
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
233261
.fetch_and(MValue);
@@ -237,7 +265,7 @@ class reducer<T, BinaryOperation,
237265
template <typename _T = T, class _BinaryOperation = BinaryOperation>
238266
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
239267
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
240-
IsReduMinimum<T, _BinaryOperation>::value>
268+
sycl::detail::IsMinimum<T, _BinaryOperation>::value>
241269
atomic_combine(_T *ReduVarPtr) const {
242270
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
243271
.fetch_min(MValue);
@@ -247,7 +275,7 @@ class reducer<T, BinaryOperation,
247275
template <typename _T = T, class _BinaryOperation = BinaryOperation>
248276
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
249277
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
250-
IsReduMaximum<T, _BinaryOperation>::value>
278+
sycl::detail::IsMaximum<T, _BinaryOperation>::value>
251279
atomic_combine(_T *ReduVarPtr) const {
252280
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
253281
.fetch_max(MValue);

sycl/include/CL/sycl/known_identity.hpp

Lines changed: 17 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,93 +17,71 @@ __SYCL_INLINE_NAMESPACE(cl) {
1717
namespace sycl {
1818
namespace detail {
1919

20-
using cl::sycl::detail::is_sgeninteger;
21-
2220
template <typename T, class BinaryOperation>
23-
using IsReduPlus =
21+
using IsPlus =
2422
bool_constant<std::is_same<BinaryOperation, ONEAPI::plus<T>>::value ||
2523
std::is_same<BinaryOperation, ONEAPI::plus<void>>::value>;
2624

2725
template <typename T, class BinaryOperation>
28-
using IsReduMultiplies =
29-
bool_constant<std::is_same<BinaryOperation, std::multiplies<T>>::value ||
30-
std::is_same<BinaryOperation, std::multiplies<void>>::value>;
26+
using IsMultiplies = bool_constant<
27+
std::is_same<BinaryOperation, ONEAPI::multiplies<T>>::value ||
28+
std::is_same<BinaryOperation, ONEAPI::multiplies<void>>::value>;
3129

3230
template <typename T, class BinaryOperation>
33-
using IsReduMinimum =
31+
using IsMinimum =
3432
bool_constant<std::is_same<BinaryOperation, ONEAPI::minimum<T>>::value ||
3533
std::is_same<BinaryOperation, ONEAPI::minimum<void>>::value>;
3634

3735
template <typename T, class BinaryOperation>
38-
using IsReduMaximum =
36+
using IsMaximum =
3937
bool_constant<std::is_same<BinaryOperation, ONEAPI::maximum<T>>::value ||
4038
std::is_same<BinaryOperation, ONEAPI::maximum<void>>::value>;
4139

4240
template <typename T, class BinaryOperation>
43-
using IsReduBitOR =
41+
using IsBitOR =
4442
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_or<T>>::value ||
4543
std::is_same<BinaryOperation, ONEAPI::bit_or<void>>::value>;
4644

4745
template <typename T, class BinaryOperation>
48-
using IsReduBitXOR =
46+
using IsBitXOR =
4947
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_xor<T>>::value ||
5048
std::is_same<BinaryOperation, ONEAPI::bit_xor<void>>::value>;
5149

5250
template <typename T, class BinaryOperation>
53-
using IsReduBitAND =
51+
using IsBitAND =
5452
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_and<T>>::value ||
5553
std::is_same<BinaryOperation, ONEAPI::bit_and<void>>::value>;
5654

57-
template <typename T, class BinaryOperation>
58-
using IsReduOptForFastAtomicFetch =
59-
bool_constant<is_sgeninteger<T>::value &&
60-
sycl::detail::IsValidAtomicType<T>::value &&
61-
(IsReduPlus<T, BinaryOperation>::value ||
62-
IsReduMinimum<T, BinaryOperation>::value ||
63-
IsReduMaximum<T, BinaryOperation>::value ||
64-
IsReduBitOR<T, BinaryOperation>::value ||
65-
IsReduBitXOR<T, BinaryOperation>::value ||
66-
IsReduBitAND<T, BinaryOperation>::value)>;
67-
68-
template <typename T, class BinaryOperation>
69-
using IsReduOptForFastReduce =
70-
bool_constant<((is_sgeninteger<T>::value &&
71-
(sizeof(T) == 4 || sizeof(T) == 8)) ||
72-
is_sgenfloat<T>::value) &&
73-
(IsReduPlus<T, BinaryOperation>::value ||
74-
IsReduMinimum<T, BinaryOperation>::value ||
75-
IsReduMaximum<T, BinaryOperation>::value)>;
76-
7755
// Identity = 0
7856
template <typename T, class BinaryOperation>
7957
using IsZeroIdentityOp = bool_constant<
80-
(is_sgeninteger<T>::value && (IsReduPlus<T, BinaryOperation>::value ||
81-
IsReduBitOR<T, BinaryOperation>::value ||
82-
IsReduBitXOR<T, BinaryOperation>::value)) ||
83-
(is_sgenfloat<T>::value && IsReduPlus<T, BinaryOperation>::value)>;
58+
(is_sgeninteger<T>::value &&
59+
(IsPlus<T, BinaryOperation>::value || IsBitOR<T, BinaryOperation>::value ||
60+
IsBitXOR<T, BinaryOperation>::value)) ||
61+
(is_sgenfloat<T>::value && IsPlus<T, BinaryOperation>::value)>;
8462

8563
// Identity = 1
8664
template <typename T, class BinaryOperation>
8765
using IsOneIdentityOp =
8866
bool_constant<(is_sgeninteger<T>::value || is_sgenfloat<T>::value) &&
89-
IsReduMultiplies<T, BinaryOperation>::value>;
67+
IsMultiplies<T, BinaryOperation>::value>;
9068

9169
// Identity = ~0
9270
template <typename T, class BinaryOperation>
9371
using IsOnesIdentityOp = bool_constant<is_sgeninteger<T>::value &&
94-
IsReduBitAND<T, BinaryOperation>::value>;
72+
IsBitAND<T, BinaryOperation>::value>;
9573

9674
// Identity = <max possible value>
9775
template <typename T, class BinaryOperation>
9876
using IsMinimumIdentityOp =
9977
bool_constant<(is_sgeninteger<T>::value || is_sgenfloat<T>::value) &&
100-
IsReduMinimum<T, BinaryOperation>::value>;
78+
IsMinimum<T, BinaryOperation>::value>;
10179

10280
// Identity = <min possible value>
10381
template <typename T, class BinaryOperation>
10482
using IsMaximumIdentityOp =
10583
bool_constant<(is_sgeninteger<T>::value || is_sgenfloat<T>::value) &&
106-
IsReduMaximum<T, BinaryOperation>::value>;
84+
IsMaximum<T, BinaryOperation>::value>;
10785

10886
template <typename T, class BinaryOperation>
10987
using IsKnownIdentityOp =

0 commit comments

Comments
 (0)