Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 249e381

Browse files
[SYCL] Add tests for user-defined reductions extension (#1395)
Spec: intel/llvm#7202 Implementation: intel/llvm#7587
1 parent 1dd04ee commit 249e381

File tree

1 file changed

+209
-0
lines changed

1 file changed

+209
-0
lines changed
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
// RUN: %clangxx -fsycl %s -o %t.out
2+
// RUN: %t.out
3+
4+
#include <numeric>
5+
6+
#include <sycl/ext/oneapi/experimental/user_defined_reductions.hpp>
7+
#include <sycl/sycl.hpp>
8+
9+
template <typename T> struct UserDefinedSum {
10+
T operator()(T a, T b) { return a + b; }
11+
};
12+
13+
template <typename T> struct UserDefinedMax {
14+
T operator()(T a, T b) { return (a < b) ? b : a; }
15+
};
16+
17+
template <typename T> struct UserDefinedBitAnd {
18+
T operator()(const T &a, const T &b) const { return a & b; }
19+
};
20+
21+
template <typename T> struct UserDefinedMultiplies {
22+
T operator()(const T &a, const T &b) const { return a * b; }
23+
};
24+
25+
struct custom_type_nested {
26+
static constexpr int default_i_value = 6;
27+
static constexpr float default_f_value = 1.5;
28+
29+
constexpr custom_type_nested() = default;
30+
constexpr custom_type_nested(int i, float f) : i(i), f(f) {}
31+
32+
int i = default_i_value;
33+
float f = default_f_value;
34+
};
35+
36+
inline bool operator==(const custom_type_nested &lhs,
37+
const custom_type_nested &rhs) {
38+
return lhs.i == rhs.i && lhs.f == rhs.f;
39+
}
40+
41+
inline std::ostream &operator<<(std::ostream &out,
42+
const custom_type_nested &v) {
43+
return out << "custom_type_nested { .i = " << v.i << ", .f = " << v.f << "}";
44+
}
45+
46+
struct custom_type {
47+
static constexpr unsigned long long default_ull_value = 42;
48+
49+
constexpr custom_type() = default;
50+
constexpr custom_type(int i, float f, unsigned long long ull)
51+
: n(i, f), ull(ull) {}
52+
53+
custom_type_nested n;
54+
unsigned long long ull = default_ull_value;
55+
};
56+
57+
inline bool operator==(const custom_type &lhs, const custom_type &rhs) {
58+
return lhs.n == rhs.n && lhs.ull == rhs.ull;
59+
}
60+
61+
inline custom_type operator+(const custom_type &lhs, const custom_type &rhs) {
62+
return custom_type(lhs.n.i + rhs.n.i, lhs.n.f + rhs.n.f, lhs.ull + rhs.ull);
63+
}
64+
65+
struct custom_type_wo_default_ctor {
66+
static constexpr unsigned long long default_ull_value = 42;
67+
68+
constexpr custom_type_wo_default_ctor() = delete;
69+
constexpr custom_type_wo_default_ctor(int i, float f, unsigned long long ull)
70+
: n(i, f), ull(ull) {}
71+
72+
custom_type_nested n;
73+
unsigned long long ull = default_ull_value;
74+
};
75+
76+
inline bool operator==(const custom_type_wo_default_ctor &lhs,
77+
const custom_type_wo_default_ctor &rhs) {
78+
return lhs.n == rhs.n && lhs.ull == rhs.ull;
79+
}
80+
81+
inline custom_type_wo_default_ctor
82+
operator+(const custom_type_wo_default_ctor &lhs,
83+
const custom_type_wo_default_ctor &rhs) {
84+
return custom_type_wo_default_ctor(lhs.n.i + rhs.n.i, lhs.n.f + rhs.n.f,
85+
lhs.ull + rhs.ull);
86+
}
87+
88+
template <typename T, std::size_t... Is>
89+
constexpr std::array<T, sizeof...(Is)> init_array(T value,
90+
std::index_sequence<Is...>) {
91+
return {{(static_cast<void>(Is), value)...}};
92+
}
93+
94+
using namespace sycl;
95+
96+
template <typename InputContainer, typename OutputContainer,
97+
class BinaryOperation>
98+
void test(queue q, InputContainer input, OutputContainer output,
99+
BinaryOperation binary_op, size_t workgroup_size,
100+
typename OutputContainer::value_type identity,
101+
typename OutputContainer::value_type init) {
102+
using InputT = typename InputContainer::value_type;
103+
using OutputT = typename OutputContainer::value_type;
104+
constexpr size_t N = input.size();
105+
{
106+
buffer<InputT> in_buf(input.data(), input.size());
107+
buffer<OutputT> out_buf(output.data(), output.size());
108+
109+
q.submit([&](handler &cgh) {
110+
accessor in{in_buf, cgh, sycl::read_only};
111+
accessor out{out_buf, cgh, sycl::write_only, sycl::no_init};
112+
113+
size_t temp_memory_size = workgroup_size * sizeof(InputT);
114+
auto scratch = sycl::local_accessor<std::byte, 1>(temp_memory_size, cgh);
115+
cgh.parallel_for(
116+
nd_range<1>(workgroup_size, workgroup_size), [=](nd_item<1> it) {
117+
// Create a handle that associates the group with an allocation it
118+
// can use
119+
auto handle =
120+
sycl::ext::oneapi::experimental::group_with_scratchpad(
121+
it.get_group(), sycl::span(&scratch[0], temp_memory_size));
122+
123+
InputT *first = in.get_pointer();
124+
InputT *last = first + N;
125+
// check reduce_over_group w/o init
126+
out[0] = sycl::ext::oneapi::experimental::reduce_over_group(
127+
handle, in[it.get_global_id(0)], binary_op);
128+
129+
// check reduce_over_group with init
130+
out[1] = sycl::ext::oneapi::experimental::reduce_over_group(
131+
handle, in[it.get_global_id(0)], init, binary_op);
132+
133+
// check joint_reduce w/o init
134+
out[2] = sycl::ext::oneapi::experimental::joint_reduce(
135+
handle, first, last, binary_op);
136+
137+
// check joint_reduce with init
138+
out[3] = sycl::ext::oneapi::experimental::joint_reduce(
139+
handle, first, last, init, binary_op);
140+
});
141+
});
142+
q.wait();
143+
}
144+
assert(output[0] == std::reduce(input.begin(), input.begin() + workgroup_size,
145+
identity, binary_op));
146+
assert(output[1] == std::reduce(input.begin(), input.begin() + workgroup_size,
147+
init, binary_op));
148+
assert(output[2] ==
149+
std::reduce(input.begin(), input.end(), identity, binary_op));
150+
assert(output[3] == std::reduce(input.begin(), input.end(), init, binary_op));
151+
}
152+
153+
int main() {
154+
queue q;
155+
156+
constexpr int N = 128;
157+
std::array<int, N> input;
158+
std::array<int, 4> output;
159+
std::iota(input.begin(), input.end(), 0);
160+
std::fill(output.begin(), output.end(), 0);
161+
162+
// queue, input array, output array, binary_op, WG size, identity, init
163+
test(q, input, output, UserDefinedSum<int>{}, 64, 0, 42);
164+
test(q, input, output, UserDefinedSum<int>{}, 32, 0, 42);
165+
test(q, input, output, UserDefinedSum<int>{}, 5, 0, 42);
166+
test(q, input, output, UserDefinedMax<int>{}, 64,
167+
std::numeric_limits<int>::lowest(), 42);
168+
test(q, input, output, UserDefinedMultiplies<int>(), 64, 1, 42);
169+
test(q, input, output, UserDefinedBitAnd<int>{}, 64, ~0, 42);
170+
171+
test(q, input, output, sycl::plus<int>(), 64, 0, 42);
172+
test(q, input, output, sycl::maximum<>(), 64,
173+
std::numeric_limits<int>::lowest(), 42);
174+
test(q, input, output, sycl::minimum<int>(), 64,
175+
std::numeric_limits<int>::max(), 42);
176+
177+
test(q, input, output, sycl::multiplies<int>(), 64, 1, 42);
178+
test(q, input, output, sycl::bit_or<int>(), 64, 0, 42);
179+
test(q, input, output, sycl::bit_xor<int>(), 64, 0, 42);
180+
test(q, input, output, sycl::bit_and<int>(), 64, ~0, 42);
181+
182+
std::array<custom_type, N> input_custom;
183+
std::array<custom_type, 4> output_custom;
184+
test(q, input_custom, output_custom, UserDefinedSum<custom_type>{}, 64,
185+
custom_type(0, 0., 0), custom_type(10, 0., 5));
186+
187+
custom_type_wo_default_ctor value(1, 2.5, 3);
188+
std::array<custom_type_wo_default_ctor, N> input_custom_wo_default_ctor =
189+
init_array(value, std::make_index_sequence<N>());
190+
std::array<custom_type_wo_default_ctor, 4> output_custom_wo_default_ctor =
191+
init_array(value, std::make_index_sequence<4>());
192+
test(q, input_custom_wo_default_ctor, output_custom_wo_default_ctor,
193+
UserDefinedSum<custom_type_wo_default_ctor>{}, 64,
194+
custom_type_wo_default_ctor(0, 0., 0),
195+
custom_type_wo_default_ctor(10, 0., 5));
196+
197+
#ifdef SYCL_EXT_ONEAPI_COMPLEX_ALGORITHMS
198+
std::array<std::complex<float>, N> input_cf;
199+
std::array<std::complex<float>, 4> output_cf;
200+
std::iota(input_cf.begin(), input_cf.end(), 0);
201+
std::fill(output_cf.begin(), output_cf.end(), 0);
202+
test(q, input_cf, output_cf, sycl::plus<std::complex<float>>(), 64, 0, 42);
203+
test(q, input_cf, output_cf, sycl::plus<>(), 64, 0, 42);
204+
#else
205+
static_assert(false, "SYCL_EXT_ONEAPI_COMPLEX_ALGORITHMS not defined");
206+
#endif
207+
208+
return 0;
209+
}

0 commit comments

Comments
 (0)