Skip to content

Commit 9a7767a

Browse files
authored
[SYCL][NFC] Remove functors code duplication (#4361)
The functors ext::oneapi::minimum/maximum were defined identically to sycl::minimum/maximum. Turn them into aliases of functors defined in sycl. Remove duplicated checks for ext::oneapi functors. Signed-off-by: Vyacheslav N Klochkov <[email protected]>
1 parent 2fe7dd3 commit 9a7767a

File tree

6 files changed

+54
-117
lines changed

6 files changed

+54
-117
lines changed

sycl/include/CL/sycl/group_algorithm.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,9 @@ get_local_linear_id<ext::oneapi::sub_group>(ext::oneapi::sub_group g) {
8686
// ---- is_native_op
8787
template <typename T>
8888
using native_op_list =
89-
type_list<ext::oneapi::plus<T>, ext::oneapi::bit_or<T>,
90-
ext::oneapi::bit_xor<T>, ext::oneapi::bit_and<T>,
91-
ext::oneapi::maximum<T>, ext::oneapi::minimum<T>,
92-
ext::oneapi::multiplies<T>, sycl::plus<T>, sycl::bit_or<T>,
93-
sycl::bit_xor<T>, sycl::bit_and<T>, sycl::maximum<T>,
94-
sycl::minimum<T>, sycl::multiplies<T>>;
89+
type_list<sycl::plus<T>, sycl::bit_or<T>, sycl::bit_xor<T>,
90+
sycl::bit_and<T>, sycl::maximum<T>, sycl::minimum<T>,
91+
sycl::multiplies<T>>;
9592

9693
template <typename T, typename BinaryOperation> struct is_native_op {
9794
static constexpr bool value =

sycl/include/CL/sycl/known_identity.hpp

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,53 +18,39 @@ namespace sycl {
1818
namespace detail {
1919

2020
template <typename T, class BinaryOperation>
21-
using IsPlus = bool_constant<
22-
std::is_same<BinaryOperation, sycl::plus<T>>::value ||
23-
std::is_same<BinaryOperation, sycl::plus<void>>::value ||
24-
std::is_same<BinaryOperation, ext::oneapi::plus<T>>::value ||
25-
std::is_same<BinaryOperation, ext::oneapi::plus<void>>::value>;
21+
using IsPlus =
22+
bool_constant<std::is_same<BinaryOperation, sycl::plus<T>>::value ||
23+
std::is_same<BinaryOperation, sycl::plus<void>>::value>;
2624

2725
template <typename T, class BinaryOperation>
28-
using IsMultiplies = bool_constant<
29-
std::is_same<BinaryOperation, sycl::multiplies<T>>::value ||
30-
std::is_same<BinaryOperation, sycl::multiplies<void>>::value ||
31-
std::is_same<BinaryOperation, ext::oneapi::multiplies<T>>::value ||
32-
std::is_same<BinaryOperation, ext::oneapi::multiplies<void>>::value>;
26+
using IsMultiplies =
27+
bool_constant<std::is_same<BinaryOperation, sycl::multiplies<T>>::value ||
28+
std::is_same<BinaryOperation, sycl::multiplies<void>>::value>;
3329

3430
template <typename T, class BinaryOperation>
35-
using IsMinimum = bool_constant<
36-
std::is_same<BinaryOperation, sycl::minimum<T>>::value ||
37-
std::is_same<BinaryOperation, sycl::minimum<void>>::value ||
38-
std::is_same<BinaryOperation, ext::oneapi::minimum<T>>::value ||
39-
std::is_same<BinaryOperation, ext::oneapi::minimum<void>>::value>;
31+
using IsMinimum =
32+
bool_constant<std::is_same<BinaryOperation, sycl::minimum<T>>::value ||
33+
std::is_same<BinaryOperation, sycl::minimum<void>>::value>;
4034

4135
template <typename T, class BinaryOperation>
42-
using IsMaximum = bool_constant<
43-
std::is_same<BinaryOperation, sycl::maximum<T>>::value ||
44-
std::is_same<BinaryOperation, sycl::maximum<void>>::value ||
45-
std::is_same<BinaryOperation, ext::oneapi::maximum<T>>::value ||
46-
std::is_same<BinaryOperation, ext::oneapi::maximum<void>>::value>;
36+
using IsMaximum =
37+
bool_constant<std::is_same<BinaryOperation, sycl::maximum<T>>::value ||
38+
std::is_same<BinaryOperation, sycl::maximum<void>>::value>;
4739

4840
template <typename T, class BinaryOperation>
49-
using IsBitOR = bool_constant<
50-
std::is_same<BinaryOperation, sycl::bit_or<T>>::value ||
51-
std::is_same<BinaryOperation, sycl::bit_or<void>>::value ||
52-
std::is_same<BinaryOperation, ext::oneapi::bit_or<T>>::value ||
53-
std::is_same<BinaryOperation, ext::oneapi::bit_or<void>>::value>;
41+
using IsBitOR =
42+
bool_constant<std::is_same<BinaryOperation, sycl::bit_or<T>>::value ||
43+
std::is_same<BinaryOperation, sycl::bit_or<void>>::value>;
5444

5545
template <typename T, class BinaryOperation>
56-
using IsBitXOR = bool_constant<
57-
std::is_same<BinaryOperation, sycl::bit_xor<T>>::value ||
58-
std::is_same<BinaryOperation, sycl::bit_xor<void>>::value ||
59-
std::is_same<BinaryOperation, ext::oneapi::bit_xor<T>>::value ||
60-
std::is_same<BinaryOperation, ext::oneapi::bit_xor<void>>::value>;
46+
using IsBitXOR =
47+
bool_constant<std::is_same<BinaryOperation, sycl::bit_xor<T>>::value ||
48+
std::is_same<BinaryOperation, sycl::bit_xor<void>>::value>;
6149

6250
template <typename T, class BinaryOperation>
63-
using IsBitAND = bool_constant<
64-
std::is_same<BinaryOperation, sycl::bit_and<T>>::value ||
65-
std::is_same<BinaryOperation, sycl::bit_and<void>>::value ||
66-
std::is_same<BinaryOperation, ext::oneapi::bit_and<T>>::value ||
67-
std::is_same<BinaryOperation, ext::oneapi::bit_and<void>>::value>;
51+
using IsBitAND =
52+
bool_constant<std::is_same<BinaryOperation, sycl::bit_and<T>>::value ||
53+
std::is_same<BinaryOperation, sycl::bit_and<void>>::value>;
6854

6955
// Identity = 0
7056
template <typename T, class BinaryOperation>

sycl/include/sycl/ext/oneapi/functional.hpp

Lines changed: 18 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -16,46 +16,13 @@ namespace sycl {
1616
namespace ext {
1717
namespace oneapi {
1818

19-
template <typename T = void> struct minimum {
20-
T operator()(const T &lhs, const T &rhs) const {
21-
return std::less<T>()(lhs, rhs) ? lhs : rhs;
22-
}
23-
};
24-
25-
template <> struct minimum<void> {
26-
struct is_transparent {};
27-
template <typename T, typename U>
28-
auto operator()(T &&lhs, U &&rhs) const ->
29-
typename std::common_type<T &&, U &&>::type {
30-
return std::less<>()(std::forward<const T>(lhs), std::forward<const U>(rhs))
31-
? std::forward<T>(lhs)
32-
: std::forward<U>(rhs);
33-
}
34-
};
35-
36-
template <typename T = void> struct maximum {
37-
T operator()(const T &lhs, const T &rhs) const {
38-
return std::greater<T>()(lhs, rhs) ? lhs : rhs;
39-
}
40-
};
41-
42-
template <> struct maximum<void> {
43-
struct is_transparent {};
44-
template <typename T, typename U>
45-
auto operator()(T &&lhs, U &&rhs) const ->
46-
typename std::common_type<T &&, U &&>::type {
47-
return std::greater<>()(std::forward<const T>(lhs),
48-
std::forward<const U>(rhs))
49-
? std::forward<T>(lhs)
50-
: std::forward<U>(rhs);
51-
}
52-
};
53-
5419
template <typename T = void> using plus = std::plus<T>;
5520
template <typename T = void> using multiplies = std::multiplies<T>;
5621
template <typename T = void> using bit_or = std::bit_or<T>;
5722
template <typename T = void> using bit_xor = std::bit_xor<T>;
5823
template <typename T = void> using bit_and = std::bit_and<T>;
24+
template <typename T = void> using maximum = sycl::maximum<T>;
25+
template <typename T = void> using minimum = sycl::minimum<T>;
5926

6027
} // namespace oneapi
6128
} // namespace ext
@@ -106,41 +73,29 @@ struct GroupOpTag<T, detail::enable_if_t<detail::is_sgenfloat<T>::value>> {
10673
return Ret; \
10774
}
10875

109-
// calc for sycl minimum/maximum function objects
76+
// calc for sycl function objects
11077
__SYCL_CALC_OVERLOAD(GroupOpISigned, SMin, sycl::minimum<T>)
11178
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMin, sycl::minimum<T>)
11279
__SYCL_CALC_OVERLOAD(GroupOpFP, FMin, sycl::minimum<T>)
80+
11381
__SYCL_CALC_OVERLOAD(GroupOpISigned, SMax, sycl::maximum<T>)
11482
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMax, sycl::maximum<T>)
11583
__SYCL_CALC_OVERLOAD(GroupOpFP, FMax, sycl::maximum<T>)
11684

117-
// calc for oneapi function objects
118-
__SYCL_CALC_OVERLOAD(GroupOpISigned, SMin, ext::oneapi::minimum<T>)
119-
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMin, ext::oneapi::minimum<T>)
120-
__SYCL_CALC_OVERLOAD(GroupOpFP, FMin, ext::oneapi::minimum<T>)
121-
__SYCL_CALC_OVERLOAD(GroupOpISigned, SMax, ext::oneapi::maximum<T>)
122-
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMax, ext::oneapi::maximum<T>)
123-
__SYCL_CALC_OVERLOAD(GroupOpFP, FMax, ext::oneapi::maximum<T>)
124-
__SYCL_CALC_OVERLOAD(GroupOpISigned, IAdd, ext::oneapi::plus<T>)
125-
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, IAdd, ext::oneapi::plus<T>)
126-
__SYCL_CALC_OVERLOAD(GroupOpFP, FAdd, ext::oneapi::plus<T>)
127-
128-
__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformIMul, ext::oneapi::multiplies<T>)
129-
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformIMul,
130-
ext::oneapi::multiplies<T>)
131-
__SYCL_CALC_OVERLOAD(GroupOpFP, NonUniformFMul, ext::oneapi::multiplies<T>)
132-
__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseOr,
133-
ext::oneapi::bit_or<T>)
134-
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseOr,
135-
ext::oneapi::bit_or<T>)
136-
__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseXor,
137-
ext::oneapi::bit_xor<T>)
138-
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseXor,
139-
ext::oneapi::bit_xor<T>)
140-
__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseAnd,
141-
ext::oneapi::bit_and<T>)
142-
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseAnd,
143-
ext::oneapi::bit_and<T>)
85+
__SYCL_CALC_OVERLOAD(GroupOpISigned, IAdd, sycl::plus<T>)
86+
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, IAdd, sycl::plus<T>)
87+
__SYCL_CALC_OVERLOAD(GroupOpFP, FAdd, sycl::plus<T>)
88+
89+
__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformIMul, sycl::multiplies<T>)
90+
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformIMul, sycl::multiplies<T>)
91+
__SYCL_CALC_OVERLOAD(GroupOpFP, NonUniformFMul, sycl::multiplies<T>)
92+
93+
__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseOr, sycl::bit_or<T>)
94+
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseOr, sycl::bit_or<T>)
95+
__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseXor, sycl::bit_xor<T>)
96+
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseXor, sycl::bit_xor<T>)
97+
__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseAnd, sycl::bit_and<T>)
98+
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseAnd, sycl::bit_and<T>)
14499

145100
#undef __SYCL_CALC_OVERLOAD
146101

sycl/include/sycl/ext/oneapi/reduction.hpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ class reducer {
183183
/// using those operations, which are based on functionality provided by
184184
/// sycl::atomic class.
185185
///
186-
/// For example, it is known that 0 is identity for ext::oneapi::plus operations
186+
/// For example, it is known that 0 is identity for sycl::plus operations
187187
/// accepting native scalar types to which scalar 0 is convertible.
188188
/// Also, for int32/64 types the atomic_combine() is lowered to
189189
/// sycl::atomic::fetch_add().
@@ -317,8 +317,7 @@ class reducer<T, BinaryOperation,
317317
.fetch_and(MValue);
318318
}
319319

320-
/// Atomic MIN operation: *ReduVarPtr = ext::oneapi::minimum(*ReduVarPtr,
321-
/// MValue);
320+
/// Atomic MIN operation: *ReduVarPtr = sycl::minimum(*ReduVarPtr, MValue);
322321
template <access::address_space Space = access::address_space::global_space,
323322
typename _T = T, class _BinaryOperation = BinaryOperation>
324323
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
@@ -332,8 +331,7 @@ class reducer<T, BinaryOperation,
332331
.fetch_min(MValue);
333332
}
334333

335-
/// Atomic MAX operation: *ReduVarPtr = ext::oneapi::maximum(*ReduVarPtr,
336-
/// MValue);
334+
/// Atomic MAX operation: *ReduVarPtr = sycl::maximum(*ReduVarPtr, MValue);
337335
template <access::address_space Space = access::address_space::global_space,
338336
typename _T = T, class _BinaryOperation = BinaryOperation>
339337
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&

sycl/source/detail/reduction.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ __SYCL_EXPORT uint32_t reduGetMaxNumConcurrentWorkGroups(
5555
std::shared_ptr<sycl::detail::queue_impl> Queue) {
5656
device Dev = Queue->get_device();
5757
uint32_t NumThreads = Dev.get_info<info::device::max_compute_units>();
58-
// The heuristics require additional tuning for various devices and vendors.
59-
// For now assuming that each of execution units have about 8 working threads
60-
// gives good results on some known/supported GPU devices.
58+
// TODO: The heuristics here require additional tuning for various devices
59+
// and vendors. For now this code assumes that execution units have about
60+
// 8 working threads, which gives good results on some known/supported
61+
// GPU devices.
6162
if (Dev.is_gpu())
6263
NumThreads *= 8;
6364
return NumThreads;

sycl/test/basic_tests/reduction_known_identity.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
using namespace cl::sycl;
1111

1212
template <typename T> void checkCommonBasicKnownIdentity() {
13-
static_assert(has_known_identity<ext::oneapi::maximum<>, T>::value);
14-
static_assert(has_known_identity<ext::oneapi::maximum<T>, T>::value);
15-
static_assert(has_known_identity<ext::oneapi::minimum<>, T>::value);
16-
static_assert(has_known_identity<ext::oneapi::minimum<T>, T>::value);
13+
static_assert(has_known_identity<sycl::maximum<>, T>::value);
14+
static_assert(has_known_identity<sycl::maximum<T>, T>::value);
15+
static_assert(has_known_identity<sycl::minimum<>, T>::value);
16+
static_assert(has_known_identity<sycl::minimum<T>, T>::value);
1717
}
1818

1919
template <typename T> void checkCommonKnownIdentity() {
@@ -100,7 +100,7 @@ int main() {
100100

101101
// Few negative tests just to check that it does not always return true.
102102
static_assert(!has_known_identity<std::minus<>, int>::value);
103-
static_assert(!has_known_identity<ext::oneapi::bit_or<>, float>::value);
103+
static_assert(!has_known_identity<sycl::bit_or<>, float>::value);
104104

105105
return 0;
106106
}

0 commit comments

Comments
 (0)