|
| 1 | +// RUN: %clangxx -fsycl %s -o %t.out |
| 2 | +// RUN: %t.out |
| 3 | + |
| 4 | +#include <numeric> |
| 5 | + |
| 6 | +#include <sycl/ext/oneapi/experimental/user_defined_reductions.hpp> |
| 7 | +#include <sycl/sycl.hpp> |
| 8 | + |
| 9 | +template <typename T> struct UserDefinedSum { |
| 10 | + T operator()(T a, T b) { return a + b; } |
| 11 | +}; |
| 12 | + |
| 13 | +template <typename T> struct UserDefinedMax { |
| 14 | + T operator()(T a, T b) { return (a < b) ? b : a; } |
| 15 | +}; |
| 16 | + |
| 17 | +template <typename T> struct UserDefinedBitAnd { |
| 18 | + T operator()(const T &a, const T &b) const { return a & b; } |
| 19 | +}; |
| 20 | + |
| 21 | +template <typename T> struct UserDefinedMultiplies { |
| 22 | + T operator()(const T &a, const T &b) const { return a * b; } |
| 23 | +}; |
| 24 | + |
| 25 | +struct custom_type_nested { |
| 26 | + static constexpr int default_i_value = 6; |
| 27 | + static constexpr float default_f_value = 1.5; |
| 28 | + |
| 29 | + constexpr custom_type_nested() = default; |
| 30 | + constexpr custom_type_nested(int i, float f) : i(i), f(f) {} |
| 31 | + |
| 32 | + int i = default_i_value; |
| 33 | + float f = default_f_value; |
| 34 | +}; |
| 35 | + |
| 36 | +inline bool operator==(const custom_type_nested &lhs, |
| 37 | + const custom_type_nested &rhs) { |
| 38 | + return lhs.i == rhs.i && lhs.f == rhs.f; |
| 39 | +} |
| 40 | + |
| 41 | +inline std::ostream &operator<<(std::ostream &out, |
| 42 | + const custom_type_nested &v) { |
| 43 | + return out << "custom_type_nested { .i = " << v.i << ", .f = " << v.f << "}"; |
| 44 | +} |
| 45 | + |
| 46 | +struct custom_type { |
| 47 | + static constexpr unsigned long long default_ull_value = 42; |
| 48 | + |
| 49 | + constexpr custom_type() = default; |
| 50 | + constexpr custom_type(int i, float f, unsigned long long ull) |
| 51 | + : n(i, f), ull(ull) {} |
| 52 | + |
| 53 | + custom_type_nested n; |
| 54 | + unsigned long long ull = default_ull_value; |
| 55 | +}; |
| 56 | + |
| 57 | +inline bool operator==(const custom_type &lhs, const custom_type &rhs) { |
| 58 | + return lhs.n == rhs.n && lhs.ull == rhs.ull; |
| 59 | +} |
| 60 | + |
| 61 | +inline custom_type operator+(const custom_type &lhs, const custom_type &rhs) { |
| 62 | + return custom_type(lhs.n.i + rhs.n.i, lhs.n.f + rhs.n.f, lhs.ull + rhs.ull); |
| 63 | +} |
| 64 | + |
| 65 | +struct custom_type_wo_default_ctor { |
| 66 | + static constexpr unsigned long long default_ull_value = 42; |
| 67 | + |
| 68 | + constexpr custom_type_wo_default_ctor() = delete; |
| 69 | + constexpr custom_type_wo_default_ctor(int i, float f, unsigned long long ull) |
| 70 | + : n(i, f), ull(ull) {} |
| 71 | + |
| 72 | + custom_type_nested n; |
| 73 | + unsigned long long ull = default_ull_value; |
| 74 | +}; |
| 75 | + |
| 76 | +inline bool operator==(const custom_type_wo_default_ctor &lhs, |
| 77 | + const custom_type_wo_default_ctor &rhs) { |
| 78 | + return lhs.n == rhs.n && lhs.ull == rhs.ull; |
| 79 | +} |
| 80 | + |
| 81 | +inline custom_type_wo_default_ctor |
| 82 | +operator+(const custom_type_wo_default_ctor &lhs, |
| 83 | + const custom_type_wo_default_ctor &rhs) { |
| 84 | + return custom_type_wo_default_ctor(lhs.n.i + rhs.n.i, lhs.n.f + rhs.n.f, |
| 85 | + lhs.ull + rhs.ull); |
| 86 | +} |
| 87 | + |
| 88 | +template <typename T, std::size_t... Is> |
| 89 | +constexpr std::array<T, sizeof...(Is)> init_array(T value, |
| 90 | + std::index_sequence<Is...>) { |
| 91 | + return {{(static_cast<void>(Is), value)...}}; |
| 92 | +} |
| 93 | + |
| 94 | +using namespace sycl; |
| 95 | + |
| 96 | +template <typename InputContainer, typename OutputContainer, |
| 97 | + class BinaryOperation> |
| 98 | +void test(queue q, InputContainer input, OutputContainer output, |
| 99 | + BinaryOperation binary_op, size_t workgroup_size, |
| 100 | + typename OutputContainer::value_type identity, |
| 101 | + typename OutputContainer::value_type init) { |
| 102 | + using InputT = typename InputContainer::value_type; |
| 103 | + using OutputT = typename OutputContainer::value_type; |
| 104 | + constexpr size_t N = input.size(); |
| 105 | + { |
| 106 | + buffer<InputT> in_buf(input.data(), input.size()); |
| 107 | + buffer<OutputT> out_buf(output.data(), output.size()); |
| 108 | + |
| 109 | + q.submit([&](handler &cgh) { |
| 110 | + accessor in{in_buf, cgh, sycl::read_only}; |
| 111 | + accessor out{out_buf, cgh, sycl::write_only, sycl::no_init}; |
| 112 | + |
| 113 | + size_t temp_memory_size = workgroup_size * sizeof(InputT); |
| 114 | + auto scratch = sycl::local_accessor<std::byte, 1>(temp_memory_size, cgh); |
| 115 | + cgh.parallel_for( |
| 116 | + nd_range<1>(workgroup_size, workgroup_size), [=](nd_item<1> it) { |
| 117 | + // Create a handle that associates the group with an allocation it |
| 118 | + // can use |
| 119 | + auto handle = |
| 120 | + sycl::ext::oneapi::experimental::group_with_scratchpad( |
| 121 | + it.get_group(), sycl::span(&scratch[0], temp_memory_size)); |
| 122 | + |
| 123 | + InputT *first = in.get_pointer(); |
| 124 | + InputT *last = first + N; |
| 125 | + // check reduce_over_group w/o init |
| 126 | + out[0] = sycl::ext::oneapi::experimental::reduce_over_group( |
| 127 | + handle, in[it.get_global_id(0)], binary_op); |
| 128 | + |
| 129 | + // check reduce_over_group with init |
| 130 | + out[1] = sycl::ext::oneapi::experimental::reduce_over_group( |
| 131 | + handle, in[it.get_global_id(0)], init, binary_op); |
| 132 | + |
| 133 | + // check joint_reduce w/o init |
| 134 | + out[2] = sycl::ext::oneapi::experimental::joint_reduce( |
| 135 | + handle, first, last, binary_op); |
| 136 | + |
| 137 | + // check joint_reduce with init |
| 138 | + out[3] = sycl::ext::oneapi::experimental::joint_reduce( |
| 139 | + handle, first, last, init, binary_op); |
| 140 | + }); |
| 141 | + }); |
| 142 | + q.wait(); |
| 143 | + } |
| 144 | + assert(output[0] == std::reduce(input.begin(), input.begin() + workgroup_size, |
| 145 | + identity, binary_op)); |
| 146 | + assert(output[1] == std::reduce(input.begin(), input.begin() + workgroup_size, |
| 147 | + init, binary_op)); |
| 148 | + assert(output[2] == |
| 149 | + std::reduce(input.begin(), input.end(), identity, binary_op)); |
| 150 | + assert(output[3] == std::reduce(input.begin(), input.end(), init, binary_op)); |
| 151 | +} |
| 152 | + |
| 153 | +int main() { |
| 154 | + queue q; |
| 155 | + |
| 156 | + constexpr int N = 128; |
| 157 | + std::array<int, N> input; |
| 158 | + std::array<int, 4> output; |
| 159 | + std::iota(input.begin(), input.end(), 0); |
| 160 | + std::fill(output.begin(), output.end(), 0); |
| 161 | + |
| 162 | + // queue, input array, output array, binary_op, WG size, identity, init |
| 163 | + test(q, input, output, UserDefinedSum<int>{}, 64, 0, 42); |
| 164 | + test(q, input, output, UserDefinedSum<int>{}, 32, 0, 42); |
| 165 | + test(q, input, output, UserDefinedSum<int>{}, 5, 0, 42); |
| 166 | + test(q, input, output, UserDefinedMax<int>{}, 64, |
| 167 | + std::numeric_limits<int>::lowest(), 42); |
| 168 | + test(q, input, output, UserDefinedMultiplies<int>(), 64, 1, 42); |
| 169 | + test(q, input, output, UserDefinedBitAnd<int>{}, 64, ~0, 42); |
| 170 | + |
| 171 | + test(q, input, output, sycl::plus<int>(), 64, 0, 42); |
| 172 | + test(q, input, output, sycl::maximum<>(), 64, |
| 173 | + std::numeric_limits<int>::lowest(), 42); |
| 174 | + test(q, input, output, sycl::minimum<int>(), 64, |
| 175 | + std::numeric_limits<int>::max(), 42); |
| 176 | + |
| 177 | + test(q, input, output, sycl::multiplies<int>(), 64, 1, 42); |
| 178 | + test(q, input, output, sycl::bit_or<int>(), 64, 0, 42); |
| 179 | + test(q, input, output, sycl::bit_xor<int>(), 64, 0, 42); |
| 180 | + test(q, input, output, sycl::bit_and<int>(), 64, ~0, 42); |
| 181 | + |
| 182 | + std::array<custom_type, N> input_custom; |
| 183 | + std::array<custom_type, 4> output_custom; |
| 184 | + test(q, input_custom, output_custom, UserDefinedSum<custom_type>{}, 64, |
| 185 | + custom_type(0, 0., 0), custom_type(10, 0., 5)); |
| 186 | + |
| 187 | + custom_type_wo_default_ctor value(1, 2.5, 3); |
| 188 | + std::array<custom_type_wo_default_ctor, N> input_custom_wo_default_ctor = |
| 189 | + init_array(value, std::make_index_sequence<N>()); |
| 190 | + std::array<custom_type_wo_default_ctor, 4> output_custom_wo_default_ctor = |
| 191 | + init_array(value, std::make_index_sequence<4>()); |
| 192 | + test(q, input_custom_wo_default_ctor, output_custom_wo_default_ctor, |
| 193 | + UserDefinedSum<custom_type_wo_default_ctor>{}, 64, |
| 194 | + custom_type_wo_default_ctor(0, 0., 0), |
| 195 | + custom_type_wo_default_ctor(10, 0., 5)); |
| 196 | + |
| 197 | +#ifdef SYCL_EXT_ONEAPI_COMPLEX_ALGORITHMS |
| 198 | + std::array<std::complex<float>, N> input_cf; |
| 199 | + std::array<std::complex<float>, 4> output_cf; |
| 200 | + std::iota(input_cf.begin(), input_cf.end(), 0); |
| 201 | + std::fill(output_cf.begin(), output_cf.end(), 0); |
| 202 | + test(q, input_cf, output_cf, sycl::plus<std::complex<float>>(), 64, 0, 42); |
| 203 | + test(q, input_cf, output_cf, sycl::plus<>(), 64, 0, 42); |
| 204 | +#else |
| 205 | + static_assert(false, "SYCL_EXT_ONEAPI_COMPLEX_ALGORITHMS not defined"); |
| 206 | +#endif |
| 207 | + |
| 208 | + return 0; |
| 209 | +} |
0 commit comments