Skip to content

Commit ba95cbd

Browse files
authored
MIN and MAX add MKL kernel (#124)
1 parent 04a1161 commit ba95cbd

File tree

1 file changed

+37
-14
lines changed

1 file changed

+37
-14
lines changed

dpnp/backend/custom_kernels_statistics.cpp

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,28 @@ void custom_max_c(void* array1_in, void* result1, const size_t* shape, size_t nd
158158
size *= shape[i];
159159
}
160160

161-
auto policy = oneapi::dpl::execution::make_device_policy<class custom_max_c_kernel<_DataType>>(DPNP_QUEUE);
161+
if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value)
162+
{
163+
// Required initializing the result before call the function
164+
result[0] = array_1[0];
162165

163-
_DataType* res = std::max_element(policy, array_1, array_1 + size);
164-
policy.queue().wait();
166+
// https://docs.oneapi.com/versions/latest/onemkl/mkl-stats-make_dataset.html
167+
auto dataset = mkl_stats::make_dataset<mkl_stats::layout::row_major>(1, size, array_1);
165168

166-
result[0] = *res;
169+
// https://docs.oneapi.com/versions/latest/onemkl/mkl-stats-max.html
170+
cl::sycl::event event = mkl_stats::max(DPNP_QUEUE, dataset, result);
167171

168-
#if 0
169-
std::cout << "max result " << result[0] << "\n";
170-
#endif
172+
event.wait();
173+
}
174+
else
175+
{
176+
auto policy = oneapi::dpl::execution::make_device_policy<class custom_max_c_kernel<_DataType>>(DPNP_QUEUE);
177+
178+
_DataType* res = std::max_element(policy, array_1, array_1 + size);
179+
policy.queue().wait();
180+
181+
result[0] = *res;
182+
}
171183
}
172184

173185
template void custom_max_c<double>(
@@ -288,17 +300,28 @@ void custom_min_c(void* array1_in, void* result1, const size_t* shape, size_t nd
288300
{
289301
size *= shape[i];
290302
}
303+
if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value)
304+
{
305+
// Required initializing the result before call the function
306+
result[0] = array_1[0];
307+
308+
// https://docs.oneapi.com/versions/latest/onemkl/mkl-stats-make_dataset.html
309+
auto dataset = mkl_stats::make_dataset<mkl_stats::layout::row_major>(1, size, array_1);
291310

292-
auto policy = oneapi::dpl::execution::make_device_policy<class custom_min_c_kernel<_DataType>>(DPNP_QUEUE);
311+
// https://docs.oneapi.com/versions/latest/onemkl/mkl-stats-min.html
312+
cl::sycl::event event = mkl_stats::min(DPNP_QUEUE, dataset, result);
293313

294-
_DataType* res = std::min_element(policy, array_1, array_1 + size);
295-
policy.queue().wait();
314+
event.wait();
315+
}
316+
else
317+
{
318+
auto policy = oneapi::dpl::execution::make_device_policy<class custom_min_c_kernel<_DataType>>(DPNP_QUEUE);
296319

297-
result[0] = *res;
320+
_DataType* res = std::min_element(policy, array_1, array_1 + size);
321+
policy.queue().wait();
298322

299-
#if 0
300-
std::cout << "min result " << result[0] << "\n";
301-
#endif
323+
result[0] = *res;
324+
}
302325
}
303326

304327
template void custom_min_c<double>(

0 commit comments

Comments
 (0)