Skip to content

Commit ce546bd

Browse files
Merge pull request intel#1460 from dm-vodopyanov/private/dvodopya/add-user-defined-reductions-test
[SYCL] Add tests for user-defined reductions extension
2 parents ede4924 + 262e3c6 commit ce546bd

File tree

1 file changed

+211
-0
lines changed

1 file changed

+211
-0
lines changed
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
// RUN: %clangxx -fsycl %s -o %t.out
2+
// RUN: %t.out
3+
//
4+
// UNSUPPORTED: cuda || hip
5+
6+
#include <numeric>
7+
8+
#include <sycl/ext/oneapi/experimental/user_defined_reductions.hpp>
9+
#include <sycl/sycl.hpp>
10+
11+
template <typename T> struct UserDefinedSum {
12+
T operator()(T a, T b) { return a + b; }
13+
};
14+
15+
template <typename T> struct UserDefinedMax {
16+
T operator()(T a, T b) { return (a < b) ? b : a; }
17+
};
18+
19+
template <typename T> struct UserDefinedBitAnd {
20+
T operator()(const T &a, const T &b) const { return a & b; }
21+
};
22+
23+
template <typename T> struct UserDefinedMultiplies {
24+
T operator()(const T &a, const T &b) const { return a * b; }
25+
};
26+
27+
struct custom_type_nested {
28+
static constexpr int default_i_value = 6;
29+
static constexpr float default_f_value = 1.5;
30+
31+
constexpr custom_type_nested() = default;
32+
constexpr custom_type_nested(int i, float f) : i(i), f(f) {}
33+
34+
int i = default_i_value;
35+
float f = default_f_value;
36+
};
37+
38+
inline bool operator==(const custom_type_nested &lhs,
39+
const custom_type_nested &rhs) {
40+
return lhs.i == rhs.i && lhs.f == rhs.f;
41+
}
42+
43+
inline std::ostream &operator<<(std::ostream &out,
44+
const custom_type_nested &v) {
45+
return out << "custom_type_nested { .i = " << v.i << ", .f = " << v.f << "}";
46+
}
47+
48+
struct custom_type {
49+
static constexpr unsigned long long default_ull_value = 42;
50+
51+
constexpr custom_type() = default;
52+
constexpr custom_type(int i, float f, unsigned long long ull)
53+
: n(i, f), ull(ull) {}
54+
55+
custom_type_nested n;
56+
unsigned long long ull = default_ull_value;
57+
};
58+
59+
inline bool operator==(const custom_type &lhs, const custom_type &rhs) {
60+
return lhs.n == rhs.n && lhs.ull == rhs.ull;
61+
}
62+
63+
inline custom_type operator+(const custom_type &lhs, const custom_type &rhs) {
64+
return custom_type(lhs.n.i + rhs.n.i, lhs.n.f + rhs.n.f, lhs.ull + rhs.ull);
65+
}
66+
67+
struct custom_type_wo_default_ctor {
68+
static constexpr unsigned long long default_ull_value = 42;
69+
70+
constexpr custom_type_wo_default_ctor() = delete;
71+
constexpr custom_type_wo_default_ctor(int i, float f, unsigned long long ull)
72+
: n(i, f), ull(ull) {}
73+
74+
custom_type_nested n;
75+
unsigned long long ull = default_ull_value;
76+
};
77+
78+
inline bool operator==(const custom_type_wo_default_ctor &lhs,
79+
const custom_type_wo_default_ctor &rhs) {
80+
return lhs.n == rhs.n && lhs.ull == rhs.ull;
81+
}
82+
83+
inline custom_type_wo_default_ctor
84+
operator+(const custom_type_wo_default_ctor &lhs,
85+
const custom_type_wo_default_ctor &rhs) {
86+
return custom_type_wo_default_ctor(lhs.n.i + rhs.n.i, lhs.n.f + rhs.n.f,
87+
lhs.ull + rhs.ull);
88+
}
89+
90+
template <typename T, std::size_t... Is>
91+
constexpr std::array<T, sizeof...(Is)> init_array(T value,
92+
std::index_sequence<Is...>) {
93+
return {{(static_cast<void>(Is), value)...}};
94+
}
95+
96+
using namespace sycl;
97+
98+
template <typename InputContainer, typename OutputContainer,
99+
class BinaryOperation>
100+
void test(queue q, InputContainer input, OutputContainer output,
101+
BinaryOperation binary_op, size_t workgroup_size,
102+
typename OutputContainer::value_type identity,
103+
typename OutputContainer::value_type init) {
104+
using InputT = typename InputContainer::value_type;
105+
using OutputT = typename OutputContainer::value_type;
106+
constexpr size_t N = input.size();
107+
{
108+
buffer<InputT> in_buf(input.data(), input.size());
109+
buffer<OutputT> out_buf(output.data(), output.size());
110+
111+
q.submit([&](handler &cgh) {
112+
accessor in{in_buf, cgh, sycl::read_only};
113+
accessor out{out_buf, cgh, sycl::write_only, sycl::no_init};
114+
115+
size_t temp_memory_size = workgroup_size * sizeof(InputT);
116+
auto scratch = sycl::local_accessor<std::byte, 1>(temp_memory_size, cgh);
117+
cgh.parallel_for(
118+
nd_range<1>(workgroup_size, workgroup_size), [=](nd_item<1> it) {
119+
// Create a handle that associates the group with an allocation it
120+
// can use
121+
auto handle =
122+
sycl::ext::oneapi::experimental::group_with_scratchpad(
123+
it.get_group(), sycl::span(&scratch[0], temp_memory_size));
124+
125+
InputT *first = in.get_pointer();
126+
InputT *last = first + N;
127+
// check reduce_over_group w/o init
128+
out[0] = sycl::ext::oneapi::experimental::reduce_over_group(
129+
handle, in[it.get_global_id(0)], binary_op);
130+
131+
// check reduce_over_group with init
132+
out[1] = sycl::ext::oneapi::experimental::reduce_over_group(
133+
handle, in[it.get_global_id(0)], init, binary_op);
134+
135+
// check joint_reduce w/o init
136+
out[2] = sycl::ext::oneapi::experimental::joint_reduce(
137+
handle, first, last, binary_op);
138+
139+
// check joint_reduce with init
140+
out[3] = sycl::ext::oneapi::experimental::joint_reduce(
141+
handle, first, last, init, binary_op);
142+
});
143+
});
144+
q.wait();
145+
}
146+
assert(output[0] == std::reduce(input.begin(), input.begin() + workgroup_size,
147+
identity, binary_op));
148+
assert(output[1] == std::reduce(input.begin(), input.begin() + workgroup_size,
149+
init, binary_op));
150+
assert(output[2] ==
151+
std::reduce(input.begin(), input.end(), identity, binary_op));
152+
assert(output[3] == std::reduce(input.begin(), input.end(), init, binary_op));
153+
}
154+
155+
int main() {
156+
queue q;
157+
158+
constexpr int N = 128;
159+
std::array<int, N> input;
160+
std::array<int, 4> output;
161+
std::iota(input.begin(), input.end(), 0);
162+
std::fill(output.begin(), output.end(), 0);
163+
164+
// queue, input array, output array, binary_op, WG size, identity, init
165+
test(q, input, output, UserDefinedSum<int>{}, 64, 0, 42);
166+
test(q, input, output, UserDefinedSum<int>{}, 32, 0, 42);
167+
test(q, input, output, UserDefinedSum<int>{}, 5, 0, 42);
168+
test(q, input, output, UserDefinedMax<int>{}, 64,
169+
std::numeric_limits<int>::lowest(), 42);
170+
test(q, input, output, UserDefinedMultiplies<int>(), 64, 1, 42);
171+
test(q, input, output, UserDefinedBitAnd<int>{}, 64, ~0, 42);
172+
173+
test(q, input, output, sycl::plus<int>(), 64, 0, 42);
174+
test(q, input, output, sycl::maximum<>(), 64,
175+
std::numeric_limits<int>::lowest(), 42);
176+
test(q, input, output, sycl::minimum<int>(), 64,
177+
std::numeric_limits<int>::max(), 42);
178+
179+
test(q, input, output, sycl::multiplies<int>(), 64, 1, 42);
180+
test(q, input, output, sycl::bit_or<int>(), 64, 0, 42);
181+
test(q, input, output, sycl::bit_xor<int>(), 64, 0, 42);
182+
test(q, input, output, sycl::bit_and<int>(), 64, ~0, 42);
183+
184+
std::array<custom_type, N> input_custom;
185+
std::array<custom_type, 4> output_custom;
186+
test(q, input_custom, output_custom, UserDefinedSum<custom_type>{}, 64,
187+
custom_type(0, 0., 0), custom_type(10, 0., 5));
188+
189+
custom_type_wo_default_ctor value(1, 2.5, 3);
190+
std::array<custom_type_wo_default_ctor, N> input_custom_wo_default_ctor =
191+
init_array(value, std::make_index_sequence<N>());
192+
std::array<custom_type_wo_default_ctor, 4> output_custom_wo_default_ctor =
193+
init_array(value, std::make_index_sequence<4>());
194+
test(q, input_custom_wo_default_ctor, output_custom_wo_default_ctor,
195+
UserDefinedSum<custom_type_wo_default_ctor>{}, 64,
196+
custom_type_wo_default_ctor(0, 0., 0),
197+
custom_type_wo_default_ctor(10, 0., 5));
198+
199+
#ifdef SYCL_EXT_ONEAPI_COMPLEX_ALGORITHMS
200+
std::array<std::complex<float>, N> input_cf;
201+
std::array<std::complex<float>, 4> output_cf;
202+
std::iota(input_cf.begin(), input_cf.end(), 0);
203+
std::fill(output_cf.begin(), output_cf.end(), 0);
204+
test(q, input_cf, output_cf, sycl::plus<std::complex<float>>(), 64, 0, 42);
205+
test(q, input_cf, output_cf, sycl::plus<>(), 64, 0, 42);
206+
#else
207+
static_assert(false, "SYCL_EXT_ONEAPI_COMPLEX_ALGORITHMS not defined");
208+
#endif
209+
210+
return 0;
211+
}

0 commit comments

Comments
 (0)