Skip to content

Commit 8fdc68d

Browse files
[SYCL] Add reducer class member aliases and constexpr value (#8211)
This commit adds the `value_type` and `binary_operation` member aliases and the `dimensions` value to the reducer class. Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 752e4d3 commit 8fdc68d

File tree

2 files changed

+79
-5
lines changed

2 files changed

+79
-5
lines changed

sycl/include/sycl/reduction.hpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,14 @@ template <class Reducer> class combiner {
320320
ReduVarPtr, [](auto &&Ref, auto Val) { return Ref.fetch_max(Val); });
321321
}
322322
};
323+
324+
template <typename T, class BinaryOperation, int Dims> class reducer_common {
325+
public:
326+
using value_type = T;
327+
using binary_operation = BinaryOperation;
328+
static constexpr int dimensions = Dims;
329+
};
330+
323331
} // namespace detail
324332

325333
/// Specialization of the generic class 'reducer'. It is used for reductions
@@ -336,7 +344,8 @@ class reducer<
336344
reducer<T, BinaryOperation, Dims, Extent, View,
337345
std::enable_if_t<
338346
Dims == 0 && Extent == 1 && View == false &&
339-
!detail::IsKnownIdentityOp<T, BinaryOperation>::value>>> {
347+
!detail::IsKnownIdentityOp<T, BinaryOperation>::value>>>,
348+
public detail::reducer_common<T, BinaryOperation, Dims> {
340349
public:
341350
reducer(const T &Identity, BinaryOperation BOp)
342351
: MValue(Identity), MIdentity(Identity), MBinaryOp(BOp) {}
@@ -371,7 +380,8 @@ class reducer<
371380
reducer<T, BinaryOperation, Dims, Extent, View,
372381
std::enable_if_t<
373382
Dims == 0 && Extent == 1 && View == false &&
374-
detail::IsKnownIdentityOp<T, BinaryOperation>::value>>> {
383+
detail::IsKnownIdentityOp<T, BinaryOperation>::value>>>,
384+
public detail::reducer_common<T, BinaryOperation, Dims> {
375385
public:
376386
reducer() : MValue(getIdentity()) {}
377387
reducer(const T & /* Identity */, BinaryOperation) : MValue(getIdentity()) {}
@@ -398,7 +408,8 @@ class reducer<T, BinaryOperation, Dims, Extent, View,
398408
std::enable_if_t<Dims == 0 && View == true>>
399409
: public detail::combiner<
400410
reducer<T, BinaryOperation, Dims, Extent, View,
401-
std::enable_if_t<Dims == 0 && View == true>>> {
411+
std::enable_if_t<Dims == 0 && View == true>>>,
412+
public detail::reducer_common<T, BinaryOperation, Dims> {
402413
public:
403414
reducer(T &Ref, BinaryOperation BOp) : MElement(Ref), MBinaryOp(BOp) {}
404415

@@ -423,7 +434,8 @@ class reducer<
423434
reducer<T, BinaryOperation, Dims, Extent, View,
424435
std::enable_if_t<
425436
Dims == 1 && View == false &&
426-
!detail::IsKnownIdentityOp<T, BinaryOperation>::value>>> {
437+
!detail::IsKnownIdentityOp<T, BinaryOperation>::value>>>,
438+
public detail::reducer_common<T, BinaryOperation, Dims> {
427439
public:
428440
reducer(const T &Identity, BinaryOperation BOp)
429441
: MValue(Identity), MIdentity(Identity), MBinaryOp(BOp) {}
@@ -453,7 +465,8 @@ class reducer<
453465
reducer<T, BinaryOperation, Dims, Extent, View,
454466
std::enable_if_t<
455467
Dims == 1 && View == false &&
456-
detail::IsKnownIdentityOp<T, BinaryOperation>::value>>> {
468+
detail::IsKnownIdentityOp<T, BinaryOperation>::value>>>,
469+
public detail::reducer_common<T, BinaryOperation, Dims> {
457470
public:
458471
reducer() : MValue(getIdentity()) {}
459472
reducer(const T & /* Identity */, BinaryOperation) : MValue(getIdentity()) {}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// RUN: %clangxx -fsycl -fsyntax-only -sycl-std=2020 %s
2+
3+
// Tests the member aliases on the reducer class.
4+
5+
#include <sycl/sycl.hpp>
6+
7+
#include <type_traits>
8+
9+
template <typename T> class Kernel;
10+
11+
template <typename ReducerT, typename ValT, typename BinOp, int Dims>
12+
void CheckReducerAliases() {
13+
static_assert(std::is_same_v<typename ReducerT::value_type, ValT>);
14+
static_assert(std::is_same_v<typename ReducerT::binary_operation, BinOp>);
15+
static_assert(ReducerT::dimensions == Dims);
16+
}
17+
18+
template <typename T> void CheckAllReducers(sycl::queue &Q) {
19+
T *Vals = sycl::malloc_device<T>(4, Q);
20+
sycl::span<T, 4> SpanVal(Vals, 4);
21+
22+
auto CustomOp = [](const T &LHS, const T &RHS) { return LHS + RHS; };
23+
24+
auto ValReduction1 = sycl::reduction(Vals, sycl::plus<>());
25+
auto ValReduction2 = sycl::reduction(Vals, T{}, sycl::plus<>());
26+
auto ValReduction3 = sycl::reduction(Vals, T{}, CustomOp);
27+
auto SpanReduction1 = sycl::reduction(SpanVal, sycl::plus<>());
28+
auto SpanReduction2 = sycl::reduction(SpanVal, T{}, sycl::plus<>());
29+
auto SpanReduction3 = sycl::reduction(SpanVal, T{}, CustomOp);
30+
// TODO: Add cases with identityless reductions when supported.
31+
Q.parallel_for<Kernel<T>>(
32+
sycl::range<1>{10}, ValReduction1, ValReduction2, ValReduction3,
33+
SpanReduction1, SpanReduction2, SpanReduction3,
34+
[=](sycl::id<1>, auto &ValRedu1, auto &ValRedu2, auto &ValRedu3,
35+
auto &SpanRedu1, auto &SpanRedu2, auto &SpanRedu3) {
36+
CheckReducerAliases<std::remove_reference_t<decltype(ValRedu1)>, T,
37+
sycl::plus<>, 0>();
38+
CheckReducerAliases<std::remove_reference_t<decltype(ValRedu2)>, T,
39+
sycl::plus<>, 0>();
40+
CheckReducerAliases<std::remove_reference_t<decltype(ValRedu3)>, T,
41+
decltype(CustomOp), 0>();
42+
CheckReducerAliases<std::remove_reference_t<decltype(SpanRedu1)>, T,
43+
sycl::plus<>, 1>();
44+
CheckReducerAliases<std::remove_reference_t<decltype(SpanRedu2)>, T,
45+
sycl::plus<>, 1>();
46+
CheckReducerAliases<std::remove_reference_t<decltype(SpanRedu3)>, T,
47+
decltype(CustomOp), 1>();
48+
});
49+
}
50+
51+
int main() {
52+
sycl::queue Q;
53+
CheckAllReducers<char>(Q);
54+
CheckAllReducers<short>(Q);
55+
CheckAllReducers<int>(Q);
56+
CheckAllReducers<long>(Q);
57+
CheckAllReducers<float>(Q);
58+
CheckAllReducers<double>(Q);
59+
CheckAllReducers<sycl::half>(Q);
60+
return 0;
61+
}

0 commit comments

Comments
 (0)