Skip to content

Commit cab4e5c

Browse files
committed
Providing function objects for sycl
1 parent df9c728 commit cab4e5c

File tree

6 files changed

+90
-54
lines changed

6 files changed

+90
-54
lines changed

sycl/include/CL/sycl.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <CL/sycl/event.hpp>
2929
#include <CL/sycl/exception.hpp>
3030
#include <CL/sycl/feature_test.hpp>
31+
#include <CL/sycl/functional.hpp>
3132
#include <CL/sycl/group.hpp>
3233
#include <CL/sycl/group_algorithm.hpp>
3334
#include <CL/sycl/group_local_memory.hpp>

sycl/include/CL/sycl/ONEAPI/functional.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#pragma once
10+
#include <CL/sycl/functional.hpp>
11+
1012
#include <functional>
1113

1214
__SYCL_INLINE_NAMESPACE(cl) {
@@ -90,6 +92,15 @@ struct GroupOpTag<T, detail::enable_if_t<detail::is_sgenfloat<T>::value>> {
9092
return Ret; \
9193
}
9294

95+
// calc for sycl minimum/maximum function objects
96+
__SYCL_CALC_OVERLOAD(GroupOpISigned, SMin, sycl::minimum<T>)
97+
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMin, sycl::minimum<T>)
98+
__SYCL_CALC_OVERLOAD(GroupOpFP, FMin, sycl::minimum<T>)
99+
__SYCL_CALC_OVERLOAD(GroupOpISigned, SMax, sycl::maximum<T>)
100+
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMax, sycl::maximum<T>)
101+
__SYCL_CALC_OVERLOAD(GroupOpFP, FMax, sycl::maximum<T>)
102+
103+
// calc for ONEAPI function objects
93104
__SYCL_CALC_OVERLOAD(GroupOpISigned, SMin, ONEAPI::minimum<T>)
94105
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMin, ONEAPI::minimum<T>)
95106
__SYCL_CALC_OVERLOAD(GroupOpFP, FMin, ONEAPI::minimum<T>)

sycl/include/CL/sycl/functional.hpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
//==----------- functional.hpp --- SYCL functional -------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
#include <functional>
11+
12+
__SYCL_INLINE_NAMESPACE(cl) {
13+
namespace sycl {
14+
15+
template <typename T = void> using plus = std::plus<T>;
16+
template <typename T = void> using multiplies = std::multiplies<T>;
17+
template <typename T = void> using bit_or = std::bit_or<T>;
18+
template <typename T = void> using bit_xor = std::bit_xor<T>;
19+
template <typename T = void> using bit_and = std::bit_and<T>;
20+
template <typename T = void> using logical_and = std::logical_and<T>;
21+
template <typename T = void> using logical_or = std::logical_or<T>;
22+
23+
template <typename T = void> struct minimum {
24+
T operator()(const T &lhs, const T &rhs) const {
25+
return std::less<T>()(lhs, rhs) ? lhs : rhs;
26+
}
27+
};
28+
29+
template <> struct minimum<void> {
30+
struct is_transparent {};
31+
template <typename T, typename U>
32+
auto operator()(T &&lhs, U &&rhs) const ->
33+
typename std::common_type<T &&, U &&>::type {
34+
return std::less<>()(std::forward<const T>(lhs), std::forward<const U>(rhs))
35+
? std::forward<T>(lhs)
36+
: std::forward<U>(rhs);
37+
}
38+
};
39+
40+
template <typename T = void> struct maximum {
41+
T operator()(const T &lhs, const T &rhs) const {
42+
return std::greater<T>()(lhs, rhs) ? lhs : rhs;
43+
}
44+
};
45+
46+
template <> struct maximum<void> {
47+
struct is_transparent {};
48+
template <typename T, typename U>
49+
auto operator()(T &&lhs, U &&rhs) const ->
50+
typename std::common_type<T &&, U &&>::type {
51+
return std::greater<>()(std::forward<const T>(lhs),
52+
std::forward<const U>(rhs))
53+
? std::forward<T>(lhs)
54+
: std::forward<U>(rhs);
55+
}
56+
};
57+
58+
} // namespace sycl
59+
} // __SYCL_INLINE_NAMESPACE(cl)

sycl/include/CL/sycl/group_algorithm.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <CL/sycl/ONEAPI/functional.hpp>
1414
#include <CL/sycl/detail/spirv.hpp>
1515
#include <CL/sycl/detail/type_traits.hpp>
16+
#include <CL/sycl/functional.hpp>
1617
#include <CL/sycl/group.hpp>
1718
#include <CL/sycl/known_identity.hpp>
1819
#include <CL/sycl/nd_item.hpp>
@@ -86,7 +87,10 @@ template <typename T>
8687
using native_op_list =
8788
type_list<ONEAPI::plus<T>, ONEAPI::bit_or<T>, ONEAPI::bit_xor<T>,
8889
ONEAPI::bit_and<T>, ONEAPI::maximum<T>, ONEAPI::minimum<T>,
89-
ONEAPI::multiplies<T>>;
90+
ONEAPI::multiplies<T>, sycl::plus<T>, sycl::bit_or<T>,
91+
sycl::bit_xor<T>, sycl::bit_and<T>, sycl::maximum<T>,
92+
sycl::minimum<T>, sycl::multiplies<T>, sycl::logical_or<T>,
93+
sycl::logical_and<T>>;
9094

9195
template <typename T, typename BinaryOperation> struct is_native_op {
9296
static constexpr bool value =

sycl/include/CL/sycl/known_identity.hpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ using cl::sycl::detail::is_sgeninteger;
2121

2222
template <typename T, class BinaryOperation>
2323
using IsReduPlus =
24-
bool_constant<std::is_same<BinaryOperation, ONEAPI::plus<T>>::value ||
25-
std::is_same<BinaryOperation, ONEAPI::plus<void>>::value>;
24+
bool_constant<std::is_same<BinaryOperation, std::plus<T>>::value ||
25+
std::is_same<BinaryOperation, std::plus<void>>::value>;
2626

2727
template <typename T, class BinaryOperation>
2828
using IsReduMultiplies =
@@ -32,27 +32,31 @@ using IsReduMultiplies =
3232
template <typename T, class BinaryOperation>
3333
using IsReduMinimum =
3434
bool_constant<std::is_same<BinaryOperation, ONEAPI::minimum<T>>::value ||
35-
std::is_same<BinaryOperation, ONEAPI::minimum<void>>::value>;
35+
std::is_same<BinaryOperation, ONEAPI::minimum<void>>::value ||
36+
std::is_same<BinaryOperation, sycl::minimum<T>>::value ||
37+
std::is_same<BinaryOperation, sycl::minimum<void>>::value>;
3638

3739
template <typename T, class BinaryOperation>
3840
using IsReduMaximum =
3941
bool_constant<std::is_same<BinaryOperation, ONEAPI::maximum<T>>::value ||
40-
std::is_same<BinaryOperation, ONEAPI::maximum<void>>::value>;
42+
std::is_same<BinaryOperation, ONEAPI::maximum<void>>::value ||
43+
std::is_same<BinaryOperation, sycl::maximum<T>>::value ||
44+
std::is_same<BinaryOperation, sycl::maximum<void>>::value>;
4145

4246
template <typename T, class BinaryOperation>
4347
using IsReduBitOR =
44-
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_or<T>>::value ||
45-
std::is_same<BinaryOperation, ONEAPI::bit_or<void>>::value>;
48+
bool_constant<std::is_same<BinaryOperation, std::bit_or<T>>::value ||
49+
std::is_same<BinaryOperation, std::bit_or<void>>::value>;
4650

4751
template <typename T, class BinaryOperation>
4852
using IsReduBitXOR =
49-
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_xor<T>>::value ||
50-
std::is_same<BinaryOperation, ONEAPI::bit_xor<void>>::value>;
53+
bool_constant<std::is_same<BinaryOperation, std::bit_xor<T>>::value ||
54+
std::is_same<BinaryOperation, std::bit_xor<void>>::value>;
5155

5256
template <typename T, class BinaryOperation>
5357
using IsReduBitAND =
54-
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_and<T>>::value ||
55-
std::is_same<BinaryOperation, ONEAPI::bit_and<void>>::value>;
58+
bool_constant<std::is_same<BinaryOperation, std::bit_and<T>>::value ||
59+
std::is_same<BinaryOperation, std::bit_and<void>>::value>;
5660

5761
template <typename T, class BinaryOperation>
5862
using IsReduOptForFastAtomicFetch =

sycl/include/CL/sycl/stl.hpp

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -72,48 +72,5 @@ unique_ptr_class<T> make_unique_ptr(ArgsT &&... Args) {
7272
return unique_ptr_class<T>(new T(std::forward<ArgsT>(Args)...));
7373
}
7474

75-
template <typename T = void> using plus = std::plus<T>;
76-
template <typename T = void> using multiplies = std::multiplies<T>;
77-
template <typename T = void> using bit_or = std::bit_or<T>;
78-
template <typename T = void> using bit_xor = std::bit_xor<T>;
79-
template <typename T = void> using bit_and = std::bit_and<T>;
80-
template <typename T = void> using logical_and = std::logical_and<T>;
81-
template <typename T = void> using logical_or = std::logical_or<T>;
82-
83-
template <typename T = void> struct minimum {
84-
T operator()(const T &lhs, const T &rhs) const {
85-
return std::less<T>()(lhs, rhs) ? lhs : rhs;
86-
}
87-
};
88-
89-
template <> struct minimum<void> {
90-
struct is_transparent {};
91-
template <typename T, typename U>
92-
auto operator()(T &&lhs, U &&rhs) const ->
93-
typename std::common_type<T &&, U &&>::type {
94-
return std::less<>()(std::forward<const T>(lhs), std::forward<const U>(rhs))
95-
? std::forward<T>(lhs)
96-
: std::forward<U>(rhs);
97-
}
98-
};
99-
100-
template <typename T = void> struct maximum {
101-
T operator()(const T &lhs, const T &rhs) const {
102-
return std::greater<T>()(lhs, rhs) ? lhs : rhs;
103-
}
104-
};
105-
106-
template <> struct maximum<void> {
107-
struct is_transparent {};
108-
template <typename T, typename U>
109-
auto operator()(T &&lhs, U &&rhs) const ->
110-
typename std::common_type<T &&, U &&>::type {
111-
return std::greater<>()(std::forward<const T>(lhs),
112-
std::forward<const U>(rhs))
113-
? std::forward<T>(lhs)
114-
: std::forward<U>(rhs);
115-
}
116-
};
117-
11875
} // namespace sycl
11976
} // __SYCL_INLINE_NAMESPACE(cl)

0 commit comments

Comments
 (0)