@@ -158,16 +158,28 @@ void custom_max_c(void* array1_in, void* result1, const size_t* shape, size_t nd
158
158
size *= shape[i];
159
159
}
160
160
161
- auto policy = oneapi::dpl::execution::make_device_policy<class custom_max_c_kernel <_DataType>>(DPNP_QUEUE);
161
+ if constexpr (std::is_same<_DataType, double >::value || std::is_same<_DataType, float >::value)
162
+ {
163
+ // Required initializing the result before call the function
164
+ result[0 ] = array_1[0 ];
162
165
163
- _DataType* res = std::max_element (policy, array_1, array_1 + size);
164
- policy. queue (). wait ( );
166
+ // https://docs.oneapi.com/versions/latest/onemkl/mkl-stats-make_dataset.html
167
+ auto dataset = mkl_stats::make_dataset<mkl_stats::layout::row_major>( 1 , size, array_1 );
165
168
166
- result[0 ] = *res;
169
+ // https://docs.oneapi.com/versions/latest/onemkl/mkl-stats-max.html
170
+ cl::sycl::event event = mkl_stats::max (DPNP_QUEUE, dataset, result);
167
171
168
- #if 0
169
- std::cout << "max result " << result[0] << "\n";
170
- #endif
172
+ event.wait ();
173
+ }
174
+ else
175
+ {
176
+ auto policy = oneapi::dpl::execution::make_device_policy<class custom_max_c_kernel <_DataType>>(DPNP_QUEUE);
177
+
178
+ _DataType* res = std::max_element (policy, array_1, array_1 + size);
179
+ policy.queue ().wait ();
180
+
181
+ result[0 ] = *res;
182
+ }
171
183
}
172
184
173
185
template void custom_max_c<double >(
@@ -288,17 +300,28 @@ void custom_min_c(void* array1_in, void* result1, const size_t* shape, size_t nd
288
300
{
289
301
size *= shape[i];
290
302
}
303
+ if constexpr (std::is_same<_DataType, double >::value || std::is_same<_DataType, float >::value)
304
+ {
305
+ // Required initializing the result before call the function
306
+ result[0 ] = array_1[0 ];
307
+
308
+ // https://docs.oneapi.com/versions/latest/onemkl/mkl-stats-make_dataset.html
309
+ auto dataset = mkl_stats::make_dataset<mkl_stats::layout::row_major>(1 , size, array_1);
291
310
292
- auto policy = oneapi::dpl::execution::make_device_policy<class custom_min_c_kernel <_DataType>>(DPNP_QUEUE);
311
+ // https://docs.oneapi.com/versions/latest/onemkl/mkl-stats-min.html
312
+ cl::sycl::event event = mkl_stats::min (DPNP_QUEUE, dataset, result);
293
313
294
- _DataType* res = std::min_element (policy, array_1, array_1 + size);
295
- policy.queue ().wait ();
314
+ event.wait ();
315
+ }
316
+ else
317
+ {
318
+ auto policy = oneapi::dpl::execution::make_device_policy<class custom_min_c_kernel <_DataType>>(DPNP_QUEUE);
296
319
297
- result[0 ] = *res;
320
+ _DataType* res = std::min_element (policy, array_1, array_1 + size);
321
+ policy.queue ().wait ();
298
322
299
- #if 0
300
- std::cout << "min result " << result[0] << "\n";
301
- #endif
323
+ result[0 ] = *res;
324
+ }
302
325
}
303
326
304
327
template void custom_min_c<double >(
0 commit comments