|
10 | 10 | //
|
11 | 11 | //===----------------------------------------------------------------------===//
|
12 | 12 |
|
13 |
| -#include "vec_convert.hpp" |
| 13 | +#include <CL/sycl.hpp> |
| 14 | + |
| 15 | +#include <cassert> |
14 | 16 |
|
15 | 17 | // TODO make the convertion on CPU and HOST identical
|
16 | 18 |
|
| 19 | +using namespace cl::sycl; |
| 20 | + |
| 21 | +template <typename T, typename convertT, int roundingMode> |
| 22 | +class kernel_name; |
| 23 | + |
| 24 | +template <int N> |
| 25 | +struct helper; |
| 26 | + |
| 27 | +template <> |
| 28 | +struct helper<0> { |
| 29 | + template <typename T, int NumElements> |
| 30 | + static void compare(const vec<T, NumElements> &x, |
| 31 | + const vec<T, NumElements> &y) { |
| 32 | + const T xs = x.template swizzle<0>(); |
| 33 | + const T ys = y.template swizzle<0>(); |
| 34 | + assert(xs == ys); |
| 35 | + } |
| 36 | +}; |
| 37 | + |
| 38 | +template <int N> |
| 39 | +struct helper { |
| 40 | + template <typename T, int NumElements> |
| 41 | + static void compare(const vec<T, NumElements> &x, |
| 42 | + const vec<T, NumElements> &y) { |
| 43 | + const T xs = x.template swizzle<N>(); |
| 44 | + const T ys = y.template swizzle<N>(); |
| 45 | + helper<N - 1>::compare(x, y); |
| 46 | + assert(xs == ys); |
| 47 | + } |
| 48 | +}; |
| 49 | + |
| 50 | +template <typename T, typename convertT, int NumElements, |
| 51 | + rounding_mode roundingMode> |
| 52 | +void test(const vec<T, NumElements> &ToConvert, |
| 53 | + const vec<convertT, NumElements> &Expected) { |
| 54 | + vec<convertT, NumElements> Converted{0}; |
| 55 | + { |
| 56 | + buffer<vec<convertT, NumElements>, 1> Buffer{&Converted, range<1>{1}}; |
| 57 | + queue Queue; |
| 58 | + |
| 59 | + cl::sycl::device D = Queue.get_device(); |
| 60 | + if (!D.has_extension("cl_khr_fp16")) |
| 61 | + exit(0); |
| 62 | + |
| 63 | + Queue.submit([&](handler &CGH) { |
| 64 | + accessor<vec<convertT, NumElements>, 1, access::mode::write> Accessor( |
| 65 | + Buffer, CGH); |
| 66 | + CGH.single_task<class kernel_name<T, convertT, static_cast<int>(roundingMode)>>([=]() { |
| 67 | + Accessor[0] = ToConvert.template convert<convertT, roundingMode>(); |
| 68 | + }); |
| 69 | + }); |
| 70 | + } |
| 71 | + helper<NumElements - 1>::compare(Converted, Expected); |
| 72 | +} |
| 73 | + |
17 | 74 | int main() {
|
18 | 75 | //automatic
|
19 | 76 | test<double, half, 4, rounding_mode::automatic>(
|
|
0 commit comments