@@ -16,46 +16,13 @@ namespace sycl {
16
16
namespace ext {
17
17
namespace oneapi {
18
18
19
- template <typename T = void > struct minimum {
20
- T operator ()(const T &lhs, const T &rhs) const {
21
- return std::less<T>()(lhs, rhs) ? lhs : rhs;
22
- }
23
- };
24
-
25
- template <> struct minimum <void > {
26
- struct is_transparent {};
27
- template <typename T, typename U>
28
- auto operator ()(T &&lhs, U &&rhs) const ->
29
- typename std::common_type<T &&, U &&>::type {
30
- return std::less<>()(std::forward<const T>(lhs), std::forward<const U>(rhs))
31
- ? std::forward<T>(lhs)
32
- : std::forward<U>(rhs);
33
- }
34
- };
35
-
36
- template <typename T = void > struct maximum {
37
- T operator ()(const T &lhs, const T &rhs) const {
38
- return std::greater<T>()(lhs, rhs) ? lhs : rhs;
39
- }
40
- };
41
-
42
- template <> struct maximum <void > {
43
- struct is_transparent {};
44
- template <typename T, typename U>
45
- auto operator ()(T &&lhs, U &&rhs) const ->
46
- typename std::common_type<T &&, U &&>::type {
47
- return std::greater<>()(std::forward<const T>(lhs),
48
- std::forward<const U>(rhs))
49
- ? std::forward<T>(lhs)
50
- : std::forward<U>(rhs);
51
- }
52
- };
53
-
54
19
template <typename T = void > using plus = std::plus<T>;
55
20
template <typename T = void > using multiplies = std::multiplies<T>;
56
21
template <typename T = void > using bit_or = std::bit_or<T>;
57
22
template <typename T = void > using bit_xor = std::bit_xor<T>;
58
23
template <typename T = void > using bit_and = std::bit_and<T>;
24
+ template <typename T = void > using maximum = sycl::maximum<T>;
25
+ template <typename T = void > using minimum = sycl::minimum<T>;
59
26
60
27
} // namespace oneapi
61
28
} // namespace ext
@@ -106,41 +73,29 @@ struct GroupOpTag<T, detail::enable_if_t<detail::is_sgenfloat<T>::value>> {
106
73
return Ret; \
107
74
}
108
75
109
- // calc for sycl minimum/maximum function objects
76
+ // calc for sycl function objects
110
77
__SYCL_CALC_OVERLOAD (GroupOpISigned, SMin, sycl::minimum<T>)
111
78
__SYCL_CALC_OVERLOAD (GroupOpIUnsigned, UMin, sycl::minimum<T>)
112
79
__SYCL_CALC_OVERLOAD (GroupOpFP, FMin, sycl::minimum<T>)
80
+
113
81
__SYCL_CALC_OVERLOAD (GroupOpISigned, SMax, sycl::maximum<T>)
114
82
__SYCL_CALC_OVERLOAD (GroupOpIUnsigned, UMax, sycl::maximum<T>)
115
83
__SYCL_CALC_OVERLOAD (GroupOpFP, FMax, sycl::maximum<T>)
116
84
117
- // calc for oneapi function objects
118
- __SYCL_CALC_OVERLOAD (GroupOpISigned, SMin, ext::oneapi::minimum<T>)
119
- __SYCL_CALC_OVERLOAD (GroupOpIUnsigned, UMin, ext::oneapi::minimum<T>)
120
- __SYCL_CALC_OVERLOAD (GroupOpFP, FMin, ext::oneapi::minimum<T>)
121
- __SYCL_CALC_OVERLOAD (GroupOpISigned, SMax, ext::oneapi::maximum<T>)
122
- __SYCL_CALC_OVERLOAD (GroupOpIUnsigned, UMax, ext::oneapi::maximum<T>)
123
- __SYCL_CALC_OVERLOAD (GroupOpFP, FMax, ext::oneapi::maximum<T>)
124
- __SYCL_CALC_OVERLOAD (GroupOpISigned, IAdd, ext::oneapi::plus<T>)
125
- __SYCL_CALC_OVERLOAD (GroupOpIUnsigned, IAdd, ext::oneapi::plus<T>)
126
- __SYCL_CALC_OVERLOAD (GroupOpFP, FAdd, ext::oneapi::plus<T>)
127
-
128
- __SYCL_CALC_OVERLOAD (GroupOpISigned, NonUniformIMul, ext::oneapi::multiplies<T>)
129
- __SYCL_CALC_OVERLOAD (GroupOpIUnsigned, NonUniformIMul,
130
- ext::oneapi::multiplies<T>)
131
- __SYCL_CALC_OVERLOAD (GroupOpFP, NonUniformFMul, ext::oneapi::multiplies<T>)
132
- __SYCL_CALC_OVERLOAD (GroupOpISigned, NonUniformBitwiseOr,
133
- ext::oneapi::bit_or<T>)
134
- __SYCL_CALC_OVERLOAD (GroupOpIUnsigned, NonUniformBitwiseOr,
135
- ext::oneapi::bit_or<T>)
136
- __SYCL_CALC_OVERLOAD (GroupOpISigned, NonUniformBitwiseXor,
137
- ext::oneapi::bit_xor<T>)
138
- __SYCL_CALC_OVERLOAD (GroupOpIUnsigned, NonUniformBitwiseXor,
139
- ext::oneapi::bit_xor<T>)
140
- __SYCL_CALC_OVERLOAD (GroupOpISigned, NonUniformBitwiseAnd,
141
- ext::oneapi::bit_and<T>)
142
- __SYCL_CALC_OVERLOAD (GroupOpIUnsigned, NonUniformBitwiseAnd,
143
- ext::oneapi::bit_and<T>)
85
+ __SYCL_CALC_OVERLOAD (GroupOpISigned, IAdd, sycl::plus<T>)
86
+ __SYCL_CALC_OVERLOAD (GroupOpIUnsigned, IAdd, sycl::plus<T>)
87
+ __SYCL_CALC_OVERLOAD (GroupOpFP, FAdd, sycl::plus<T>)
88
+
89
+ __SYCL_CALC_OVERLOAD (GroupOpISigned, NonUniformIMul, sycl::multiplies<T>)
90
+ __SYCL_CALC_OVERLOAD (GroupOpIUnsigned, NonUniformIMul, sycl::multiplies<T>)
91
+ __SYCL_CALC_OVERLOAD (GroupOpFP, NonUniformFMul, sycl::multiplies<T>)
92
+
93
+ __SYCL_CALC_OVERLOAD (GroupOpISigned, NonUniformBitwiseOr, sycl::bit_or<T>)
94
+ __SYCL_CALC_OVERLOAD (GroupOpIUnsigned, NonUniformBitwiseOr, sycl::bit_or<T>)
95
+ __SYCL_CALC_OVERLOAD (GroupOpISigned, NonUniformBitwiseXor, sycl::bit_xor<T>)
96
+ __SYCL_CALC_OVERLOAD (GroupOpIUnsigned, NonUniformBitwiseXor, sycl::bit_xor<T>)
97
+ __SYCL_CALC_OVERLOAD (GroupOpISigned, NonUniformBitwiseAnd, sycl::bit_and<T>)
98
+ __SYCL_CALC_OVERLOAD (GroupOpIUnsigned, NonUniformBitwiseAnd, sycl::bit_and<T>)
144
99
145
100
#undef __SYCL_CALC_OVERLOAD
146
101
0 commit comments