Skip to content

Commit 24a2ad8

Browse files
authored
[SYCL] Provide SYCL 2020 function objects (#3868)
Provided SYCL function objects accordingly to SYCL 2020 specification Signed-off-by: mdimakov <[email protected]>
1 parent d9493cb commit 24a2ad8

File tree

10 files changed

+129
-43
lines changed

10 files changed

+129
-43
lines changed

sycl/include/CL/sycl.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <CL/sycl/event.hpp>
2929
#include <CL/sycl/exception.hpp>
3030
#include <CL/sycl/feature_test.hpp>
31+
#include <CL/sycl/functional.hpp>
3132
#include <CL/sycl/group.hpp>
3233
#include <CL/sycl/group_algorithm.hpp>
3334
#include <CL/sycl/group_local_memory.hpp>

sycl/include/CL/sycl/ONEAPI/functional.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#pragma once
10+
#include <CL/sycl/functional.hpp>
11+
1012
#include <functional>
1113

1214
__SYCL_INLINE_NAMESPACE(cl) {
@@ -90,6 +92,15 @@ struct GroupOpTag<T, detail::enable_if_t<detail::is_sgenfloat<T>::value>> {
9092
return Ret; \
9193
}
9294

95+
// calc for sycl minimum/maximum function objects
96+
__SYCL_CALC_OVERLOAD(GroupOpISigned, SMin, sycl::minimum<T>)
97+
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMin, sycl::minimum<T>)
98+
__SYCL_CALC_OVERLOAD(GroupOpFP, FMin, sycl::minimum<T>)
99+
__SYCL_CALC_OVERLOAD(GroupOpISigned, SMax, sycl::maximum<T>)
100+
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMax, sycl::maximum<T>)
101+
__SYCL_CALC_OVERLOAD(GroupOpFP, FMax, sycl::maximum<T>)
102+
103+
// calc for ONEAPI function objects
93104
__SYCL_CALC_OVERLOAD(GroupOpISigned, SMin, ONEAPI::minimum<T>)
94105
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMin, ONEAPI::minimum<T>)
95106
__SYCL_CALC_OVERLOAD(GroupOpFP, FMin, ONEAPI::minimum<T>)

sycl/include/CL/sycl/functional.hpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
//==----------- functional.hpp --- SYCL functional -------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
#include <functional>
11+
12+
__SYCL_INLINE_NAMESPACE(cl) {
13+
namespace sycl {
14+
15+
template <typename T = void> using plus = std::plus<T>;
16+
template <typename T = void> using multiplies = std::multiplies<T>;
17+
template <typename T = void> using bit_or = std::bit_or<T>;
18+
template <typename T = void> using bit_xor = std::bit_xor<T>;
19+
template <typename T = void> using bit_and = std::bit_and<T>;
20+
21+
template <typename T = void> struct minimum {
22+
T operator()(const T &lhs, const T &rhs) const {
23+
return std::less<T>()(lhs, rhs) ? lhs : rhs;
24+
}
25+
};
26+
27+
template <> struct minimum<void> {
28+
struct is_transparent {};
29+
template <typename T, typename U>
30+
auto operator()(T &&lhs, U &&rhs) const ->
31+
typename std::common_type<T &&, U &&>::type {
32+
return std::less<>()(std::forward<const T>(lhs), std::forward<const U>(rhs))
33+
? std::forward<T>(lhs)
34+
: std::forward<U>(rhs);
35+
}
36+
};
37+
38+
template <typename T = void> struct maximum {
39+
T operator()(const T &lhs, const T &rhs) const {
40+
return std::greater<T>()(lhs, rhs) ? lhs : rhs;
41+
}
42+
};
43+
44+
template <> struct maximum<void> {
45+
struct is_transparent {};
46+
template <typename T, typename U>
47+
auto operator()(T &&lhs, U &&rhs) const ->
48+
typename std::common_type<T &&, U &&>::type {
49+
return std::greater<>()(std::forward<const T>(lhs),
50+
std::forward<const U>(rhs))
51+
? std::forward<T>(lhs)
52+
: std::forward<U>(rhs);
53+
}
54+
};
55+
56+
} // namespace sycl
57+
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/include/CL/sycl/group_algorithm.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <CL/sycl/ONEAPI/functional.hpp>
1414
#include <CL/sycl/detail/spirv.hpp>
1515
#include <CL/sycl/detail/type_traits.hpp>
16+
#include <CL/sycl/functional.hpp>
1617
#include <CL/sycl/group.hpp>
1718
#include <CL/sycl/known_identity.hpp>
1819
#include <CL/sycl/nd_item.hpp>
@@ -86,7 +87,9 @@ template <typename T>
8687
using native_op_list =
8788
type_list<ONEAPI::plus<T>, ONEAPI::bit_or<T>, ONEAPI::bit_xor<T>,
8889
ONEAPI::bit_and<T>, ONEAPI::maximum<T>, ONEAPI::minimum<T>,
89-
ONEAPI::multiplies<T>>;
90+
ONEAPI::multiplies<T>, sycl::plus<T>, sycl::bit_or<T>,
91+
sycl::bit_xor<T>, sycl::bit_and<T>, sycl::maximum<T>,
92+
sycl::minimum<T>, sycl::multiplies<T>>;
9093

9194
template <typename T, typename BinaryOperation> struct is_native_op {
9295
static constexpr bool value =

sycl/include/CL/sycl/known_identity.hpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,37 +19,51 @@ namespace detail {
1919

2020
template <typename T, class BinaryOperation>
2121
using IsPlus =
22-
bool_constant<std::is_same<BinaryOperation, ONEAPI::plus<T>>::value ||
22+
bool_constant<std::is_same<BinaryOperation, sycl::plus<T>>::value ||
23+
std::is_same<BinaryOperation, sycl::plus<void>>::value ||
24+
std::is_same<BinaryOperation, ONEAPI::plus<T>>::value ||
2325
std::is_same<BinaryOperation, ONEAPI::plus<void>>::value>;
2426

2527
template <typename T, class BinaryOperation>
2628
using IsMultiplies = bool_constant<
29+
std::is_same<BinaryOperation, sycl::multiplies<T>>::value ||
30+
std::is_same<BinaryOperation, sycl::multiplies<void>>::value ||
2731
std::is_same<BinaryOperation, ONEAPI::multiplies<T>>::value ||
2832
std::is_same<BinaryOperation, ONEAPI::multiplies<void>>::value>;
2933

3034
template <typename T, class BinaryOperation>
3135
using IsMinimum =
32-
bool_constant<std::is_same<BinaryOperation, ONEAPI::minimum<T>>::value ||
36+
bool_constant<std::is_same<BinaryOperation, sycl::minimum<T>>::value ||
37+
std::is_same<BinaryOperation, sycl::minimum<void>>::value ||
38+
std::is_same<BinaryOperation, ONEAPI::minimum<T>>::value ||
3339
std::is_same<BinaryOperation, ONEAPI::minimum<void>>::value>;
3440

3541
template <typename T, class BinaryOperation>
3642
using IsMaximum =
37-
bool_constant<std::is_same<BinaryOperation, ONEAPI::maximum<T>>::value ||
43+
bool_constant<std::is_same<BinaryOperation, sycl::maximum<T>>::value ||
44+
std::is_same<BinaryOperation, sycl::maximum<void>>::value ||
45+
std::is_same<BinaryOperation, ONEAPI::maximum<T>>::value ||
3846
std::is_same<BinaryOperation, ONEAPI::maximum<void>>::value>;
3947

4048
template <typename T, class BinaryOperation>
4149
using IsBitOR =
42-
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_or<T>>::value ||
50+
bool_constant<std::is_same<BinaryOperation, sycl::bit_or<T>>::value ||
51+
std::is_same<BinaryOperation, sycl::bit_or<void>>::value ||
52+
std::is_same<BinaryOperation, ONEAPI::bit_or<T>>::value ||
4353
std::is_same<BinaryOperation, ONEAPI::bit_or<void>>::value>;
4454

4555
template <typename T, class BinaryOperation>
4656
using IsBitXOR =
47-
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_xor<T>>::value ||
57+
bool_constant<std::is_same<BinaryOperation, sycl::bit_xor<T>>::value ||
58+
std::is_same<BinaryOperation, sycl::bit_xor<void>>::value ||
59+
std::is_same<BinaryOperation, ONEAPI::bit_xor<T>>::value ||
4860
std::is_same<BinaryOperation, ONEAPI::bit_xor<void>>::value>;
4961

5062
template <typename T, class BinaryOperation>
5163
using IsBitAND =
52-
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_and<T>>::value ||
64+
bool_constant<std::is_same<BinaryOperation, sycl::bit_and<T>>::value ||
65+
std::is_same<BinaryOperation, sycl::bit_and<void>>::value ||
66+
std::is_same<BinaryOperation, ONEAPI::bit_and<T>>::value ||
5367
std::is_same<BinaryOperation, ONEAPI::bit_and<void>>::value>;
5468

5569
// Identity = 0

sycl/test/extensions/group-algorithm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ int main() {
7676
std::iota(input.begin(), input.end(), 0);
7777
std::fill(output.begin(), output.end(), 0);
7878

79-
test<class KernelNamePlusV>(q, input, output, plus<>(), 0, GeZero());
80-
test<class KernelNameMinimumV>(q, input, output, minimum<>(),
79+
test<class KernelNamePlusV>(q, input, output, ONEAPI::plus<>(), 0, GeZero());
80+
test<class KernelNameMinimumV>(q, input, output, ONEAPI::minimum<>(),
8181
std::numeric_limits<int>::max(), IsEven());
8282

8383
#ifdef SPIRV_1_3
8484
test<class KernelName_WonwuUVPUPOTKRKIBtT>(q, input, output,
85-
multiplies<int>(), 1, LtZero());
85+
ONEAPI::multiplies<int>(), 1, LtZero());
8686
#endif // SPIRV_1_3
8787

8888
std::cout << "Test passed." << std::endl;

sycl/test/on-device/back_to_back_collectives.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ int main() {
4747
auto g = it.get_group();
4848
// Loop to increase number of back-to-back calls
4949
for (int r = 0; r < 10; ++r) {
50-
Sum[i] = reduce(g, Input[i], plus<>());
51-
EScan[i] = exclusive_scan(g, Input[i], plus<>());
52-
IScan[i] = inclusive_scan(g, Input[i], plus<>());
50+
Sum[i] = reduce(g, Input[i], sycl::plus<>());
51+
EScan[i] = exclusive_scan(g, Input[i], sycl::plus<>());
52+
IScan[i] = inclusive_scan(g, Input[i], sycl::plus<>());
5353
}
5454
});
5555
});

sycl/test/on-device/group_algorithms_sycl2020/exclusive_scan.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -138,24 +138,24 @@ int main() {
138138
std::iota(input.begin(), input.end(), 0);
139139
std::fill(output.begin(), output.end(), 0);
140140

141-
test<class KernelNamePlusV>(q, input, output, std::plus<>(), 0);
142-
test<class KernelNameMinimumV>(q, input, output, sycl::ONEAPI::minimum<>(),
141+
test<class KernelNamePlusV>(q, input, output, sycl::plus<>(), 0);
142+
test<class KernelNameMinimumV>(q, input, output, sycl::minimum<>(),
143143
std::numeric_limits<int>::max());
144-
test<class KernelNameMaximumV>(q, input, output, sycl::ONEAPI::maximum<>(),
144+
test<class KernelNameMaximumV>(q, input, output, sycl::maximum<>(),
145145
std::numeric_limits<int>::lowest());
146146

147-
test<class KernelNamePlusI>(q, input, output, std::plus<int>(), 0);
148-
test<class KernelNameMinimumI>(q, input, output, sycl::ONEAPI::minimum<int>(),
147+
test<class KernelNamePlusI>(q, input, output, sycl::plus<int>(), 0);
148+
test<class KernelNameMinimumI>(q, input, output, sycl::minimum<int>(),
149149
std::numeric_limits<int>::max());
150-
test<class KernelNameMaximumI>(q, input, output, sycl::ONEAPI::maximum<int>(),
150+
test<class KernelNameMaximumI>(q, input, output, sycl::maximum<int>(),
151151
std::numeric_limits<int>::lowest());
152152

153153
#ifdef SPIRV_1_3
154-
test<class KernelName_VzAPutpBRRJrQPB>(q, input, output, multiplies<int>(),
154+
test<class KernelName_VzAPutpBRRJrQPB>(q, input, output, sycl::multiplies<int>(),
155155
1);
156-
test<class KernelName_UXdGbr>(q, input, output, bit_or<int>(), 0);
157-
test<class KernelName_saYaodNyJknrPW>(q, input, output, bit_xor<int>(), 0);
158-
test<class KernelName_GPcuAlvAOjrDyP>(q, input, output, bit_and<int>(), ~0);
156+
test<class KernelName_UXdGbr>(q, input, output, sycl::bit_or<int>(), 0);
157+
test<class KernelName_saYaodNyJknrPW>(q, input, output, sycl::bit_xor<int>(), 0);
158+
test<class KernelName_GPcuAlvAOjrDyP>(q, input, output, sycl::bit_and<int>(), ~0);
159159
#endif // SPIRV_1_3
160160

161161
std::cout << "Test passed." << std::endl;

sycl/test/on-device/group_algorithms_sycl2020/inclusive_scan.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -138,25 +138,25 @@ int main() {
138138
std::iota(input.begin(), input.end(), 0);
139139
std::fill(output.begin(), output.end(), 0);
140140

141-
test<class KernelNamePlusV>(q, input, output, std::plus<>(), 0);
142-
test<class KernelNameMinimumV>(q, input, output, sycl::ONEAPI::minimum<>(),
141+
test<class KernelNamePlusV>(q, input, output, sycl::plus<>(), 0);
142+
test<class KernelNameMinimumV>(q, input, output, sycl::minimum<>(),
143143
std::numeric_limits<int>::max());
144-
test<class KernelNameMaximumV>(q, input, output, sycl::ONEAPI::maximum<>(),
144+
test<class KernelNameMaximumV>(q, input, output, sycl::maximum<>(),
145145
std::numeric_limits<int>::lowest());
146146

147-
test<class KernelNamePlusI>(q, input, output, std::plus<int>(), 0);
148-
test<class KernelNameMinimumI>(q, input, output, sycl::ONEAPI::minimum<int>(),
147+
test<class KernelNamePlusI>(q, input, output, sycl::plus<int>(), 0);
148+
test<class KernelNameMinimumI>(q, input, output, sycl::minimum<int>(),
149149
std::numeric_limits<int>::max());
150-
test<class KernelNameMaximumI>(q, input, output, sycl::ONEAPI::maximum<int>(),
150+
test<class KernelNameMaximumI>(q, input, output, sycl::maximum<int>(),
151151
std::numeric_limits<int>::lowest());
152152

153153
#ifdef SPIRV_1_3
154154
test<class KernelName_zMyjxUrBgeUGoxmDwhvJ>(q, input, output,
155-
multiplies<int>(), 1);
156-
test<class KernelName_SljjtroxNRaAXoVnT>(q, input, output, bit_or<int>(), 0);
157-
test<class KernelName_yXIZfjwjxQGiPeQAnc>(q, input, output, bit_xor<int>(),
155+
sycl::multiplies<int>(), 1);
156+
test<class KernelName_SljjtroxNRaAXoVnT>(q, input, output, sycl::bit_or<int>(), 0);
157+
test<class KernelName_yXIZfjwjxQGiPeQAnc>(q, input, output, sycl::bit_xor<int>(),
158158
0);
159-
test<class KernelName_xGnAnMYHvqekCk>(q, input, output, bit_and<int>(), ~0);
159+
test<class KernelName_xGnAnMYHvqekCk>(q, input, output, sycl::bit_and<int>(), ~0);
160160
#endif // SPIRV_1_3
161161

162162
std::cout << "Test passed." << std::endl;

sycl/test/on-device/group_algorithms_sycl2020/reduce.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,24 +74,24 @@ int main() {
7474
std::iota(input.begin(), input.end(), 0);
7575
std::fill(output.begin(), output.end(), 0);
7676

77-
test<class KernelNamePlusV>(q, input, output, std::plus<>(), 0);
78-
test<class KernelNameMinimumV>(q, input, output, sycl::ONEAPI::minimum<>(),
77+
test<class KernelNamePlusV>(q, input, output, sycl::plus<>(), 0);
78+
test<class KernelNameMinimumV>(q, input, output, sycl::minimum<>(),
7979
std::numeric_limits<int>::max());
80-
test<class KernelNameMaximumV>(q, input, output, sycl::ONEAPI::maximum<>(),
80+
test<class KernelNameMaximumV>(q, input, output, sycl::maximum<>(),
8181
std::numeric_limits<int>::lowest());
8282

83-
test<class KernelNamePlusI>(q, input, output, std::plus<int>(), 0);
84-
test<class KernelNameMinimumI>(q, input, output, sycl::ONEAPI::minimum<int>(),
83+
test<class KernelNamePlusI>(q, input, output, sycl::plus<int>(), 0);
84+
test<class KernelNameMinimumI>(q, input, output, sycl::minimum<int>(),
8585
std::numeric_limits<int>::max());
86-
test<class KernelNameMaximumI>(q, input, output, sycl::ONEAPI::maximum<int>(),
86+
test<class KernelNameMaximumI>(q, input, output, sycl::maximum<int>(),
8787
std::numeric_limits<int>::lowest());
8888

8989
#ifdef SPIRV_1_3
9090
test<class KernelName_WonwuUVPUPOTKRKIBtT>(q, input, output,
91-
multiplies<int>(), 1);
92-
test<class KernelName_qYBaJDZTMGkdIwD>(q, input, output, bit_or<int>(), 0);
93-
test<class KernelName_eLSFt>(q, input, output, bit_xor<int>(), 0);
94-
test<class KernelName_uFhJnxSVhNAiFPTG>(q, input, output, bit_and<int>(), ~0);
91+
sycl::multiplies<int>(), 1);
92+
test<class KernelName_qYBaJDZTMGkdIwD>(q, input, output, sycl::bit_or<int>(), 0);
93+
test<class KernelName_eLSFt>(q, input, output, sycl::bit_xor<int>(), 0);
94+
test<class KernelName_uFhJnxSVhNAiFPTG>(q, input, output, sycl::bit_and<int>(), ~0);
9595
#endif // SPIRV_1_3
9696

9797
std::cout << "Test passed." << std::endl;

0 commit comments

Comments
 (0)