|
25 | 25 |
|
26 | 26 | #include <iostream>
|
27 | 27 | #include <mkl_blas_sycl.hpp>
|
| 28 | +#include <mkl_stats_sycl.hpp> |
28 | 29 |
|
29 | 30 | #include <backend_iface.hpp>
|
30 | 31 | #include "backend_pstl.hpp"
|
31 | 32 | #include "backend_utils.hpp"
|
32 | 33 | #include "queue_sycl.hpp"
|
33 | 34 |
|
34 | 35 | namespace mkl_blas = oneapi::mkl::blas::row_major;
|
| 36 | +namespace mkl_stats = oneapi::mkl::stats; |
35 | 37 |
|
36 | 38 | template <typename _DataType>
|
37 | 39 | class custom_cov_c_kernel;
|
@@ -198,17 +200,28 @@ void custom_mean_c(void* array1_in, void* result1, const size_t* shape, size_t n
|
198 | 200 | return;
|
199 | 201 | }
|
200 | 202 |
|
201 |
| - _DataType* sum = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(1 * sizeof(_DataType))); |
| 203 | + if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value) |
| 204 | + { |
| 205 | + _ResultType* array = reinterpret_cast<_DataType*>(array1_in); |
202 | 206 |
|
203 |
| - custom_sum_c<_DataType>(array1_in, sum, size); |
| 207 | + // https://docs.oneapi.com/versions/latest/onemkl/mkl-stats-make_dataset.html |
| 208 | + auto dataset = mkl_stats::make_dataset<mkl_stats::layout::row_major>(1, size, array); |
204 | 209 |
|
205 |
| - result[0] = static_cast<_ResultType>(sum[0]) / static_cast<_ResultType>(size); |
| 210 | + // https://docs.oneapi.com/versions/latest/onemkl/mkl-stats-mean.html |
| 211 | + cl::sycl::event event = mkl_stats::mean(DPNP_QUEUE, dataset, result); |
206 | 212 |
|
207 |
| - dpnp_memory_free_c(sum); |
| 213 | + event.wait(); |
| 214 | + } |
| 215 | + else |
| 216 | + { |
| 217 | + _DataType* sum = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(1 * sizeof(_DataType))); |
208 | 218 |
|
209 |
| -#if 0 |
210 |
| - std::cout << "mean result " << result[0] << "\n"; |
211 |
| -#endif |
| 219 | + custom_sum_c<_DataType>(array1_in, sum, size); |
| 220 | + |
| 221 | + result[0] = static_cast<_ResultType>(sum[0]) / static_cast<_ResultType>(size); |
| 222 | + |
| 223 | + dpnp_memory_free_c(sum); |
| 224 | + } |
212 | 225 | }
|
213 | 226 |
|
214 | 227 | template void custom_mean_c<double, double>(
|
|
0 commit comments