Skip to content

Commit 9b355ef

Browse files
committed
Fix LIT test fails + address reviewer's comment
Signed-off-by: Vyacheslav N Klochkov <[email protected]>
1 parent 701377b commit 9b355ef

File tree

1 file changed

+38
-14
lines changed

1 file changed

+38
-14
lines changed

sycl/include/CL/sycl/ONEAPI/reduction.hpp

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "CL/sycl/ONEAPI/accessor_property_list.hpp"
1212
#include <CL/sycl/ONEAPI/group_algorithm.hpp>
1313
#include <CL/sycl/accessor.hpp>
14+
#include <CL/sycl/atomic.hpp>
1415
#include <CL/sycl/handler.hpp>
1516
#include <CL/sycl/kernel.hpp>
1617

@@ -71,7 +72,8 @@ using IsReduBitAND =
7172

7273
template <typename T, class BinaryOperation>
7374
using IsReduOptForFastAtomicFetch =
74-
bool_constant<is_sgeninteger<T>::value && IsValidAtomicType<T>::value &&
75+
bool_constant<is_sgeninteger<T>::value &&
76+
sycl::detail::IsValidAtomicType<T>::value &&
7577
(IsReduPlus<T, BinaryOperation>::value ||
7678
IsReduMinimum<T, BinaryOperation>::value ||
7779
IsReduMaximum<T, BinaryOperation>::value ||
@@ -81,7 +83,9 @@ using IsReduOptForFastAtomicFetch =
8183

8284
template <typename T, class BinaryOperation>
8385
using IsReduOptForFastReduce =
84-
bool_constant<(is_sgeninteger<T>::value || is_sgenfloat<T>::value) &&
86+
bool_constant<((is_sgeninteger<T>::value &&
87+
(sizeof(T) == 32 || sizeof(T) == 64)) ||
88+
is_sgenfloat<T>::value) &&
8589
(IsReduPlus<T, BinaryOperation>::value ||
8690
IsReduMinimum<T, BinaryOperation>::value ||
8791
IsReduMaximum<T, BinaryOperation>::value)>;
@@ -322,7 +326,7 @@ class reducer<T, BinaryOperation,
322326
/// Atomic ADD operation: *ReduVarPtr += MValue;
323327
template <typename _T = T, class _BinaryOperation = BinaryOperation>
324328
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
325-
(is_geninteger32bit<T>::value || is_geninteger64bit<T>::value) &&
329+
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
326330
IsReduPlus<T, _BinaryOperation>::value>
327331
atomic_combine(_T *ReduVarPtr) const {
328332
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
@@ -332,7 +336,7 @@ class reducer<T, BinaryOperation,
332336
/// Atomic BITWISE OR operation: *ReduVarPtr |= MValue;
333337
template <typename _T = T, class _BinaryOperation = BinaryOperation>
334338
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
335-
(is_geninteger32bit<T>::value || is_geninteger64bit<T>::value) &&
339+
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
336340
IsReduBitOR<T, _BinaryOperation>::value>
337341
atomic_combine(_T *ReduVarPtr) const {
338342
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
@@ -342,7 +346,7 @@ class reducer<T, BinaryOperation,
342346
/// Atomic BITWISE XOR operation: *ReduVarPtr ^= MValue;
343347
template <typename _T = T, class _BinaryOperation = BinaryOperation>
344348
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
345-
(is_geninteger32bit<T>::value || is_geninteger64bit<T>::value) &&
349+
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
346350
IsReduBitXOR<T, _BinaryOperation>::value>
347351
atomic_combine(_T *ReduVarPtr) const {
348352
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
@@ -352,7 +356,7 @@ class reducer<T, BinaryOperation,
352356
/// Atomic BITWISE AND operation: *ReduVarPtr &= MValue;
353357
template <typename _T = T, class _BinaryOperation = BinaryOperation>
354358
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
355-
(is_geninteger32bit<T>::value || is_geninteger64bit<T>::value) &&
359+
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
356360
IsReduBitAND<T, _BinaryOperation>::value>
357361
atomic_combine(_T *ReduVarPtr) const {
358362
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
@@ -362,7 +366,7 @@ class reducer<T, BinaryOperation,
362366
/// Atomic MIN operation: *ReduVarPtr = ONEAPI::minimum(*ReduVarPtr, MValue);
363367
template <typename _T = T, class _BinaryOperation = BinaryOperation>
364368
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
365-
(is_geninteger32bit<T>::value || is_geninteger64bit<T>::value) &&
369+
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
366370
IsReduMinimum<T, _BinaryOperation>::value>
367371
atomic_combine(_T *ReduVarPtr) const {
368372
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
@@ -372,7 +376,7 @@ class reducer<T, BinaryOperation,
372376
/// Atomic MAX operation: *ReduVarPtr = ONEAPI::maximum(*ReduVarPtr, MValue);
373377
template <typename _T = T, class _BinaryOperation = BinaryOperation>
374378
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
375-
(is_geninteger32bit<T>::value || is_geninteger64bit<T>::value) &&
379+
IsReduOptForFastAtomicFetch<T, _BinaryOperation>::value &&
376380
IsReduMaximum<T, _BinaryOperation>::value>
377381
atomic_combine(_T *ReduVarPtr) const {
378382
atomic<T, access::address_space::global_space>(global_ptr<T>(ReduVarPtr))
@@ -1619,10 +1623,8 @@ reduction(T *VarPtr, BinaryOperation) {
16191623
access::mode::read_write>(VarPtr);
16201624
}
16211625

1622-
} // namespace ONEAPI
1623-
16241626
template <typename BinaryOperation, typename AccumulatorT>
1625-
struct has_known_identity : ONEAPI::detail::has_known_identity_impl<
1627+
struct has_known_identity : detail::has_known_identity_impl<
16261628
typename std::decay<BinaryOperation>::type,
16271629
typename std::decay<AccumulatorT>::type> {};
16281630
#if __cplusplus >= 201703L
@@ -1632,14 +1634,36 @@ inline constexpr bool has_known_identity_v =
16321634
#endif
16331635

16341636
template <typename BinaryOperation, typename AccumulatorT>
1635-
struct known_identity : ONEAPI::detail::known_identity_impl<
1636-
typename std::decay<BinaryOperation>::type,
1637-
typename std::decay<AccumulatorT>::type> {};
1637+
struct known_identity
1638+
: detail::known_identity_impl<typename std::decay<BinaryOperation>::type,
1639+
typename std::decay<AccumulatorT>::type> {};
16381640
#if __cplusplus >= 201703L
16391641
template <typename BinaryOperation, typename AccumulatorT>
16401642
inline constexpr AccumulatorT known_identity_v =
16411643
known_identity<BinaryOperation, AccumulatorT>::value;
16421644
#endif
1645+
} // namespace ONEAPI
1646+
1647+
// Currently, the type traits defined below correspond to SYCL 1.2.1 ONEAPI
1648+
// reduction extension. That may be changed later when SYCL 2020 reductions
1649+
// are implemented.
1650+
#if SYCL_LANGUAGE_VERSION >= 202001
1651+
template <typename BinaryOperation, typename AccumulatorT>
1652+
struct has_known_identity
1653+
: ONEAPI::has_known_identity<BinaryOperation, AccumulatorT> {};
1654+
1655+
template <typename BinaryOperation, typename AccumulatorT>
1656+
inline constexpr bool has_known_identity_v =
1657+
has_known_identity<BinaryOperation, AccumulatorT>::value;
1658+
1659+
template <typename BinaryOperation, typename AccumulatorT>
1660+
struct known_identity : ONEAPI::known_identity<BinaryOperation, AccumulatorT> {
1661+
};
1662+
1663+
template <typename BinaryOperation, typename AccumulatorT>
1664+
inline constexpr AccumulatorT known_identity_v =
1665+
known_identity<BinaryOperation, AccumulatorT>::value;
1666+
#endif // SYCL_LANGUAGE_VERSION >= 202001
16431667

16441668
} // namespace sycl
16451669
} // __SYCL_INLINE_NAMESPACE(cl)

0 commit comments

Comments
 (0)