Skip to content

Commit beae23d

Browse files
authored
MEAN add MKL kernel (#119)
* MEAN add MKL kernel
1 parent 13de40d commit beae23d

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

dpnp/backend/custom_kernels_statistics.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@
2525

2626
#include <iostream>
2727
#include <mkl_blas_sycl.hpp>
28+
#include <mkl_stats_sycl.hpp>
2829

2930
#include <backend_iface.hpp>
3031
#include "backend_pstl.hpp"
3132
#include "backend_utils.hpp"
3233
#include "queue_sycl.hpp"
3334

3435
namespace mkl_blas = oneapi::mkl::blas::row_major;
36+
namespace mkl_stats = oneapi::mkl::stats;
3537

3638
template <typename _DataType>
3739
class custom_cov_c_kernel;
@@ -198,17 +200,28 @@ void custom_mean_c(void* array1_in, void* result1, const size_t* shape, size_t n
198200
return;
199201
}
200202

201-
_DataType* sum = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(1 * sizeof(_DataType)));
203+
if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value)
204+
{
205+
_ResultType* array = reinterpret_cast<_DataType*>(array1_in);
202206

203-
custom_sum_c<_DataType>(array1_in, sum, size);
207+
// https://docs.oneapi.com/versions/latest/onemkl/mkl-stats-make_dataset.html
208+
auto dataset = mkl_stats::make_dataset<mkl_stats::layout::row_major>(1, size, array);
204209

205-
result[0] = static_cast<_ResultType>(sum[0]) / static_cast<_ResultType>(size);
210+
// https://docs.oneapi.com/versions/latest/onemkl/mkl-stats-mean.html
211+
cl::sycl::event event = mkl_stats::mean(DPNP_QUEUE, dataset, result);
206212

207-
dpnp_memory_free_c(sum);
213+
event.wait();
214+
}
215+
else
216+
{
217+
_DataType* sum = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(1 * sizeof(_DataType)));
208218

209-
#if 0
210-
std::cout << "mean result " << result[0] << "\n";
211-
#endif
219+
custom_sum_c<_DataType>(array1_in, sum, size);
220+
221+
result[0] = static_cast<_ResultType>(sum[0]) / static_cast<_ResultType>(size);
222+
223+
dpnp_memory_free_c(sum);
224+
}
212225
}
213226

214227
template void custom_mean_c<double, double>(

0 commit comments

Comments
 (0)