Skip to content

Commit c9f1c2a

Browse files
authored
SUM add MKL kernel (#122)
* SUM add MKL kernel
1 parent ba95cbd commit c9f1c2a

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

dpnp/backend/custom_kernels_reduction.cpp

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
#include <backend_iface.hpp>
3030
#include "queue_sycl.hpp"
3131

32+
namespace mkl_stats = oneapi::mkl::stats;
33+
3234
template <typename _KernelNameSpecialization>
3335
class custom_sum_c_kernel;
3436

@@ -44,28 +46,32 @@ void custom_sum_c(void* array1_in, void* result1, size_t size)
4446
_DataType* result = reinterpret_cast<_DataType*>(result1);
4547

4648
#if 1 // naive algorithm
47-
// cl::sycl::range<1> gws(size);
48-
auto policy = oneapi::dpl::execution::make_device_policy<custom_sum_c_kernel<_DataType>>(DPNP_QUEUE);
49+
if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value)
50+
{
51+
// https://docs.oneapi.com/versions/latest/onemkl/mkl-stats-make_dataset.html
52+
auto dataset = mkl_stats::make_dataset<mkl_stats::layout::row_major>(1, size, array_1);
4953

50-
// sycl::buffer<_DataType, 1> array_1_buf(array_1, gws);
51-
// auto it_begin = oneapi::dpl::begin(array_1_buf);
52-
// auto it_end = oneapi::dpl::end(array_1_buf);
54+
// https://docs.oneapi.com/versions/latest/onemkl/mkl-stats-raw_sum.html
55+
cl::sycl::event event = mkl_stats::raw_sum(DPNP_QUEUE, dataset, result);
5356

54-
_DataType accumulator = 0;
55-
accumulator = std::reduce(policy, array_1, array_1 + size, _DataType(0), std::plus<_DataType>());
57+
event.wait();
58+
}
59+
else
60+
{
61+
// cl::sycl::range<1> gws(size);
62+
auto policy = oneapi::dpl::execution::make_device_policy<custom_sum_c_kernel<_DataType>>(DPNP_QUEUE);
5663

57-
policy.queue().wait();
64+
// sycl::buffer<_DataType, 1> array_1_buf(array_1, gws);
65+
// auto it_begin = oneapi::dpl::begin(array_1_buf);
66+
// auto it_end = oneapi::dpl::end(array_1_buf);
5867

59-
#if 0 // verification
60-
accumulator = 0;
61-
for (size_t i = 0; i < size; ++i)
62-
{
63-
accumulator += array_1[i];
64-
}
65-
// std::cout << "result: " << accumulator << std::endl;
66-
#endif
68+
_DataType accumulator = 0;
69+
accumulator = std::reduce(policy, array_1, array_1 + size, _DataType(0), std::plus<_DataType>());
6770

68-
result[0] = accumulator;
71+
policy.queue().wait();
72+
73+
result[0] = accumulator;
74+
}
6975

7076
return;
7177

0 commit comments

Comments
 (0)