29
29
#include < backend_iface.hpp>
30
30
#include " queue_sycl.hpp"
31
31
32
+ namespace mkl_stats = oneapi::mkl::stats;
33
+
32
34
template <typename _KernelNameSpecialization>
33
35
class custom_sum_c_kernel ;
34
36
@@ -44,28 +46,32 @@ void custom_sum_c(void* array1_in, void* result1, size_t size)
44
46
_DataType* result = reinterpret_cast <_DataType*>(result1);
45
47
46
48
#if 1 // naive algorithm
47
- // cl::sycl::range<1> gws(size);
48
- auto policy = oneapi::dpl::execution::make_device_policy<custom_sum_c_kernel<_DataType>>(DPNP_QUEUE);
49
+ if constexpr (std::is_same<_DataType, double >::value || std::is_same<_DataType, float >::value)
50
+ {
51
+ // https://docs.oneapi.com/versions/latest/onemkl/mkl-stats-make_dataset.html
52
+ auto dataset = mkl_stats::make_dataset<mkl_stats::layout::row_major>(1 , size, array_1);
49
53
50
- // sycl::buffer<_DataType, 1> array_1_buf(array_1, gws);
51
- // auto it_begin = oneapi::dpl::begin(array_1_buf);
52
- // auto it_end = oneapi::dpl::end(array_1_buf);
54
+ // https://docs.oneapi.com/versions/latest/onemkl/mkl-stats-raw_sum.html
55
+ cl::sycl::event event = mkl_stats::raw_sum (DPNP_QUEUE, dataset, result);
53
56
54
- _DataType accumulator = 0 ;
55
- accumulator = std::reduce (policy, array_1, array_1 + size, _DataType (0 ), std::plus<_DataType>());
57
+ event.wait ();
58
+ }
59
+ else
60
+ {
61
+ // cl::sycl::range<1> gws(size);
62
+ auto policy = oneapi::dpl::execution::make_device_policy<custom_sum_c_kernel<_DataType>>(DPNP_QUEUE);
56
63
57
- policy.queue ().wait ();
64
+ // sycl::buffer<_DataType, 1> array_1_buf(array_1, gws);
65
+ // auto it_begin = oneapi::dpl::begin(array_1_buf);
66
+ // auto it_end = oneapi::dpl::end(array_1_buf);
58
67
59
- #if 0 // verification
60
- accumulator = 0;
61
- for (size_t i = 0; i < size; ++i)
62
- {
63
- accumulator += array_1[i];
64
- }
65
- // std::cout << "result: " << accumulator << std::endl;
66
- #endif
68
+ _DataType accumulator = 0 ;
69
+ accumulator = std::reduce (policy, array_1, array_1 + size, _DataType (0 ), std::plus<_DataType>());
67
70
68
- result[0 ] = accumulator;
71
+ policy.queue ().wait ();
72
+
73
+ result[0 ] = accumulator;
74
+ }
69
75
70
76
return ;
71
77
0 commit comments