Skip to content

Commit 7d6a560

Browse files
committed
Implements necessary sycl utilities for custom reductions
1 parent 51d994a commit 7d6a560

File tree

1 file changed

+265
-0
lines changed

1 file changed

+265
-0
lines changed

dpctl/tensor/libtensor/include/utils/sycl_utils.hpp

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,111 @@
2828
#include <cstddef>
2929
#include <vector>
3030

31+
#include "math_utils.hpp"
32+
3133
namespace dpctl
3234
{
3335
namespace tensor
3436
{
3537
namespace sycl_utils
3638
{
39+
namespace detail
40+
{
41+
42+
template <typename...> struct TypeList;
43+
44+
template <typename Head, typename... Tail> struct TypeList<Head, Tail...>
45+
{
46+
using head = Head;
47+
using tail = TypeList<Tail...>;
48+
};
49+
50+
using NullTypeList = TypeList<>;
51+
template <typename T>
52+
struct IsNullTypeList : std::conditional_t<std::is_same_v<T, NullTypeList>,
53+
std::true_type,
54+
std::false_type>
55+
{
56+
};
57+
58+
// recursively check if type is contained in given TypeList
59+
template <typename T, typename TList>
60+
struct IsContained
61+
: std::conditional_t<
62+
std::is_same_v<typename TList::head, std::remove_cv_t<T>>,
63+
std::true_type,
64+
IsContained<T, typename TList::tail>>
65+
{
66+
};
67+
68+
template <> struct TypeList<>
69+
{
70+
};
71+
72+
// std::false_type when last case has been checked for membership
73+
template <typename T> struct IsContained<T, NullTypeList> : std::false_type
74+
{
75+
};
76+
77+
template <class T> struct IsComplex : std::false_type
78+
{
79+
};
80+
template <class T> struct IsComplex<std::complex<T>> : std::true_type
81+
{
82+
};
83+
84+
} // namespace detail
85+
86+
template <typename T>
87+
using sycl_ops = detail::TypeList<sycl::plus<T>,
88+
sycl::bit_or<T>,
89+
sycl::bit_xor<T>,
90+
sycl::bit_and<T>,
91+
sycl::maximum<T>,
92+
sycl::minimum<T>,
93+
sycl::multiplies<T>>;
94+
95+
template <typename T, typename Op> struct IsSyclOp
96+
{
97+
static constexpr bool value =
98+
detail::IsContained<Op, sycl_ops<std::remove_const_t<T>>>::value ||
99+
detail::IsContained<Op, sycl_ops<std::add_const_t<T>>>::value ||
100+
detail::IsContained<Op, sycl_ops<void>>::value;
101+
};
102+
103+
struct AtomicSupport
104+
{
105+
bool operator()(const sycl::queue &exec_q,
106+
sycl::usm::alloc usm_alloc_type,
107+
bool require_atomic64 = false) const
108+
{
109+
bool supports_atomics = false;
110+
111+
const sycl::device &dev = exec_q.get_device();
112+
if (require_atomic64) {
113+
if (!dev.has(sycl::aspect::atomic64))
114+
return false;
115+
}
116+
117+
switch (usm_alloc_type) {
118+
case sycl::usm::alloc::shared:
119+
supports_atomics =
120+
dev.has(sycl::aspect::usm_atomic_shared_allocations);
121+
break;
122+
case sycl::usm::alloc::host:
123+
supports_atomics =
124+
dev.has(sycl::aspect::usm_atomic_host_allocations);
125+
break;
126+
case sycl::usm::alloc::device:
127+
supports_atomics = true;
128+
break;
129+
default:
130+
supports_atomics = false;
131+
}
132+
133+
return supports_atomics;
134+
}
135+
};
37136

38137
/*! @brief Find the smallest multiple of supported sub-group size larger than
39138
* nelems */
@@ -66,6 +165,172 @@ size_t choose_workgroup_size(const size_t nelems,
66165
return wg;
67166
}
68167

168+
template <typename T, typename GroupT, typename LocAccT, typename OpT>
169+
T custom_reduce_over_group(GroupT wg,
170+
LocAccT local_mem_acc,
171+
T local_val,
172+
OpT op)
173+
{
174+
size_t wgs = wg.get_local_linear_range();
175+
local_mem_acc[wg.get_local_linear_id()] = local_val;
176+
177+
sycl::group_barrier(wg, sycl::memory_scope::work_group);
178+
179+
T red_val_over_wg = local_mem_acc[0];
180+
if (wg.leader()) {
181+
for (size_t i = 1; i < wgs; ++i) {
182+
red_val_over_wg = op(red_val_over_wg, local_mem_acc[i]);
183+
}
184+
}
185+
186+
sycl::group_barrier(wg, sycl::memory_scope::work_group);
187+
188+
return sycl::group_broadcast(wg, red_val_over_wg);
189+
}
190+
191+
// Reduction functors
192+
193+
// Maximum
194+
195+
template <typename T> struct Maximum
196+
{
197+
T operator()(const T &x, const T &y) const
198+
{
199+
if constexpr (detail::IsComplex<T>::value) {
200+
using dpctl::tensor::math_utils::max_complex;
201+
return max_complex<T>(x, y);
202+
}
203+
else if constexpr (std::is_floating_point_v<T> ||
204+
std::is_same_v<T, sycl::half>) {
205+
return (std::isnan(x) || x > y) ? x : y;
206+
}
207+
else if constexpr (std::is_same_v<T, bool>) {
208+
return x || y;
209+
}
210+
else {
211+
return (x > y) ? x : y;
212+
}
213+
}
214+
};
215+
216+
// Minimum
217+
218+
template <typename T> struct Minimum
219+
{
220+
T operator()(const T &x, const T &y) const
221+
{
222+
if constexpr (detail::IsComplex<T>::value) {
223+
using dpctl::tensor::math_utils::min_complex;
224+
return min_complex<T>(x, y);
225+
}
226+
else if constexpr (std::is_floating_point_v<T> ||
227+
std::is_same_v<T, sycl::half>) {
228+
return (std::isnan(x) || x < y) ? x : y;
229+
}
230+
else if constexpr (std::is_same_v<T, bool>) {
231+
return x && y;
232+
}
233+
else {
234+
return (x < y) ? x : y;
235+
}
236+
}
237+
};
238+
239+
// Define identities and operator checking structs
240+
241+
template <typename Op, typename T, typename = void> struct GetIdentity
242+
{
243+
};
244+
245+
// Maximum
246+
247+
template <typename T, class Op>
248+
using IsMaximum = std::bool_constant<std::is_same_v<Op, sycl::maximum<T>> ||
249+
std::is_same_v<Op, sycl::maximum<void>> ||
250+
std::is_same_v<Op, Maximum<T>> ||
251+
std::is_same_v<Op, Maximum<void>>>;
252+
253+
template <typename Op, typename T>
254+
struct GetIdentity<Op, T, std::enable_if_t<IsMaximum<T, Op>::value>>
255+
{
256+
static constexpr T value =
257+
static_cast<T>(std::numeric_limits<T>::has_infinity
258+
? static_cast<T>(-std::numeric_limits<T>::infinity())
259+
: std::numeric_limits<T>::lowest());
260+
};
261+
262+
template <typename Op>
263+
struct GetIdentity<Op, bool, std::enable_if_t<IsMaximum<bool, Op>::value>>
264+
{
265+
static constexpr bool value = false;
266+
};
267+
268+
template <typename Op, typename T>
269+
struct GetIdentity<Op,
270+
std::complex<T>,
271+
std::enable_if_t<IsMaximum<std::complex<T>, Op>::value>>
272+
{
273+
static constexpr std::complex<T> value{-std::numeric_limits<T>::infinity(),
274+
-std::numeric_limits<T>::infinity()};
275+
};
276+
277+
// Minimum
278+
279+
template <typename T, class Op>
280+
using IsMinimum = std::bool_constant<std::is_same_v<Op, sycl::minimum<T>> ||
281+
std::is_same_v<Op, sycl::minimum<void>> ||
282+
std::is_same_v<Op, Minimum<T>> ||
283+
std::is_same_v<Op, Minimum<void>>>;
284+
285+
template <typename Op, typename T>
286+
struct GetIdentity<Op, T, std::enable_if_t<IsMinimum<T, Op>::value>>
287+
{
288+
static constexpr T value =
289+
static_cast<T>(std::numeric_limits<T>::has_infinity
290+
? static_cast<T>(std::numeric_limits<T>::infinity())
291+
: std::numeric_limits<T>::max());
292+
};
293+
294+
template <typename Op>
295+
struct GetIdentity<Op, bool, std::enable_if_t<IsMinimum<bool, Op>::value>>
296+
{
297+
static constexpr bool value = true;
298+
};
299+
300+
template <typename Op, typename T>
301+
struct GetIdentity<Op,
302+
std::complex<T>,
303+
std::enable_if_t<IsMinimum<std::complex<T>, Op>::value>>
304+
{
305+
static constexpr std::complex<T> value{std::numeric_limits<T>::infinity(),
306+
std::numeric_limits<T>::infinity()};
307+
};
308+
309+
// Plus
310+
311+
template <typename T, class Op>
312+
using IsPlus = std::bool_constant<
313+
std::is_same_v<Op, sycl::plus<T>> || std::is_same_v<Op, sycl::plus<void>> ||
314+
std::is_same_v<Op, std::plus<T>> || std::is_same_v<Op, std::plus<T>>>;
315+
316+
// Identity
317+
318+
template <typename Op, typename T, typename = void> struct Identity
319+
{
320+
};
321+
322+
template <typename Op, typename T>
323+
struct Identity<Op, T, std::enable_if_t<!IsSyclOp<T, Op>::value>>
324+
{
325+
static constexpr T value = GetIdentity<Op, T>::value;
326+
};
327+
328+
template <typename Op, typename T>
329+
struct Identity<Op, T, std::enable_if_t<IsSyclOp<T, Op>::value>>
330+
{
331+
static constexpr T value = sycl::known_identity<Op, T>::value;
332+
};
333+
69334
} // namespace sycl_utils
70335
} // namespace tensor
71336
} // namespace dpctl

0 commit comments

Comments
 (0)